mirror of
https://github.com/SickGear/SickGear.git
synced 2025-01-20 16:43:43 +00:00
Fixed issues with editing/saving custom scene exceptions.
Fixed charmap issues for anime show names. Fixed issues with display show page and epCat key errors. Fixed duplicate log messages for clearing provider caches. Fixed issues with email notifier ep names not properly being encoded to UTF-8. TVDB<->TVRAGE Indexer ID mapping is now performed on demand to be used when needed such as newznab providers can be searched with tvrage_id's and some will return tvrage_id's that later can be used to create show objects from for faster and more accurate name parsing, mapping is done via Trakt API calls. Added stop event signals to schedualed tasks, SR now waits indefinate till task has been fully stopped before completing a restart or shutdown event. NameParserCache is now persistent and stores 200 parsed results at any given time for quicker lookups and better performance, this helps maintain results between updates or shutdown/startup events. Black and White lists for anime now only get used for anime shows as intended, performance gain for non-anime shows that dont need to load these lists. Internal name cache now builds it self on demand when needed per show request plus checks if show is already in cache and if true exits routine to save time. Schedualer and QueueItems classes are now a sub-class of threading.Thread and a stop threading event signal has been added to each. If I forgot to list something it doesn't mean its not fixed so please test and report back if anything is wrong or has been corrected by this new release.
This commit is contained in:
parent
09f53d3537
commit
d02c0bd6eb
304 changed files with 922 additions and 102786 deletions
|
@ -455,7 +455,7 @@ class SickRage(object):
|
|||
sickbeard.showList.append(curShow)
|
||||
except Exception, e:
|
||||
logger.log(
|
||||
u"There was an error creating the show in " + sqlShow["location"] + ": " + str(e).decode('utf-8'),
|
||||
u"There was an error creating the show in " + sqlShow["location"] + ": " + str(e).decode('utf-8', 'replace'),
|
||||
logger.ERROR)
|
||||
|
||||
def restore(self, srcDir, dstDir):
|
||||
|
@ -477,14 +477,14 @@ class SickRage(object):
|
|||
# stop all tasks
|
||||
sickbeard.halt()
|
||||
|
||||
# save all shows to DB
|
||||
sickbeard.saveAll()
|
||||
|
||||
# shutdown web server
|
||||
if self.webserver:
|
||||
self.webserver.shutDown()
|
||||
self.webserver = None
|
||||
|
||||
# save all shows to DB
|
||||
sickbeard.saveAll()
|
||||
|
||||
# if run as daemon delete the pidfile
|
||||
if self.runAsDaemon and self.CREATEPID:
|
||||
self.remove_pid_file(self.PIDFILE)
|
||||
|
|
|
@ -190,13 +190,13 @@
|
|||
#if $show.rls_ignore_words:
|
||||
<tr><td class="showLegend">Ignored Words: </td><td>#echo $show.rls_ignore_words#</td></tr>
|
||||
#end if
|
||||
#if $bwl.get_white_keywords_for("release_group"):
|
||||
#if $bwl and $bwl.get_white_keywords_for("release_group"):
|
||||
<tr>
|
||||
<td class="showLegend">Wanted Group#if len($bwl.get_white_keywords_for("release_group"))>1 then "s" else ""#:</td>
|
||||
<td>#echo ', '.join($bwl.get_white_keywords_for("release_group"))#</td>
|
||||
</tr>
|
||||
#end if
|
||||
#if $bwl.get_black_keywords_for("release_group"):
|
||||
#if $bwl and $bwl.get_black_keywords_for("release_group"):
|
||||
<tr>
|
||||
<td class="showLegend">Unwanted Group#if len($bwl.get_black_keywords_for("release_group"))>1 then "s" else ""#:</td>
|
||||
<td>#echo ', '.join($bwl.get_black_keywords_for("release_group"))#</td>
|
||||
|
@ -265,6 +265,11 @@
|
|||
<table class="sickbeardTable" cellspacing="1" border="0" cellpadding="0">
|
||||
|
||||
#for $epResult in $sqlResults:
|
||||
#set $epStr = str($epResult["season"]) + "x" + str($epResult["episode"])
|
||||
#if not $epStr in $epCats:
|
||||
#continue
|
||||
#end if
|
||||
|
||||
#if not $sickbeard.DISPLAY_SHOW_SPECIALS and int($epResult["season"]) == 0:
|
||||
#continue
|
||||
#end if
|
||||
|
@ -314,7 +319,6 @@
|
|||
#set $curSeason = int($epResult["season"])
|
||||
#end if
|
||||
|
||||
#set $epStr = str($epResult["season"]) + "x" + str($epResult["episode"])
|
||||
#set $epLoc = $epResult["location"]
|
||||
<tr class="$Overview.overviewStrings[$epCats[$epStr]] season-$curSeason">
|
||||
<td width="1%">
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
from pkgutil import extend_path
|
||||
|
||||
__path__ = extend_path(__path__, __name__)
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright 2009 Brian Quinlan. All Rights Reserved.
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
|
||||
"""Execute computations asynchronously using threads or processes."""
|
||||
|
||||
__author__ = 'Brian Quinlan (brian@sweetapp.com)'
|
||||
|
||||
from concurrent.futures._base import (FIRST_COMPLETED,
|
||||
FIRST_EXCEPTION,
|
||||
ALL_COMPLETED,
|
||||
CancelledError,
|
||||
TimeoutError,
|
||||
Future,
|
||||
Executor,
|
||||
wait,
|
||||
as_completed)
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
# Jython doesn't have multiprocessing
|
||||
try:
|
||||
from concurrent.futures.process import ProcessPoolExecutor
|
||||
except ImportError:
|
||||
pass
|
|
@ -1,577 +0,0 @@
|
|||
# Copyright 2009 Brian Quinlan. All Rights Reserved.
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
|
||||
from __future__ import with_statement
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
try:
|
||||
from collections import namedtuple
|
||||
except ImportError:
|
||||
from concurrent.futures._compat import namedtuple
|
||||
|
||||
__author__ = 'Brian Quinlan (brian@sweetapp.com)'
|
||||
|
||||
FIRST_COMPLETED = 'FIRST_COMPLETED'
|
||||
FIRST_EXCEPTION = 'FIRST_EXCEPTION'
|
||||
ALL_COMPLETED = 'ALL_COMPLETED'
|
||||
_AS_COMPLETED = '_AS_COMPLETED'
|
||||
|
||||
# Possible future states (for internal use by the futures package).
|
||||
PENDING = 'PENDING'
|
||||
RUNNING = 'RUNNING'
|
||||
# The future was cancelled by the user...
|
||||
CANCELLED = 'CANCELLED'
|
||||
# ...and _Waiter.add_cancelled() was called by a worker.
|
||||
CANCELLED_AND_NOTIFIED = 'CANCELLED_AND_NOTIFIED'
|
||||
FINISHED = 'FINISHED'
|
||||
|
||||
_FUTURE_STATES = [
|
||||
PENDING,
|
||||
RUNNING,
|
||||
CANCELLED,
|
||||
CANCELLED_AND_NOTIFIED,
|
||||
FINISHED
|
||||
]
|
||||
|
||||
_STATE_TO_DESCRIPTION_MAP = {
|
||||
PENDING: "pending",
|
||||
RUNNING: "running",
|
||||
CANCELLED: "cancelled",
|
||||
CANCELLED_AND_NOTIFIED: "cancelled",
|
||||
FINISHED: "finished"
|
||||
}
|
||||
|
||||
# Logger for internal use by the futures package.
|
||||
LOGGER = logging.getLogger("concurrent.futures")
|
||||
|
||||
class Error(Exception):
|
||||
"""Base class for all future-related exceptions."""
|
||||
pass
|
||||
|
||||
class CancelledError(Error):
|
||||
"""The Future was cancelled."""
|
||||
pass
|
||||
|
||||
class TimeoutError(Error):
|
||||
"""The operation exceeded the given deadline."""
|
||||
pass
|
||||
|
||||
class _Waiter(object):
|
||||
"""Provides the event that wait() and as_completed() block on."""
|
||||
def __init__(self):
|
||||
self.event = threading.Event()
|
||||
self.finished_futures = []
|
||||
|
||||
def add_result(self, future):
|
||||
self.finished_futures.append(future)
|
||||
|
||||
def add_exception(self, future):
|
||||
self.finished_futures.append(future)
|
||||
|
||||
def add_cancelled(self, future):
|
||||
self.finished_futures.append(future)
|
||||
|
||||
class _AsCompletedWaiter(_Waiter):
|
||||
"""Used by as_completed()."""
|
||||
|
||||
def __init__(self):
|
||||
super(_AsCompletedWaiter, self).__init__()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def add_result(self, future):
|
||||
with self.lock:
|
||||
super(_AsCompletedWaiter, self).add_result(future)
|
||||
self.event.set()
|
||||
|
||||
def add_exception(self, future):
|
||||
with self.lock:
|
||||
super(_AsCompletedWaiter, self).add_exception(future)
|
||||
self.event.set()
|
||||
|
||||
def add_cancelled(self, future):
|
||||
with self.lock:
|
||||
super(_AsCompletedWaiter, self).add_cancelled(future)
|
||||
self.event.set()
|
||||
|
||||
class _FirstCompletedWaiter(_Waiter):
|
||||
"""Used by wait(return_when=FIRST_COMPLETED)."""
|
||||
|
||||
def add_result(self, future):
|
||||
super(_FirstCompletedWaiter, self).add_result(future)
|
||||
self.event.set()
|
||||
|
||||
def add_exception(self, future):
|
||||
super(_FirstCompletedWaiter, self).add_exception(future)
|
||||
self.event.set()
|
||||
|
||||
def add_cancelled(self, future):
|
||||
super(_FirstCompletedWaiter, self).add_cancelled(future)
|
||||
self.event.set()
|
||||
|
||||
class _AllCompletedWaiter(_Waiter):
|
||||
"""Used by wait(return_when=FIRST_EXCEPTION and ALL_COMPLETED)."""
|
||||
|
||||
def __init__(self, num_pending_calls, stop_on_exception):
|
||||
self.num_pending_calls = num_pending_calls
|
||||
self.stop_on_exception = stop_on_exception
|
||||
self.lock = threading.Lock()
|
||||
super(_AllCompletedWaiter, self).__init__()
|
||||
|
||||
def _decrement_pending_calls(self):
|
||||
with self.lock:
|
||||
self.num_pending_calls -= 1
|
||||
if not self.num_pending_calls:
|
||||
self.event.set()
|
||||
|
||||
def add_result(self, future):
|
||||
super(_AllCompletedWaiter, self).add_result(future)
|
||||
self._decrement_pending_calls()
|
||||
|
||||
def add_exception(self, future):
|
||||
super(_AllCompletedWaiter, self).add_exception(future)
|
||||
if self.stop_on_exception:
|
||||
self.event.set()
|
||||
else:
|
||||
self._decrement_pending_calls()
|
||||
|
||||
def add_cancelled(self, future):
|
||||
super(_AllCompletedWaiter, self).add_cancelled(future)
|
||||
self._decrement_pending_calls()
|
||||
|
||||
class _AcquireFutures(object):
|
||||
"""A context manager that does an ordered acquire of Future conditions."""
|
||||
|
||||
def __init__(self, futures):
|
||||
self.futures = sorted(futures, key=id)
|
||||
|
||||
def __enter__(self):
|
||||
for future in self.futures:
|
||||
future._condition.acquire()
|
||||
|
||||
def __exit__(self, *args):
|
||||
for future in self.futures:
|
||||
future._condition.release()
|
||||
|
||||
def _create_and_install_waiters(fs, return_when):
|
||||
if return_when == _AS_COMPLETED:
|
||||
waiter = _AsCompletedWaiter()
|
||||
elif return_when == FIRST_COMPLETED:
|
||||
waiter = _FirstCompletedWaiter()
|
||||
else:
|
||||
pending_count = sum(
|
||||
f._state not in [CANCELLED_AND_NOTIFIED, FINISHED] for f in fs)
|
||||
|
||||
if return_when == FIRST_EXCEPTION:
|
||||
waiter = _AllCompletedWaiter(pending_count, stop_on_exception=True)
|
||||
elif return_when == ALL_COMPLETED:
|
||||
waiter = _AllCompletedWaiter(pending_count, stop_on_exception=False)
|
||||
else:
|
||||
raise ValueError("Invalid return condition: %r" % return_when)
|
||||
|
||||
for f in fs:
|
||||
f._waiters.append(waiter)
|
||||
|
||||
return waiter
|
||||
|
||||
def as_completed(fs, timeout=None):
|
||||
"""An iterator over the given futures that yields each as it completes.
|
||||
|
||||
Args:
|
||||
fs: The sequence of Futures (possibly created by different Executors) to
|
||||
iterate over.
|
||||
timeout: The maximum number of seconds to wait. If None, then there
|
||||
is no limit on the wait time.
|
||||
|
||||
Returns:
|
||||
An iterator that yields the given Futures as they complete (finished or
|
||||
cancelled).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the entire result iterator could not be generated
|
||||
before the given timeout.
|
||||
"""
|
||||
if timeout is not None:
|
||||
end_time = timeout + time.time()
|
||||
|
||||
with _AcquireFutures(fs):
|
||||
finished = set(
|
||||
f for f in fs
|
||||
if f._state in [CANCELLED_AND_NOTIFIED, FINISHED])
|
||||
pending = set(fs) - finished
|
||||
waiter = _create_and_install_waiters(fs, _AS_COMPLETED)
|
||||
|
||||
try:
|
||||
for future in finished:
|
||||
yield future
|
||||
|
||||
while pending:
|
||||
if timeout is None:
|
||||
wait_timeout = None
|
||||
else:
|
||||
wait_timeout = end_time - time.time()
|
||||
if wait_timeout < 0:
|
||||
raise TimeoutError(
|
||||
'%d (of %d) futures unfinished' % (
|
||||
len(pending), len(fs)))
|
||||
|
||||
waiter.event.wait(wait_timeout)
|
||||
|
||||
with waiter.lock:
|
||||
finished = waiter.finished_futures
|
||||
waiter.finished_futures = []
|
||||
waiter.event.clear()
|
||||
|
||||
for future in finished:
|
||||
yield future
|
||||
pending.remove(future)
|
||||
|
||||
finally:
|
||||
for f in fs:
|
||||
f._waiters.remove(waiter)
|
||||
|
||||
DoneAndNotDoneFutures = namedtuple(
|
||||
'DoneAndNotDoneFutures', 'done not_done')
|
||||
def wait(fs, timeout=None, return_when=ALL_COMPLETED):
|
||||
"""Wait for the futures in the given sequence to complete.
|
||||
|
||||
Args:
|
||||
fs: The sequence of Futures (possibly created by different Executors) to
|
||||
wait upon.
|
||||
timeout: The maximum number of seconds to wait. If None, then there
|
||||
is no limit on the wait time.
|
||||
return_when: Indicates when this function should return. The options
|
||||
are:
|
||||
|
||||
FIRST_COMPLETED - Return when any future finishes or is
|
||||
cancelled.
|
||||
FIRST_EXCEPTION - Return when any future finishes by raising an
|
||||
exception. If no future raises an exception
|
||||
then it is equivalent to ALL_COMPLETED.
|
||||
ALL_COMPLETED - Return when all futures finish or are cancelled.
|
||||
|
||||
Returns:
|
||||
A named 2-tuple of sets. The first set, named 'done', contains the
|
||||
futures that completed (is finished or cancelled) before the wait
|
||||
completed. The second set, named 'not_done', contains uncompleted
|
||||
futures.
|
||||
"""
|
||||
with _AcquireFutures(fs):
|
||||
done = set(f for f in fs
|
||||
if f._state in [CANCELLED_AND_NOTIFIED, FINISHED])
|
||||
not_done = set(fs) - done
|
||||
|
||||
if (return_when == FIRST_COMPLETED) and done:
|
||||
return DoneAndNotDoneFutures(done, not_done)
|
||||
elif (return_when == FIRST_EXCEPTION) and done:
|
||||
if any(f for f in done
|
||||
if not f.cancelled() and f.exception() is not None):
|
||||
return DoneAndNotDoneFutures(done, not_done)
|
||||
|
||||
if len(done) == len(fs):
|
||||
return DoneAndNotDoneFutures(done, not_done)
|
||||
|
||||
waiter = _create_and_install_waiters(fs, return_when)
|
||||
|
||||
waiter.event.wait(timeout)
|
||||
for f in fs:
|
||||
f._waiters.remove(waiter)
|
||||
|
||||
done.update(waiter.finished_futures)
|
||||
return DoneAndNotDoneFutures(done, set(fs) - done)
|
||||
|
||||
class Future(object):
|
||||
"""Represents the result of an asynchronous computation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the future. Should not be called by clients."""
|
||||
self._condition = threading.Condition()
|
||||
self._state = PENDING
|
||||
self._result = None
|
||||
self._exception = None
|
||||
self._waiters = []
|
||||
self._done_callbacks = []
|
||||
|
||||
def _invoke_callbacks(self):
|
||||
for callback in self._done_callbacks:
|
||||
try:
|
||||
callback(self)
|
||||
except Exception:
|
||||
LOGGER.exception('exception calling callback for %r', self)
|
||||
|
||||
def __repr__(self):
|
||||
with self._condition:
|
||||
if self._state == FINISHED:
|
||||
if self._exception:
|
||||
return '<Future at %s state=%s raised %s>' % (
|
||||
hex(id(self)),
|
||||
_STATE_TO_DESCRIPTION_MAP[self._state],
|
||||
self._exception.__class__.__name__)
|
||||
else:
|
||||
return '<Future at %s state=%s returned %s>' % (
|
||||
hex(id(self)),
|
||||
_STATE_TO_DESCRIPTION_MAP[self._state],
|
||||
self._result.__class__.__name__)
|
||||
return '<Future at %s state=%s>' % (
|
||||
hex(id(self)),
|
||||
_STATE_TO_DESCRIPTION_MAP[self._state])
|
||||
|
||||
def cancel(self):
|
||||
"""Cancel the future if possible.
|
||||
|
||||
Returns True if the future was cancelled, False otherwise. A future
|
||||
cannot be cancelled if it is running or has already completed.
|
||||
"""
|
||||
with self._condition:
|
||||
if self._state in [RUNNING, FINISHED]:
|
||||
return False
|
||||
|
||||
if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
|
||||
return True
|
||||
|
||||
self._state = CANCELLED
|
||||
self._condition.notify_all()
|
||||
|
||||
self._invoke_callbacks()
|
||||
return True
|
||||
|
||||
def cancelled(self):
|
||||
"""Return True if the future has cancelled."""
|
||||
with self._condition:
|
||||
return self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]
|
||||
|
||||
def isAlive(self):
|
||||
return self.running()
|
||||
|
||||
def running(self):
|
||||
"""Return True if the future is currently executing."""
|
||||
with self._condition:
|
||||
return self._state == RUNNING
|
||||
|
||||
def done(self):
|
||||
"""Return True of the future was cancelled or finished executing."""
|
||||
with self._condition:
|
||||
return self._state in [CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED]
|
||||
|
||||
def __get_result(self):
|
||||
if self._exception:
|
||||
raise self._exception
|
||||
else:
|
||||
return self._result
|
||||
|
||||
def add_done_callback(self, fn):
|
||||
"""Attaches a callable that will be called when the future finishes.
|
||||
|
||||
Args:
|
||||
fn: A callable that will be called with this future as its only
|
||||
argument when the future completes or is cancelled. The callable
|
||||
will always be called by a thread in the same process in which
|
||||
it was added. If the future has already completed or been
|
||||
cancelled then the callable will be called immediately. These
|
||||
callables are called in the order that they were added.
|
||||
"""
|
||||
with self._condition:
|
||||
if self._state not in [CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED]:
|
||||
self._done_callbacks.append(fn)
|
||||
return
|
||||
fn(self)
|
||||
|
||||
def result(self, timeout=None):
|
||||
"""Return the result of the call that the future represents.
|
||||
|
||||
Args:
|
||||
timeout: The number of seconds to wait for the result if the future
|
||||
isn't done. If None, then there is no limit on the wait time.
|
||||
|
||||
Returns:
|
||||
The result of the call that the future represents.
|
||||
|
||||
Raises:
|
||||
CancelledError: If the future was cancelled.
|
||||
TimeoutError: If the future didn't finish executing before the given
|
||||
timeout.
|
||||
Exception: If the call raised then that exception will be raised.
|
||||
"""
|
||||
with self._condition:
|
||||
if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
|
||||
raise CancelledError()
|
||||
elif self._state == FINISHED:
|
||||
return self.__get_result()
|
||||
|
||||
self._condition.wait(timeout)
|
||||
|
||||
if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
|
||||
raise CancelledError()
|
||||
elif self._state == FINISHED:
|
||||
return self.__get_result()
|
||||
else:
|
||||
raise TimeoutError()
|
||||
|
||||
def exception(self, timeout=None):
|
||||
"""Return the exception raised by the call that the future represents.
|
||||
|
||||
Args:
|
||||
timeout: The number of seconds to wait for the exception if the
|
||||
future isn't done. If None, then there is no limit on the wait
|
||||
time.
|
||||
|
||||
Returns:
|
||||
The exception raised by the call that the future represents or None
|
||||
if the call completed without raising.
|
||||
|
||||
Raises:
|
||||
CancelledError: If the future was cancelled.
|
||||
TimeoutError: If the future didn't finish executing before the given
|
||||
timeout.
|
||||
"""
|
||||
|
||||
with self._condition:
|
||||
if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
|
||||
raise CancelledError()
|
||||
elif self._state == FINISHED:
|
||||
return self._exception
|
||||
|
||||
self._condition.wait(timeout)
|
||||
|
||||
if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
|
||||
raise CancelledError()
|
||||
elif self._state == FINISHED:
|
||||
return self._exception
|
||||
else:
|
||||
raise TimeoutError()
|
||||
|
||||
# The following methods should only be used by Executors and in tests.
|
||||
def set_running_or_notify_cancel(self):
|
||||
"""Mark the future as running or process any cancel notifications.
|
||||
|
||||
Should only be used by Executor implementations and unit tests.
|
||||
|
||||
If the future has been cancelled (cancel() was called and returned
|
||||
True) then any threads waiting on the future completing (though calls
|
||||
to as_completed() or wait()) are notified and False is returned.
|
||||
|
||||
If the future was not cancelled then it is put in the running state
|
||||
(future calls to running() will return True) and True is returned.
|
||||
|
||||
This method should be called by Executor implementations before
|
||||
executing the work associated with this future. If this method returns
|
||||
False then the work should not be executed.
|
||||
|
||||
Returns:
|
||||
False if the Future was cancelled, True otherwise.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if this method was already called or if set_result()
|
||||
or set_exception() was called.
|
||||
"""
|
||||
with self._condition:
|
||||
if self._state == CANCELLED:
|
||||
self._state = CANCELLED_AND_NOTIFIED
|
||||
for waiter in self._waiters:
|
||||
waiter.add_cancelled(self)
|
||||
# self._condition.notify_all() is not necessary because
|
||||
# self.cancel() triggers a notification.
|
||||
return False
|
||||
elif self._state == PENDING:
|
||||
self._state = RUNNING
|
||||
return True
|
||||
else:
|
||||
LOGGER.critical('Future %s in unexpected state: %s',
|
||||
id(self.future),
|
||||
self.future._state)
|
||||
raise RuntimeError('Future in unexpected state')
|
||||
|
||||
def set_result(self, result):
|
||||
"""Sets the return value of work associated with the future.
|
||||
|
||||
Should only be used by Executor implementations and unit tests.
|
||||
"""
|
||||
with self._condition:
|
||||
self._result = result
|
||||
self._state = FINISHED
|
||||
for waiter in self._waiters:
|
||||
waiter.add_result(self)
|
||||
self._condition.notify_all()
|
||||
self._invoke_callbacks()
|
||||
|
||||
def set_exception(self, exception):
|
||||
"""Sets the result of the future as being the given exception.
|
||||
|
||||
Should only be used by Executor implementations and unit tests.
|
||||
"""
|
||||
with self._condition:
|
||||
self._exception = exception
|
||||
self._state = FINISHED
|
||||
for waiter in self._waiters:
|
||||
waiter.add_exception(self)
|
||||
self._condition.notify_all()
|
||||
self._invoke_callbacks()
|
||||
|
||||
class Executor(object):
|
||||
"""This is an abstract base class for concrete asynchronous executors."""
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
"""Submits a callable to be executed with the given arguments.
|
||||
|
||||
Schedules the callable to be executed as fn(*args, **kwargs) and returns
|
||||
a Future instance representing the execution of the callable.
|
||||
|
||||
Returns:
|
||||
A Future representing the given call.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def map(self, fn, *iterables, **kwargs):
|
||||
"""Returns a iterator equivalent to map(fn, iter).
|
||||
|
||||
Args:
|
||||
fn: A callable that will take as many arguments as there are
|
||||
passed iterables.
|
||||
timeout: The maximum number of seconds to wait. If None, then there
|
||||
is no limit on the wait time.
|
||||
|
||||
Returns:
|
||||
An iterator equivalent to: map(func, *iterables) but the calls may
|
||||
be evaluated out-of-order.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the entire result iterator could not be generated
|
||||
before the given timeout.
|
||||
Exception: If fn(*args) raises for any values.
|
||||
"""
|
||||
timeout = kwargs.get('timeout')
|
||||
if timeout is not None:
|
||||
end_time = timeout + time.time()
|
||||
|
||||
fs = [self.submit(fn, *args) for args in zip(*iterables)]
|
||||
|
||||
try:
|
||||
for future in fs:
|
||||
if timeout is None:
|
||||
yield future.result()
|
||||
else:
|
||||
yield future.result(end_time - time.time())
|
||||
finally:
|
||||
for future in fs:
|
||||
future.cancel()
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
"""Clean-up the resources associated with the Executor.
|
||||
|
||||
It is safe to call this method several times. Otherwise, no other
|
||||
methods can be called after this one.
|
||||
|
||||
Args:
|
||||
wait: If True then shutdown will not return until all running
|
||||
futures have finished executing and the resources used by the
|
||||
executor have been reclaimed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.shutdown(wait=True)
|
||||
return False
|
|
@ -1,101 +0,0 @@
|
|||
from keyword import iskeyword as _iskeyword
|
||||
from operator import itemgetter as _itemgetter
|
||||
import sys as _sys
|
||||
|
||||
|
||||
def namedtuple(typename, field_names):
|
||||
"""Returns a new subclass of tuple with named fields.
|
||||
|
||||
>>> Point = namedtuple('Point', 'x y')
|
||||
>>> Point.__doc__ # docstring for the new class
|
||||
'Point(x, y)'
|
||||
>>> p = Point(11, y=22) # instantiate with positional args or keywords
|
||||
>>> p[0] + p[1] # indexable like a plain tuple
|
||||
33
|
||||
>>> x, y = p # unpack like a regular tuple
|
||||
>>> x, y
|
||||
(11, 22)
|
||||
>>> p.x + p.y # fields also accessable by name
|
||||
33
|
||||
>>> d = p._asdict() # convert to a dictionary
|
||||
>>> d['x']
|
||||
11
|
||||
>>> Point(**d) # convert from a dictionary
|
||||
Point(x=11, y=22)
|
||||
>>> p._replace(x=100) # _replace() is like str.replace() but targets named fields
|
||||
Point(x=100, y=22)
|
||||
|
||||
"""
|
||||
|
||||
# Parse and validate the field names. Validation serves two purposes,
|
||||
# generating informative error messages and preventing template injection attacks.
|
||||
if isinstance(field_names, basestring):
|
||||
field_names = field_names.replace(',', ' ').split() # names separated by whitespace and/or commas
|
||||
field_names = tuple(map(str, field_names))
|
||||
for name in (typename,) + field_names:
|
||||
if not all(c.isalnum() or c=='_' for c in name):
|
||||
raise ValueError('Type names and field names can only contain alphanumeric characters and underscores: %r' % name)
|
||||
if _iskeyword(name):
|
||||
raise ValueError('Type names and field names cannot be a keyword: %r' % name)
|
||||
if name[0].isdigit():
|
||||
raise ValueError('Type names and field names cannot start with a number: %r' % name)
|
||||
seen_names = set()
|
||||
for name in field_names:
|
||||
if name.startswith('_'):
|
||||
raise ValueError('Field names cannot start with an underscore: %r' % name)
|
||||
if name in seen_names:
|
||||
raise ValueError('Encountered duplicate field name: %r' % name)
|
||||
seen_names.add(name)
|
||||
|
||||
# Create and fill-in the class template
|
||||
numfields = len(field_names)
|
||||
argtxt = repr(field_names).replace("'", "")[1:-1] # tuple repr without parens or quotes
|
||||
reprtxt = ', '.join('%s=%%r' % name for name in field_names)
|
||||
dicttxt = ', '.join('%r: t[%d]' % (name, pos) for pos, name in enumerate(field_names))
|
||||
template = '''class %(typename)s(tuple):
|
||||
'%(typename)s(%(argtxt)s)' \n
|
||||
__slots__ = () \n
|
||||
_fields = %(field_names)r \n
|
||||
def __new__(_cls, %(argtxt)s):
|
||||
return _tuple.__new__(_cls, (%(argtxt)s)) \n
|
||||
@classmethod
|
||||
def _make(cls, iterable, new=tuple.__new__, len=len):
|
||||
'Make a new %(typename)s object from a sequence or iterable'
|
||||
result = new(cls, iterable)
|
||||
if len(result) != %(numfields)d:
|
||||
raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result))
|
||||
return result \n
|
||||
def __repr__(self):
|
||||
return '%(typename)s(%(reprtxt)s)' %% self \n
|
||||
def _asdict(t):
|
||||
'Return a new dict which maps field names to their values'
|
||||
return {%(dicttxt)s} \n
|
||||
def _replace(_self, **kwds):
|
||||
'Return a new %(typename)s object replacing specified fields with new values'
|
||||
result = _self._make(map(kwds.pop, %(field_names)r, _self))
|
||||
if kwds:
|
||||
raise ValueError('Got unexpected field names: %%r' %% kwds.keys())
|
||||
return result \n
|
||||
def __getnewargs__(self):
|
||||
return tuple(self) \n\n''' % locals()
|
||||
for i, name in enumerate(field_names):
|
||||
template += ' %s = _property(_itemgetter(%d))\n' % (name, i)
|
||||
|
||||
# Execute the template string in a temporary namespace and
|
||||
# support tracing utilities by setting a value for frame.f_globals['__name__']
|
||||
namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename,
|
||||
_property=property, _tuple=tuple)
|
||||
try:
|
||||
exec(template, namespace)
|
||||
except SyntaxError:
|
||||
e = _sys.exc_info()[1]
|
||||
raise SyntaxError(e.message + ':\n' + template)
|
||||
result = namespace[typename]
|
||||
|
||||
# For pickling to work, the __module__ variable needs to be set to the frame
|
||||
# where the named tuple is created. Bypass this step in enviroments where
|
||||
# sys._getframe is not defined (Jython for example).
|
||||
if hasattr(_sys, '_getframe'):
|
||||
result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__')
|
||||
|
||||
return result
|
|
@ -1,363 +0,0 @@
|
|||
# Copyright 2009 Brian Quinlan. All Rights Reserved.
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
|
||||
"""Implements ProcessPoolExecutor.
|
||||
|
||||
The follow diagram and text describe the data-flow through the system:
|
||||
|
||||
|======================= In-process =====================|== Out-of-process ==|
|
||||
|
||||
+----------+ +----------+ +--------+ +-----------+ +---------+
|
||||
| | => | Work Ids | => | | => | Call Q | => | |
|
||||
| | +----------+ | | +-----------+ | |
|
||||
| | | ... | | | | ... | | |
|
||||
| | | 6 | | | | 5, call() | | |
|
||||
| | | 7 | | | | ... | | |
|
||||
| Process | | ... | | Local | +-----------+ | Process |
|
||||
| Pool | +----------+ | Worker | | #1..n |
|
||||
| Executor | | Thread | | |
|
||||
| | +----------- + | | +-----------+ | |
|
||||
| | <=> | Work Items | <=> | | <= | Result Q | <= | |
|
||||
| | +------------+ | | +-----------+ | |
|
||||
| | | 6: call() | | | | ... | | |
|
||||
| | | future | | | | 4, result | | |
|
||||
| | | ... | | | | 3, except | | |
|
||||
+----------+ +------------+ +--------+ +-----------+ +---------+
|
||||
|
||||
Executor.submit() called:
|
||||
- creates a uniquely numbered _WorkItem and adds it to the "Work Items" dict
|
||||
- adds the id of the _WorkItem to the "Work Ids" queue
|
||||
|
||||
Local worker thread:
|
||||
- reads work ids from the "Work Ids" queue and looks up the corresponding
|
||||
WorkItem from the "Work Items" dict: if the work item has been cancelled then
|
||||
it is simply removed from the dict, otherwise it is repackaged as a
|
||||
_CallItem and put in the "Call Q". New _CallItems are put in the "Call Q"
|
||||
until "Call Q" is full. NOTE: the size of the "Call Q" is kept small because
|
||||
calls placed in the "Call Q" can no longer be cancelled with Future.cancel().
|
||||
- reads _ResultItems from "Result Q", updates the future stored in the
|
||||
"Work Items" dict and deletes the dict entry
|
||||
|
||||
Process #1..n:
|
||||
- reads _CallItems from "Call Q", executes the calls, and puts the resulting
|
||||
_ResultItems in "Request Q"
|
||||
"""
|
||||
|
||||
from __future__ import with_statement
|
||||
import atexit
|
||||
import multiprocessing
|
||||
import threading
|
||||
import weakref
|
||||
import sys
|
||||
|
||||
from concurrent.futures import _base
|
||||
|
||||
try:
|
||||
import queue
|
||||
except ImportError:
|
||||
import Queue as queue
|
||||
|
||||
__author__ = 'Brian Quinlan (brian@sweetapp.com)'
|
||||
|
||||
# Workers are created as daemon threads and processes. This is done to allow the
|
||||
# interpreter to exit when there are still idle processes in a
|
||||
# ProcessPoolExecutor's process pool (i.e. shutdown() was not called). However,
|
||||
# allowing workers to die with the interpreter has two undesirable properties:
|
||||
# - The workers would still be running during interpretor shutdown,
|
||||
# meaning that they would fail in unpredictable ways.
|
||||
# - The workers could be killed while evaluating a work item, which could
|
||||
# be bad if the callable being evaluated has external side-effects e.g.
|
||||
# writing to a file.
|
||||
#
|
||||
# To work around this problem, an exit handler is installed which tells the
|
||||
# workers to exit when their work queues are empty and then waits until the
|
||||
# threads/processes finish.
|
||||
|
||||
_threads_queues = weakref.WeakKeyDictionary()
|
||||
_shutdown = False
|
||||
|
||||
def _python_exit():
|
||||
global _shutdown
|
||||
_shutdown = True
|
||||
items = list(_threads_queues.items())
|
||||
for t, q in items:
|
||||
q.put(None)
|
||||
for t, q in items:
|
||||
t.join()
|
||||
|
||||
# Controls how many more calls than processes will be queued in the call queue.
|
||||
# A smaller number will mean that processes spend more time idle waiting for
|
||||
# work while a larger number will make Future.cancel() succeed less frequently
|
||||
# (Futures in the call queue cannot be cancelled).
|
||||
EXTRA_QUEUED_CALLS = 1
|
||||
|
||||
class _WorkItem(object):
|
||||
def __init__(self, future, fn, args, kwargs):
|
||||
self.future = future
|
||||
self.fn = fn
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
class _ResultItem(object):
|
||||
def __init__(self, work_id, exception=None, result=None):
|
||||
self.work_id = work_id
|
||||
self.exception = exception
|
||||
self.result = result
|
||||
|
||||
class _CallItem(object):
|
||||
def __init__(self, work_id, fn, args, kwargs):
|
||||
self.work_id = work_id
|
||||
self.fn = fn
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _process_worker(call_queue, result_queue):
|
||||
"""Evaluates calls from call_queue and places the results in result_queue.
|
||||
|
||||
This worker is run in a separate process.
|
||||
|
||||
Args:
|
||||
call_queue: A multiprocessing.Queue of _CallItems that will be read and
|
||||
evaluated by the worker.
|
||||
result_queue: A multiprocessing.Queue of _ResultItems that will written
|
||||
to by the worker.
|
||||
shutdown: A multiprocessing.Event that will be set as a signal to the
|
||||
worker that it should exit when call_queue is empty.
|
||||
"""
|
||||
while True:
|
||||
call_item = call_queue.get(block=True)
|
||||
if call_item is None:
|
||||
# Wake up queue management thread
|
||||
result_queue.put(None)
|
||||
return
|
||||
try:
|
||||
r = call_item.fn(*call_item.args, **call_item.kwargs)
|
||||
except BaseException:
|
||||
e = sys.exc_info()[1]
|
||||
result_queue.put(_ResultItem(call_item.work_id,
|
||||
exception=e))
|
||||
else:
|
||||
result_queue.put(_ResultItem(call_item.work_id,
|
||||
result=r))
|
||||
|
||||
def _add_call_item_to_queue(pending_work_items,
|
||||
work_ids,
|
||||
call_queue):
|
||||
"""Fills call_queue with _WorkItems from pending_work_items.
|
||||
|
||||
This function never blocks.
|
||||
|
||||
Args:
|
||||
pending_work_items: A dict mapping work ids to _WorkItems e.g.
|
||||
{5: <_WorkItem...>, 6: <_WorkItem...>, ...}
|
||||
work_ids: A queue.Queue of work ids e.g. Queue([5, 6, ...]). Work ids
|
||||
are consumed and the corresponding _WorkItems from
|
||||
pending_work_items are transformed into _CallItems and put in
|
||||
call_queue.
|
||||
call_queue: A multiprocessing.Queue that will be filled with _CallItems
|
||||
derived from _WorkItems.
|
||||
"""
|
||||
while True:
|
||||
if call_queue.full():
|
||||
return
|
||||
try:
|
||||
work_id = work_ids.get(block=False)
|
||||
except queue.Empty:
|
||||
return
|
||||
else:
|
||||
work_item = pending_work_items[work_id]
|
||||
|
||||
if work_item.future.set_running_or_notify_cancel():
|
||||
call_queue.put(_CallItem(work_id,
|
||||
work_item.fn,
|
||||
work_item.args,
|
||||
work_item.kwargs),
|
||||
block=True)
|
||||
else:
|
||||
del pending_work_items[work_id]
|
||||
continue
|
||||
|
||||
def _queue_management_worker(executor_reference,
|
||||
processes,
|
||||
pending_work_items,
|
||||
work_ids_queue,
|
||||
call_queue,
|
||||
result_queue):
|
||||
"""Manages the communication between this process and the worker processes.
|
||||
|
||||
This function is run in a local thread.
|
||||
|
||||
Args:
|
||||
executor_reference: A weakref.ref to the ProcessPoolExecutor that owns
|
||||
this thread. Used to determine if the ProcessPoolExecutor has been
|
||||
garbage collected and that this function can exit.
|
||||
process: A list of the multiprocessing.Process instances used as
|
||||
workers.
|
||||
pending_work_items: A dict mapping work ids to _WorkItems e.g.
|
||||
{5: <_WorkItem...>, 6: <_WorkItem...>, ...}
|
||||
work_ids_queue: A queue.Queue of work ids e.g. Queue([5, 6, ...]).
|
||||
call_queue: A multiprocessing.Queue that will be filled with _CallItems
|
||||
derived from _WorkItems for processing by the process workers.
|
||||
result_queue: A multiprocessing.Queue of _ResultItems generated by the
|
||||
process workers.
|
||||
"""
|
||||
nb_shutdown_processes = [0]
|
||||
def shutdown_one_process():
|
||||
"""Tell a worker to terminate, which will in turn wake us again"""
|
||||
call_queue.put(None)
|
||||
nb_shutdown_processes[0] += 1
|
||||
while True:
|
||||
_add_call_item_to_queue(pending_work_items,
|
||||
work_ids_queue,
|
||||
call_queue)
|
||||
|
||||
result_item = result_queue.get(block=True)
|
||||
if result_item is not None:
|
||||
work_item = pending_work_items[result_item.work_id]
|
||||
del pending_work_items[result_item.work_id]
|
||||
|
||||
if result_item.exception:
|
||||
work_item.future.set_exception(result_item.exception)
|
||||
else:
|
||||
work_item.future.set_result(result_item.result)
|
||||
# Check whether we should start shutting down.
|
||||
executor = executor_reference()
|
||||
# No more work items can be added if:
|
||||
# - The interpreter is shutting down OR
|
||||
# - The executor that owns this worker has been collected OR
|
||||
# - The executor that owns this worker has been shutdown.
|
||||
if _shutdown or executor is None or executor._shutdown_thread:
|
||||
# Since no new work items can be added, it is safe to shutdown
|
||||
# this thread if there are no pending work items.
|
||||
if not pending_work_items:
|
||||
while nb_shutdown_processes[0] < len(processes):
|
||||
shutdown_one_process()
|
||||
# If .join() is not called on the created processes then
|
||||
# some multiprocessing.Queue methods may deadlock on Mac OS
|
||||
# X.
|
||||
for p in processes:
|
||||
p.join()
|
||||
call_queue.close()
|
||||
return
|
||||
del executor
|
||||
|
||||
_system_limits_checked = False
|
||||
_system_limited = None
|
||||
def _check_system_limits():
|
||||
global _system_limits_checked, _system_limited
|
||||
if _system_limits_checked:
|
||||
if _system_limited:
|
||||
raise NotImplementedError(_system_limited)
|
||||
_system_limits_checked = True
|
||||
try:
|
||||
import os
|
||||
nsems_max = os.sysconf("SC_SEM_NSEMS_MAX")
|
||||
except (AttributeError, ValueError):
|
||||
# sysconf not available or setting not available
|
||||
return
|
||||
if nsems_max == -1:
|
||||
# indetermine limit, assume that limit is determined
|
||||
# by available memory only
|
||||
return
|
||||
if nsems_max >= 256:
|
||||
# minimum number of semaphores available
|
||||
# according to POSIX
|
||||
return
|
||||
_system_limited = "system provides too few semaphores (%d available, 256 necessary)" % nsems_max
|
||||
raise NotImplementedError(_system_limited)
|
||||
|
||||
class ProcessPoolExecutor(_base.Executor):
|
||||
def __init__(self, max_workers=None):
|
||||
"""Initializes a new ProcessPoolExecutor instance.
|
||||
|
||||
Args:
|
||||
max_workers: The maximum number of processes that can be used to
|
||||
execute the given calls. If None or not given then as many
|
||||
worker processes will be created as the machine has processors.
|
||||
"""
|
||||
_check_system_limits()
|
||||
|
||||
if max_workers is None:
|
||||
self._max_workers = multiprocessing.cpu_count()
|
||||
else:
|
||||
self._max_workers = max_workers
|
||||
|
||||
# Make the call queue slightly larger than the number of processes to
|
||||
# prevent the worker processes from idling. But don't make it too big
|
||||
# because futures in the call queue cannot be cancelled.
|
||||
self._call_queue = multiprocessing.Queue(self._max_workers +
|
||||
EXTRA_QUEUED_CALLS)
|
||||
self._result_queue = multiprocessing.Queue()
|
||||
self._work_ids = queue.Queue()
|
||||
self._queue_management_thread = None
|
||||
self._processes = set()
|
||||
|
||||
# Shutdown is a two-step process.
|
||||
self._shutdown_thread = False
|
||||
self._shutdown_lock = threading.Lock()
|
||||
self._queue_count = 0
|
||||
self._pending_work_items = {}
|
||||
|
||||
def _start_queue_management_thread(self):
|
||||
# When the executor gets lost, the weakref callback will wake up
|
||||
# the queue management thread.
|
||||
def weakref_cb(_, q=self._result_queue):
|
||||
q.put(None)
|
||||
if self._queue_management_thread is None:
|
||||
self._queue_management_thread = threading.Thread(
|
||||
target=_queue_management_worker,
|
||||
args=(weakref.ref(self, weakref_cb),
|
||||
self._processes,
|
||||
self._pending_work_items,
|
||||
self._work_ids,
|
||||
self._call_queue,
|
||||
self._result_queue))
|
||||
self._queue_management_thread.daemon = True
|
||||
self._queue_management_thread.start()
|
||||
_threads_queues[self._queue_management_thread] = self._result_queue
|
||||
|
||||
def _adjust_process_count(self):
|
||||
for _ in range(len(self._processes), self._max_workers):
|
||||
p = multiprocessing.Process(
|
||||
target=_process_worker,
|
||||
args=(self._call_queue,
|
||||
self._result_queue))
|
||||
p.start()
|
||||
self._processes.add(p)
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
with self._shutdown_lock:
|
||||
if self._shutdown_thread:
|
||||
raise RuntimeError('cannot schedule new futures after shutdown')
|
||||
|
||||
f = _base.Future()
|
||||
w = _WorkItem(f, fn, args, kwargs)
|
||||
|
||||
self._pending_work_items[self._queue_count] = w
|
||||
self._work_ids.put(self._queue_count)
|
||||
self._queue_count += 1
|
||||
# Wake up queue management thread
|
||||
self._result_queue.put(None)
|
||||
|
||||
self._start_queue_management_thread()
|
||||
self._adjust_process_count()
|
||||
return f
|
||||
submit.__doc__ = _base.Executor.submit.__doc__
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
with self._shutdown_lock:
|
||||
self._shutdown_thread = True
|
||||
if self._queue_management_thread:
|
||||
# Wake up queue management thread
|
||||
self._result_queue.put(None)
|
||||
if wait:
|
||||
self._queue_management_thread.join()
|
||||
# To reduce the risk of openning too many files, remove references to
|
||||
# objects that use file descriptors.
|
||||
self._queue_management_thread = None
|
||||
self._call_queue = None
|
||||
self._result_queue = None
|
||||
self._processes = None
|
||||
shutdown.__doc__ = _base.Executor.shutdown.__doc__
|
||||
|
||||
atexit.register(_python_exit)
|
|
@ -1,145 +0,0 @@
|
|||
# Copyright 2009 Brian Quinlan. All Rights Reserved.
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
|
||||
"""Implements ThreadPoolExecutor."""
|
||||
|
||||
from __future__ import with_statement
|
||||
import atexit
|
||||
import threading
|
||||
import weakref
|
||||
import sys
|
||||
|
||||
from concurrent.futures import _base
|
||||
|
||||
try:
|
||||
import queue
|
||||
except ImportError:
|
||||
import Queue as queue
|
||||
|
||||
__author__ = 'Brian Quinlan (brian@sweetapp.com)'
|
||||
|
||||
# Workers are created as daemon threads. This is done to allow the interpreter
|
||||
# to exit when there are still idle threads in a ThreadPoolExecutor's thread
|
||||
# pool (i.e. shutdown() was not called). However, allowing workers to die with
|
||||
# the interpreter has two undesirable properties:
|
||||
# - The workers would still be running during interpretor shutdown,
|
||||
# meaning that they would fail in unpredictable ways.
|
||||
# - The workers could be killed while evaluating a work item, which could
|
||||
# be bad if the callable being evaluated has external side-effects e.g.
|
||||
# writing to a file.
|
||||
#
|
||||
# To work around this problem, an exit handler is installed which tells the
|
||||
# workers to exit when their work queues are empty and then waits until the
|
||||
# threads finish.
|
||||
|
||||
_threads_queues = weakref.WeakKeyDictionary()
|
||||
_shutdown = False
|
||||
|
||||
def _python_exit():
|
||||
global _shutdown
|
||||
_shutdown = True
|
||||
items = list(_threads_queues.items())
|
||||
for t, q in items:
|
||||
q.put(None)
|
||||
for t, q in items:
|
||||
t.join()
|
||||
|
||||
atexit.register(_python_exit)
|
||||
|
||||
class _WorkItem(object):
|
||||
def __init__(self, future, fn, args, kwargs):
|
||||
self.future = future
|
||||
self.fn = fn
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def run(self):
|
||||
if not self.future.set_running_or_notify_cancel():
|
||||
return
|
||||
|
||||
try:
|
||||
result = self.fn(*self.args, **self.kwargs)
|
||||
except BaseException:
|
||||
e = sys.exc_info()[1]
|
||||
self.future.set_exception(e)
|
||||
else:
|
||||
self.future.set_result(result)
|
||||
|
||||
def _worker(executor_reference, work_queue):
|
||||
try:
|
||||
while True:
|
||||
work_item = work_queue.get(block=True)
|
||||
if work_item is not None:
|
||||
work_item.run()
|
||||
continue
|
||||
executor = executor_reference()
|
||||
# Exit if:
|
||||
# - The interpreter is shutting down OR
|
||||
# - The executor that owns the worker has been collected OR
|
||||
# - The executor that owns the worker has been shutdown.
|
||||
if _shutdown or executor is None or executor._shutdown:
|
||||
# Notice other workers
|
||||
work_queue.put(None)
|
||||
return
|
||||
del executor
|
||||
except BaseException:
|
||||
_base.LOGGER.critical('Exception in worker', exc_info=True)
|
||||
|
||||
class ThreadPoolExecutor(_base.Executor):
|
||||
def __init__(self, max_workers):
|
||||
"""Initializes a new ThreadPoolExecutor instance.
|
||||
|
||||
Args:
|
||||
max_workers: The maximum number of threads that can be used to
|
||||
execute the given calls.
|
||||
"""
|
||||
self._max_workers = max_workers
|
||||
self._work_queue = queue.Queue()
|
||||
self._threads = set()
|
||||
self._shutdown = False
|
||||
self._shutdown_lock = threading.Lock()
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
with self._shutdown_lock:
|
||||
if self._shutdown:
|
||||
raise RuntimeError('cannot schedule new futures after shutdown')
|
||||
|
||||
f = _base.Future()
|
||||
w = _WorkItem(f, fn, args, kwargs)
|
||||
|
||||
self._work_queue.put(w)
|
||||
|
||||
name = None
|
||||
if kwargs.has_key('name'):
|
||||
name = kwargs.pop('name')
|
||||
|
||||
self._adjust_thread_count(name)
|
||||
return f
|
||||
submit.__doc__ = _base.Executor.submit.__doc__
|
||||
|
||||
def _adjust_thread_count(self, name=None):
|
||||
# When the executor gets lost, the weakref callback will wake up
|
||||
# the worker threads.
|
||||
def weakref_cb(_, q=self._work_queue):
|
||||
q.put(None)
|
||||
# TODO(bquinlan): Should avoid creating new threads if there are more
|
||||
# idle threads than items in the work queue.
|
||||
if len(self._threads) < self._max_workers:
|
||||
t = threading.Thread(target=_worker,
|
||||
args=(weakref.ref(self, weakref_cb),
|
||||
self._work_queue),)
|
||||
if name:
|
||||
t.name = name
|
||||
t.daemon = True
|
||||
t.start()
|
||||
self._threads.add(t)
|
||||
_threads_queues[t] = self._work_queue
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
with self._shutdown_lock:
|
||||
self._shutdown = True
|
||||
self._work_queue.put(None)
|
||||
if wait:
|
||||
for t in self._threads:
|
||||
t.join()
|
||||
shutdown.__doc__ = _base.Executor.shutdown.__doc__
|
|
@ -1,24 +0,0 @@
|
|||
# Copyright 2009 Brian Quinlan. All Rights Reserved.
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
|
||||
"""Execute computations asynchronously using threads or processes."""
|
||||
|
||||
import warnings
|
||||
|
||||
from concurrent.futures import (FIRST_COMPLETED,
|
||||
FIRST_EXCEPTION,
|
||||
ALL_COMPLETED,
|
||||
CancelledError,
|
||||
TimeoutError,
|
||||
Future,
|
||||
Executor,
|
||||
wait,
|
||||
as_completed,
|
||||
ProcessPoolExecutor,
|
||||
ThreadPoolExecutor)
|
||||
|
||||
__author__ = 'Brian Quinlan (brian@sweetapp.com)'
|
||||
|
||||
warnings.warn('The futures package has been deprecated. '
|
||||
'Use the concurrent.futures package instead.',
|
||||
DeprecationWarning)
|
|
@ -1 +0,0 @@
|
|||
from concurrent.futures import ProcessPoolExecutor
|
|
@ -1 +0,0 @@
|
|||
from concurrent.futures import ThreadPoolExecutor
|
|
@ -1,519 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''Common object storage frontend.'''
|
||||
|
||||
import os
|
||||
import zlib
|
||||
import urllib
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
from collections import deque
|
||||
|
||||
try:
|
||||
# Import store and cache entry points if setuptools installed
|
||||
import pkg_resources
|
||||
stores = dict((_store.name, _store) for _store in
|
||||
pkg_resources.iter_entry_points('shove.stores'))
|
||||
caches = dict((_cache.name, _cache) for _cache in
|
||||
pkg_resources.iter_entry_points('shove.caches'))
|
||||
# Pass if nothing loaded
|
||||
if not stores and not caches:
|
||||
raise ImportError()
|
||||
except ImportError:
|
||||
# Static store backend registry
|
||||
stores = dict(
|
||||
bsddb='shove.store.bsdb:BsdStore',
|
||||
cassandra='shove.store.cassandra:CassandraStore',
|
||||
dbm='shove.store.dbm:DbmStore',
|
||||
durus='shove.store.durusdb:DurusStore',
|
||||
file='shove.store.file:FileStore',
|
||||
firebird='shove.store.db:DbStore',
|
||||
ftp='shove.store.ftp:FtpStore',
|
||||
hdf5='shove.store.hdf5:HDF5Store',
|
||||
leveldb='shove.store.leveldbstore:LevelDBStore',
|
||||
memory='shove.store.memory:MemoryStore',
|
||||
mssql='shove.store.db:DbStore',
|
||||
mysql='shove.store.db:DbStore',
|
||||
oracle='shove.store.db:DbStore',
|
||||
postgres='shove.store.db:DbStore',
|
||||
redis='shove.store.redisdb:RedisStore',
|
||||
s3='shove.store.s3:S3Store',
|
||||
simple='shove.store.simple:SimpleStore',
|
||||
sqlite='shove.store.db:DbStore',
|
||||
svn='shove.store.svn:SvnStore',
|
||||
zodb='shove.store.zodb:ZodbStore',
|
||||
)
|
||||
# Static cache backend registry
|
||||
caches = dict(
|
||||
bsddb='shove.cache.bsdb:BsdCache',
|
||||
file='shove.cache.file:FileCache',
|
||||
filelru='shove.cache.filelru:FileLRUCache',
|
||||
firebird='shove.cache.db:DbCache',
|
||||
memcache='shove.cache.memcached:MemCached',
|
||||
memlru='shove.cache.memlru:MemoryLRUCache',
|
||||
memory='shove.cache.memory:MemoryCache',
|
||||
mssql='shove.cache.db:DbCache',
|
||||
mysql='shove.cache.db:DbCache',
|
||||
oracle='shove.cache.db:DbCache',
|
||||
postgres='shove.cache.db:DbCache',
|
||||
redis='shove.cache.redisdb:RedisCache',
|
||||
simple='shove.cache.simple:SimpleCache',
|
||||
simplelru='shove.cache.simplelru:SimpleLRUCache',
|
||||
sqlite='shove.cache.db:DbCache',
|
||||
)
|
||||
|
||||
|
||||
def getbackend(uri, engines, **kw):
|
||||
'''
|
||||
Loads the right backend based on a URI.
|
||||
|
||||
@param uri Instance or name string
|
||||
@param engines A dictionary of scheme/class pairs
|
||||
'''
|
||||
if isinstance(uri, basestring):
|
||||
mod = engines[uri.split('://', 1)[0]]
|
||||
# Load module if setuptools not present
|
||||
if isinstance(mod, basestring):
|
||||
# Isolate classname from dot path
|
||||
module, klass = mod.split(':')
|
||||
# Load module
|
||||
mod = getattr(__import__(module, '', '', ['']), klass)
|
||||
# Load appropriate class from setuptools entry point
|
||||
else:
|
||||
mod = mod.load()
|
||||
# Return instance
|
||||
return mod(uri, **kw)
|
||||
# No-op for existing instances
|
||||
return uri
|
||||
|
||||
|
||||
def synchronized(func):
|
||||
'''
|
||||
Decorator to lock and unlock a method (Phillip J. Eby).
|
||||
|
||||
@param func Method to decorate
|
||||
'''
|
||||
def wrapper(self, *__args, **__kw):
|
||||
self._lock.acquire()
|
||||
try:
|
||||
return func(self, *__args, **__kw)
|
||||
finally:
|
||||
self._lock.release()
|
||||
wrapper.__name__ = func.__name__
|
||||
wrapper.__dict__ = func.__dict__
|
||||
wrapper.__doc__ = func.__doc__
|
||||
return wrapper
|
||||
|
||||
|
||||
class Base(object):
|
||||
|
||||
'''Base Mapping class.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
'''
|
||||
@keyword compress True, False, or an integer compression level (1-9).
|
||||
'''
|
||||
self._compress = kw.get('compress', False)
|
||||
self._protocol = kw.get('protocol', pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def __getitem__(self, key):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __delitem__(self, key):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __contains__(self, key):
|
||||
try:
|
||||
value = self[key]
|
||||
except KeyError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(self, key, default=None):
|
||||
'''
|
||||
Fetch a given key from the mapping. If the key does not exist,
|
||||
return the default.
|
||||
|
||||
@param key Keyword of item in mapping.
|
||||
@param default Default value (default: None)
|
||||
'''
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def dumps(self, value):
|
||||
'''Optionally serializes and compresses an object.'''
|
||||
# Serialize everything but ASCII strings
|
||||
value = pickle.dumps(value, protocol=self._protocol)
|
||||
if self._compress:
|
||||
level = 9 if self._compress is True else self._compress
|
||||
value = zlib.compress(value, level)
|
||||
return value
|
||||
|
||||
def loads(self, value):
|
||||
'''Deserializes and optionally decompresses an object.'''
|
||||
if self._compress:
|
||||
try:
|
||||
value = zlib.decompress(value)
|
||||
except zlib.error:
|
||||
pass
|
||||
value = pickle.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
class BaseStore(Base):
|
||||
|
||||
'''Base Store class (based on UserDict.DictMixin).'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(BaseStore, self).__init__(engine, **kw)
|
||||
self._store = None
|
||||
|
||||
def __cmp__(self, other):
|
||||
if other is None:
|
||||
return False
|
||||
if isinstance(other, BaseStore):
|
||||
return cmp(dict(self.iteritems()), dict(other.iteritems()))
|
||||
|
||||
def __del__(self):
|
||||
# __init__ didn't succeed, so don't bother closing
|
||||
if not hasattr(self, '_store'):
|
||||
return
|
||||
self.close()
|
||||
|
||||
def __iter__(self):
|
||||
for k in self.keys():
|
||||
yield k
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys())
|
||||
|
||||
def __repr__(self):
|
||||
return repr(dict(self.iteritems()))
|
||||
|
||||
def close(self):
|
||||
'''Closes internal store and clears object references.'''
|
||||
try:
|
||||
self._store.close()
|
||||
except AttributeError:
|
||||
pass
|
||||
self._store = None
|
||||
|
||||
def clear(self):
|
||||
'''Removes all keys and values from a store.'''
|
||||
for key in self.keys():
|
||||
del self[key]
|
||||
|
||||
def items(self):
|
||||
'''Returns a list with all key/value pairs in the store.'''
|
||||
return list(self.iteritems())
|
||||
|
||||
def iteritems(self):
|
||||
'''Lazily returns all key/value pairs in a store.'''
|
||||
for k in self:
|
||||
yield (k, self[k])
|
||||
|
||||
def iterkeys(self):
|
||||
'''Lazy returns all keys in a store.'''
|
||||
return self.__iter__()
|
||||
|
||||
def itervalues(self):
|
||||
'''Lazily returns all values in a store.'''
|
||||
for _, v in self.iteritems():
|
||||
yield v
|
||||
|
||||
def keys(self):
|
||||
'''Returns a list with all keys in a store.'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def pop(self, key, *args):
|
||||
'''
|
||||
Removes and returns a value from a store.
|
||||
|
||||
@param args Default to return if key not present.
|
||||
'''
|
||||
if len(args) > 1:
|
||||
raise TypeError('pop expected at most 2 arguments, got ' + repr(
|
||||
1 + len(args))
|
||||
)
|
||||
try:
|
||||
value = self[key]
|
||||
# Return default if key not in store
|
||||
except KeyError:
|
||||
if args:
|
||||
return args[0]
|
||||
del self[key]
|
||||
return value
|
||||
|
||||
def popitem(self):
|
||||
'''Removes and returns a key, value pair from a store.'''
|
||||
try:
|
||||
k, v = self.iteritems().next()
|
||||
except StopIteration:
|
||||
raise KeyError('Store is empty.')
|
||||
del self[k]
|
||||
return (k, v)
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
'''
|
||||
Returns the value corresponding to an existing key or sets the
|
||||
to key to the default and returns the default.
|
||||
|
||||
@param default Default value (default: None)
|
||||
'''
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
self[key] = default
|
||||
return default
|
||||
|
||||
def update(self, other=None, **kw):
|
||||
'''
|
||||
Adds to or overwrites the values in this store with values from
|
||||
another store.
|
||||
|
||||
other Another store
|
||||
kw Additional keys and values to store
|
||||
'''
|
||||
if other is None:
|
||||
pass
|
||||
elif hasattr(other, 'iteritems'):
|
||||
for k, v in other.iteritems():
|
||||
self[k] = v
|
||||
elif hasattr(other, 'keys'):
|
||||
for k in other.keys():
|
||||
self[k] = other[k]
|
||||
else:
|
||||
for k, v in other:
|
||||
self[k] = v
|
||||
if kw:
|
||||
self.update(kw)
|
||||
|
||||
def values(self):
|
||||
'''Returns a list with all values in a store.'''
|
||||
return list(v for _, v in self.iteritems())
|
||||
|
||||
|
||||
class Shove(BaseStore):
|
||||
|
||||
'''Common object frontend class.'''
|
||||
|
||||
def __init__(self, store='simple://', cache='simple://', **kw):
|
||||
super(Shove, self).__init__(store, **kw)
|
||||
# Load store
|
||||
self._store = getbackend(store, stores, **kw)
|
||||
# Load cache
|
||||
self._cache = getbackend(cache, caches, **kw)
|
||||
# Buffer for lazy writing and setting for syncing frequency
|
||||
self._buffer, self._sync = dict(), kw.get('sync', 2)
|
||||
|
||||
def __getitem__(self, key):
|
||||
'''Gets a item from shove.'''
|
||||
try:
|
||||
return self._cache[key]
|
||||
except KeyError:
|
||||
# Synchronize cache and store
|
||||
self.sync()
|
||||
value = self._store[key]
|
||||
self._cache[key] = value
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
'''Sets an item in shove.'''
|
||||
self._cache[key] = self._buffer[key] = value
|
||||
# When the buffer reaches self._limit, writes the buffer to the store
|
||||
if len(self._buffer) >= self._sync:
|
||||
self.sync()
|
||||
|
||||
def __delitem__(self, key):
|
||||
'''Deletes an item from shove.'''
|
||||
try:
|
||||
del self._cache[key]
|
||||
except KeyError:
|
||||
pass
|
||||
self.sync()
|
||||
del self._store[key]
|
||||
|
||||
def keys(self):
|
||||
'''Returns a list of keys in shove.'''
|
||||
self.sync()
|
||||
return self._store.keys()
|
||||
|
||||
def sync(self):
|
||||
'''Writes buffer to store.'''
|
||||
for k, v in self._buffer.iteritems():
|
||||
self._store[k] = v
|
||||
self._buffer.clear()
|
||||
|
||||
def close(self):
|
||||
'''Finalizes and closes shove.'''
|
||||
# If close has been called, pass
|
||||
if self._store is not None:
|
||||
try:
|
||||
self.sync()
|
||||
except AttributeError:
|
||||
pass
|
||||
self._store.close()
|
||||
self._store = self._cache = self._buffer = None
|
||||
|
||||
|
||||
class FileBase(Base):
|
||||
|
||||
'''Base class for file based storage.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(FileBase, self).__init__(engine, **kw)
|
||||
if engine.startswith('file://'):
|
||||
engine = urllib.url2pathname(engine.split('://')[1])
|
||||
self._dir = engine
|
||||
# Create directory
|
||||
if not os.path.exists(self._dir):
|
||||
self._createdir()
|
||||
|
||||
def __getitem__(self, key):
|
||||
# (per Larry Meyn)
|
||||
try:
|
||||
item = open(self._key_to_file(key), 'rb')
|
||||
data = item.read()
|
||||
item.close()
|
||||
return self.loads(data)
|
||||
except:
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# (per Larry Meyn)
|
||||
try:
|
||||
item = open(self._key_to_file(key), 'wb')
|
||||
item.write(self.dumps(value))
|
||||
item.close()
|
||||
except (IOError, OSError):
|
||||
raise KeyError(key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
try:
|
||||
os.remove(self._key_to_file(key))
|
||||
except (IOError, OSError):
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key):
|
||||
return os.path.exists(self._key_to_file(key))
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self._dir))
|
||||
|
||||
def _createdir(self):
|
||||
'''Creates the store directory.'''
|
||||
try:
|
||||
os.makedirs(self._dir)
|
||||
except OSError:
|
||||
raise EnvironmentError(
|
||||
'Cache directory "%s" does not exist and ' \
|
||||
'could not be created' % self._dir
|
||||
)
|
||||
|
||||
def _key_to_file(self, key):
|
||||
'''Gives the filesystem path for a key.'''
|
||||
return os.path.join(self._dir, urllib.quote_plus(key))
|
||||
|
||||
def keys(self):
|
||||
'''Returns a list of keys in the store.'''
|
||||
return [urllib.unquote_plus(name) for name in os.listdir(self._dir)]
|
||||
|
||||
|
||||
class SimpleBase(Base):
|
||||
|
||||
'''Single-process in-memory store base class.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(SimpleBase, self).__init__(engine, **kw)
|
||||
self._store = dict()
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
return self._store[key]
|
||||
except:
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._store[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
try:
|
||||
del self._store[key]
|
||||
except:
|
||||
raise KeyError(key)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._store)
|
||||
|
||||
def keys(self):
|
||||
'''Returns a list of keys in the store.'''
|
||||
return self._store.keys()
|
||||
|
||||
|
||||
class LRUBase(SimpleBase):
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(LRUBase, self).__init__(engine, **kw)
|
||||
self._max_entries = kw.get('max_entries', 300)
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
self._queue = deque()
|
||||
self._refcount = dict()
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
value = super(LRUBase, self).__getitem__(key)
|
||||
self._hits += 1
|
||||
except KeyError:
|
||||
self._misses += 1
|
||||
raise
|
||||
self._housekeep(key)
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super(LRUBase, self).__setitem__(key, value)
|
||||
self._housekeep(key)
|
||||
if len(self._store) > self._max_entries:
|
||||
while len(self._store) > self._max_entries:
|
||||
k = self._queue.popleft()
|
||||
self._refcount[k] -= 1
|
||||
if not self._refcount[k]:
|
||||
super(LRUBase, self).__delitem__(k)
|
||||
del self._refcount[k]
|
||||
|
||||
def _housekeep(self, key):
|
||||
self._queue.append(key)
|
||||
self._refcount[key] = self._refcount.get(key, 0) + 1
|
||||
if len(self._queue) > self._max_entries * 4:
|
||||
self._purge_queue()
|
||||
|
||||
def _purge_queue(self):
|
||||
for i in [None] * len(self._queue):
|
||||
k = self._queue.popleft()
|
||||
if self._refcount[k] == 1:
|
||||
self._queue.append(k)
|
||||
else:
|
||||
self._refcount[k] -= 1
|
||||
|
||||
|
||||
class DbBase(Base):
|
||||
|
||||
'''Database common base class.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(DbBase, self).__init__(engine, **kw)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._store.delete(self._store.c.key == key).execute()
|
||||
|
||||
def __len__(self):
|
||||
return self._store.count().execute().fetchone()[0]
|
||||
|
||||
|
||||
__all__ = ['Shove']
|
1
lib/shove/cache/__init__.py
vendored
1
lib/shove/cache/__init__.py
vendored
|
@ -1 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
117
lib/shove/cache/db.py
vendored
117
lib/shove/cache/db.py
vendored
|
@ -1,117 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Database object cache.
|
||||
|
||||
The shove psuedo-URL used for database object caches is the format used by
|
||||
SQLAlchemy:
|
||||
|
||||
<driver>://<username>:<password>@<host>:<port>/<database>
|
||||
|
||||
<driver> is the database engine. The engines currently supported SQLAlchemy are
|
||||
sqlite, mysql, postgres, oracle, mssql, and firebird.
|
||||
<username> is the database account user name
|
||||
<password> is the database accound password
|
||||
<host> is the database location
|
||||
<port> is the database port
|
||||
<database> is the name of the specific database
|
||||
|
||||
For more information on specific databases see:
|
||||
|
||||
http://www.sqlalchemy.org/docs/dbengine.myt#dbengine_supported
|
||||
'''
|
||||
|
||||
import time
|
||||
import random
|
||||
from datetime import datetime
|
||||
try:
|
||||
from sqlalchemy import (
|
||||
MetaData, Table, Column, String, Binary, DateTime, select, update,
|
||||
insert, delete,
|
||||
)
|
||||
from shove import DbBase
|
||||
except ImportError:
|
||||
raise ImportError('Requires SQLAlchemy >= 0.4')
|
||||
|
||||
__all__ = ['DbCache']
|
||||
|
||||
|
||||
class DbCache(DbBase):
|
||||
|
||||
'''database cache backend'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(DbCache, self).__init__(engine, **kw)
|
||||
# Get table name
|
||||
tablename = kw.get('tablename', 'cache')
|
||||
# Bind metadata
|
||||
self._metadata = MetaData(engine)
|
||||
# Make cache table
|
||||
self._store = Table(tablename, self._metadata,
|
||||
Column('key', String(60), primary_key=True, nullable=False),
|
||||
Column('value', Binary, nullable=False),
|
||||
Column('expires', DateTime, nullable=False),
|
||||
)
|
||||
# Create cache table if it does not exist
|
||||
if not self._store.exists():
|
||||
self._store.create()
|
||||
# Set maximum entries
|
||||
self._max_entries = kw.get('max_entries', 300)
|
||||
# Maximum number of entries to cull per call if cache is full
|
||||
self._maxcull = kw.get('maxcull', 10)
|
||||
# Set timeout
|
||||
self.timeout = kw.get('timeout', 300)
|
||||
|
||||
def __getitem__(self, key):
|
||||
row = select(
|
||||
[self._store.c.value, self._store.c.expires],
|
||||
self._store.c.key == key
|
||||
).execute().fetchone()
|
||||
if row is not None:
|
||||
# Remove if item expired
|
||||
if row.expires < datetime.now().replace(microsecond=0):
|
||||
del self[key]
|
||||
raise KeyError(key)
|
||||
return self.loads(str(row.value))
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
timeout, value, cache = self.timeout, self.dumps(value), self._store
|
||||
# Cull if too many items
|
||||
if len(self) >= self._max_entries:
|
||||
self._cull()
|
||||
# Generate expiration time
|
||||
expires = datetime.fromtimestamp(
|
||||
time.time() + timeout
|
||||
).replace(microsecond=0)
|
||||
# Update database if key already present
|
||||
if key in self:
|
||||
update(
|
||||
cache,
|
||||
cache.c.key == key,
|
||||
dict(value=value, expires=expires),
|
||||
).execute()
|
||||
# Insert new key if key not present
|
||||
else:
|
||||
insert(
|
||||
cache, dict(key=key, value=value, expires=expires)
|
||||
).execute()
|
||||
|
||||
def _cull(self):
|
||||
'''Remove items in cache to make more room.'''
|
||||
cache, maxcull = self._store, self._maxcull
|
||||
# Remove items that have timed out
|
||||
now = datetime.now().replace(microsecond=0)
|
||||
delete(cache, cache.c.expires < now).execute()
|
||||
# Remove any items over the maximum allowed number in the cache
|
||||
if len(self) >= self._max_entries:
|
||||
# Upper limit for key query
|
||||
ul = maxcull * 2
|
||||
# Get list of keys
|
||||
keys = [
|
||||
i[0] for i in select(
|
||||
[cache.c.key], limit=ul
|
||||
).execute().fetchall()
|
||||
]
|
||||
# Get some keys at random
|
||||
delkeys = list(random.choice(keys) for i in xrange(maxcull))
|
||||
delete(cache, cache.c.key.in_(delkeys)).execute()
|
46
lib/shove/cache/file.py
vendored
46
lib/shove/cache/file.py
vendored
|
@ -1,46 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
File-based cache
|
||||
|
||||
shove's psuedo-URL for file caches follows the form:
|
||||
|
||||
file://<path>
|
||||
|
||||
Where the path is a URL path to a directory on a local filesystem.
|
||||
Alternatively, a native pathname to the directory can be passed as the 'engine'
|
||||
argument.
|
||||
'''
|
||||
|
||||
import time
|
||||
|
||||
from shove import FileBase
|
||||
from shove.cache.simple import SimpleCache
|
||||
|
||||
|
||||
class FileCache(FileBase, SimpleCache):
|
||||
|
||||
'''File-based cache backend'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(FileCache, self).__init__(engine, **kw)
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
exp, value = super(FileCache, self).__getitem__(key)
|
||||
# Remove item if time has expired.
|
||||
if exp < time.time():
|
||||
del self[key]
|
||||
raise KeyError(key)
|
||||
return value
|
||||
except:
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if len(self) >= self._max_entries:
|
||||
self._cull()
|
||||
super(FileCache, self).__setitem__(
|
||||
key, (time.time() + self.timeout, value)
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['FileCache']
|
23
lib/shove/cache/filelru.py
vendored
23
lib/shove/cache/filelru.py
vendored
|
@ -1,23 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
File-based LRU cache
|
||||
|
||||
shove's psuedo-URL for file caches follows the form:
|
||||
|
||||
file://<path>
|
||||
|
||||
Where the path is a URL path to a directory on a local filesystem.
|
||||
Alternatively, a native pathname to the directory can be passed as the 'engine'
|
||||
argument.
|
||||
'''
|
||||
|
||||
from shove import FileBase
|
||||
from shove.cache.simplelru import SimpleLRUCache
|
||||
|
||||
|
||||
class FileCache(FileBase, SimpleLRUCache):
|
||||
|
||||
'''File-based LRU cache backend'''
|
||||
|
||||
|
||||
__all__ = ['FileCache']
|
43
lib/shove/cache/memcached.py
vendored
43
lib/shove/cache/memcached.py
vendored
|
@ -1,43 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
"memcached" cache.
|
||||
|
||||
The shove psuedo-URL for a memcache cache is:
|
||||
|
||||
memcache://<memcache_server>
|
||||
'''
|
||||
|
||||
try:
|
||||
import memcache
|
||||
except ImportError:
|
||||
raise ImportError("Memcache cache requires the 'memcache' library")
|
||||
|
||||
from shove import Base
|
||||
|
||||
|
||||
class MemCached(Base):
|
||||
|
||||
'''Memcached cache backend'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(MemCached, self).__init__(engine, **kw)
|
||||
if engine.startswith('memcache://'):
|
||||
engine = engine.split('://')[1]
|
||||
self._store = memcache.Client(engine.split(';'))
|
||||
# Set timeout
|
||||
self.timeout = kw.get('timeout', 300)
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = self._store.get(key)
|
||||
if value is None:
|
||||
raise KeyError(key)
|
||||
return self.loads(value)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._store.set(key, self.dumps(value), self.timeout)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._store.delete(key)
|
||||
|
||||
|
||||
__all__ = ['MemCached']
|
38
lib/shove/cache/memlru.py
vendored
38
lib/shove/cache/memlru.py
vendored
|
@ -1,38 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Thread-safe in-memory cache using LRU.
|
||||
|
||||
The shove psuedo-URL for a memory cache is:
|
||||
|
||||
memlru://
|
||||
'''
|
||||
|
||||
import copy
|
||||
import threading
|
||||
|
||||
from shove import synchronized
|
||||
from shove.cache.simplelru import SimpleLRUCache
|
||||
|
||||
|
||||
class MemoryLRUCache(SimpleLRUCache):
|
||||
|
||||
'''Thread-safe in-memory cache backend using LRU.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(MemoryLRUCache, self).__init__(engine, **kw)
|
||||
self._lock = threading.Condition()
|
||||
|
||||
@synchronized
|
||||
def __setitem__(self, key, value):
|
||||
super(MemoryLRUCache, self).__setitem__(key, value)
|
||||
|
||||
@synchronized
|
||||
def __getitem__(self, key):
|
||||
return copy.deepcopy(super(MemoryLRUCache, self).__getitem__(key))
|
||||
|
||||
@synchronized
|
||||
def __delitem__(self, key):
|
||||
super(MemoryLRUCache, self).__delitem__(key)
|
||||
|
||||
|
||||
__all__ = ['MemoryLRUCache']
|
38
lib/shove/cache/memory.py
vendored
38
lib/shove/cache/memory.py
vendored
|
@ -1,38 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Thread-safe in-memory cache.
|
||||
|
||||
The shove psuedo-URL for a memory cache is:
|
||||
|
||||
memory://
|
||||
'''
|
||||
|
||||
import copy
|
||||
import threading
|
||||
|
||||
from shove import synchronized
|
||||
from shove.cache.simple import SimpleCache
|
||||
|
||||
|
||||
class MemoryCache(SimpleCache):
|
||||
|
||||
'''Thread-safe in-memory cache backend.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(MemoryCache, self).__init__(engine, **kw)
|
||||
self._lock = threading.Condition()
|
||||
|
||||
@synchronized
|
||||
def __setitem__(self, key, value):
|
||||
super(MemoryCache, self).__setitem__(key, value)
|
||||
|
||||
@synchronized
|
||||
def __getitem__(self, key):
|
||||
return copy.deepcopy(super(MemoryCache, self).__getitem__(key))
|
||||
|
||||
@synchronized
|
||||
def __delitem__(self, key):
|
||||
super(MemoryCache, self).__delitem__(key)
|
||||
|
||||
|
||||
__all__ = ['MemoryCache']
|
45
lib/shove/cache/redisdb.py
vendored
45
lib/shove/cache/redisdb.py
vendored
|
@ -1,45 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Redis-based object cache
|
||||
|
||||
The shove psuedo-URL for a redis cache is:
|
||||
|
||||
redis://<host>:<port>/<db>
|
||||
'''
|
||||
|
||||
import urlparse
|
||||
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ImportError('This store requires the redis library')
|
||||
|
||||
from shove import Base
|
||||
|
||||
|
||||
class RedisCache(Base):
|
||||
|
||||
'''Redis cache backend'''
|
||||
|
||||
init = 'redis://'
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(RedisCache, self).__init__(engine, **kw)
|
||||
spliturl = urlparse.urlsplit(engine)
|
||||
host, port = spliturl[1].split(':')
|
||||
db = spliturl[2].replace('/', '')
|
||||
self._store = redis.Redis(host, int(port), db)
|
||||
# Set timeout
|
||||
self.timeout = kw.get('timeout', 300)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.loads(self._store[key])
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._store.setex(key, self.dumps(value), self.timeout)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._store.delete(key)
|
||||
|
||||
|
||||
__all__ = ['RedisCache']
|
68
lib/shove/cache/simple.py
vendored
68
lib/shove/cache/simple.py
vendored
|
@ -1,68 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Single-process in-memory cache.
|
||||
|
||||
The shove psuedo-URL for a simple cache is:
|
||||
|
||||
simple://
|
||||
'''
|
||||
|
||||
import time
|
||||
import random
|
||||
|
||||
from shove import SimpleBase
|
||||
|
||||
|
||||
class SimpleCache(SimpleBase):
|
||||
|
||||
'''Single-process in-memory cache.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(SimpleCache, self).__init__(engine, **kw)
|
||||
# Get random seed
|
||||
random.seed()
|
||||
# Set maximum number of items to cull if over max
|
||||
self._maxcull = kw.get('maxcull', 10)
|
||||
# Set max entries
|
||||
self._max_entries = kw.get('max_entries', 300)
|
||||
# Set timeout
|
||||
self.timeout = kw.get('timeout', 300)
|
||||
|
||||
def __getitem__(self, key):
|
||||
exp, value = super(SimpleCache, self).__getitem__(key)
|
||||
# Delete if item timed out.
|
||||
if exp < time.time():
|
||||
super(SimpleCache, self).__delitem__(key)
|
||||
raise KeyError(key)
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Cull values if over max # of entries
|
||||
if len(self) >= self._max_entries:
|
||||
self._cull()
|
||||
# Set expiration time and value
|
||||
exp = time.time() + self.timeout
|
||||
super(SimpleCache, self).__setitem__(key, (exp, value))
|
||||
|
||||
def _cull(self):
|
||||
'''Remove items in cache to make room.'''
|
||||
num, maxcull = 0, self._maxcull
|
||||
# Cull number of items allowed (set by self._maxcull)
|
||||
for key in self.keys():
|
||||
# Remove only maximum # of items allowed by maxcull
|
||||
if num <= maxcull:
|
||||
# Remove items if expired
|
||||
try:
|
||||
self[key]
|
||||
except KeyError:
|
||||
num += 1
|
||||
else:
|
||||
break
|
||||
# Remove any additional items up to max # of items allowed by maxcull
|
||||
while len(self) >= self._max_entries and num <= maxcull:
|
||||
# Cull remainder of allowed quota at random
|
||||
del self[random.choice(self.keys())]
|
||||
num += 1
|
||||
|
||||
|
||||
__all__ = ['SimpleCache']
|
18
lib/shove/cache/simplelru.py
vendored
18
lib/shove/cache/simplelru.py
vendored
|
@ -1,18 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Single-process in-memory LRU cache.
|
||||
|
||||
The shove psuedo-URL for a simple cache is:
|
||||
|
||||
simplelru://
|
||||
'''
|
||||
|
||||
from shove import LRUBase
|
||||
|
||||
|
||||
class SimpleLRUCache(LRUBase):
|
||||
|
||||
'''In-memory cache that purges based on least recently used item.'''
|
||||
|
||||
|
||||
__all__ = ['SimpleLRUCache']
|
|
@ -1,48 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from urllib import url2pathname
|
||||
from shove.store.simple import SimpleStore
|
||||
|
||||
|
||||
class ClientStore(SimpleStore):
|
||||
|
||||
'''Base class for stores where updates have to be committed.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(ClientStore, self).__init__(engine, **kw)
|
||||
if engine.startswith(self.init):
|
||||
self._engine = url2pathname(engine.split('://')[1])
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.loads(super(ClientStore, self).__getitem__(key))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super(ClientStore, self).__setitem__(key, self.dumps(value))
|
||||
|
||||
|
||||
class SyncStore(ClientStore):
|
||||
|
||||
'''Base class for stores where updates have to be committed.'''
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.loads(super(SyncStore, self).__getitem__(key))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super(SyncStore, self).__setitem__(key, value)
|
||||
try:
|
||||
self.sync()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def __delitem__(self, key):
|
||||
super(SyncStore, self).__delitem__(key)
|
||||
try:
|
||||
self.sync()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
'bsdb', 'db', 'dbm', 'durusdb', 'file', 'ftp', 'memory', 's3', 'simple',
|
||||
'svn', 'zodb', 'redisdb', 'hdf5db', 'leveldbstore', 'cassandra',
|
||||
]
|
|
@ -1,48 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Berkeley Source Database Store.
|
||||
|
||||
shove's psuedo-URL for BSDDB stores follows the form:
|
||||
|
||||
bsddb://<path>
|
||||
|
||||
Where the path is a URL path to a Berkeley database. Alternatively, the native
|
||||
pathname to a Berkeley database can be passed as the 'engine' parameter.
|
||||
'''
|
||||
try:
|
||||
import bsddb
|
||||
except ImportError:
|
||||
raise ImportError('requires bsddb library')
|
||||
|
||||
import threading
|
||||
|
||||
from shove import synchronized
|
||||
from shove.store import SyncStore
|
||||
|
||||
|
||||
class BsdStore(SyncStore):
|
||||
|
||||
'''Class for Berkeley Source Database Store.'''
|
||||
|
||||
init = 'bsddb://'
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(BsdStore, self).__init__(engine, **kw)
|
||||
self._store = bsddb.hashopen(self._engine)
|
||||
self._lock = threading.Condition()
|
||||
self.sync = self._store.sync
|
||||
|
||||
@synchronized
|
||||
def __getitem__(self, key):
|
||||
return super(BsdStore, self).__getitem__(key)
|
||||
|
||||
@synchronized
|
||||
def __setitem__(self, key, value):
|
||||
super(BsdStore, self).__setitem__(key, value)
|
||||
|
||||
@synchronized
|
||||
def __delitem__(self, key):
|
||||
super(BsdStore, self).__delitem__(key)
|
||||
|
||||
|
||||
__all__ = ['BsdStore']
|
|
@ -1,72 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Cassandra-based object store
|
||||
|
||||
The shove psuedo-URL for a cassandra-based store is:
|
||||
|
||||
cassandra://<host>:<port>/<keyspace>/<columnFamily>
|
||||
'''
|
||||
|
||||
import urlparse
|
||||
|
||||
try:
|
||||
import pycassa
|
||||
except ImportError:
|
||||
raise ImportError('This store requires the pycassa library')
|
||||
|
||||
from shove import BaseStore
|
||||
|
||||
|
||||
class CassandraStore(BaseStore):
|
||||
|
||||
'''Cassandra based store'''
|
||||
|
||||
init = 'cassandra://'
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(CassandraStore, self).__init__(engine, **kw)
|
||||
spliturl = urlparse.urlsplit(engine)
|
||||
_, keyspace, column_family = spliturl[2].split('/')
|
||||
try:
|
||||
self._pool = pycassa.connect(keyspace, [spliturl[1]])
|
||||
self._store = pycassa.ColumnFamily(self._pool, column_family)
|
||||
except pycassa.InvalidRequestException:
|
||||
from pycassa.system_manager import SystemManager
|
||||
system_manager = SystemManager(spliturl[1])
|
||||
system_manager.create_keyspace(
|
||||
keyspace,
|
||||
pycassa.system_manager.SIMPLE_STRATEGY,
|
||||
{'replication_factor': str(kw.get('replication', 1))}
|
||||
)
|
||||
system_manager.create_column_family(keyspace, column_family)
|
||||
self._pool = pycassa.connect(keyspace, [spliturl[1]])
|
||||
self._store = pycassa.ColumnFamily(self._pool, column_family)
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
item = self._store.get(key).get(key)
|
||||
if item is not None:
|
||||
return self.loads(item)
|
||||
raise KeyError(key)
|
||||
except pycassa.NotFoundException:
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._store.insert(key, dict(key=self.dumps(value)))
|
||||
|
||||
def __delitem__(self, key):
|
||||
# beware eventual consistency
|
||||
try:
|
||||
self._store.remove(key)
|
||||
except pycassa.NotFoundException:
|
||||
raise KeyError(key)
|
||||
|
||||
def clear(self):
|
||||
# beware eventual consistency
|
||||
self._store.truncate()
|
||||
|
||||
def keys(self):
|
||||
return list(i[0] for i in self._store.get_range())
|
||||
|
||||
|
||||
__all__ = ['CassandraStore']
|
|
@ -1,73 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Database object store.
|
||||
|
||||
The shove psuedo-URL used for database object stores is the format used by
|
||||
SQLAlchemy:
|
||||
|
||||
<driver>://<username>:<password>@<host>:<port>/<database>
|
||||
|
||||
<driver> is the database engine. The engines currently supported SQLAlchemy are
|
||||
sqlite, mysql, postgres, oracle, mssql, and firebird.
|
||||
<username> is the database account user name
|
||||
<password> is the database accound password
|
||||
<host> is the database location
|
||||
<port> is the database port
|
||||
<database> is the name of the specific database
|
||||
|
||||
For more information on specific databases see:
|
||||
|
||||
http://www.sqlalchemy.org/docs/dbengine.myt#dbengine_supported
|
||||
'''
|
||||
|
||||
try:
|
||||
from sqlalchemy import MetaData, Table, Column, String, Binary, select
|
||||
from shove import BaseStore, DbBase
|
||||
except ImportError, e:
|
||||
raise ImportError('Error: ' + e + ' Requires SQLAlchemy >= 0.4')
|
||||
|
||||
|
||||
class DbStore(BaseStore, DbBase):
|
||||
|
||||
'''Database cache backend.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(DbStore, self).__init__(engine, **kw)
|
||||
# Get tablename
|
||||
tablename = kw.get('tablename', 'store')
|
||||
# Bind metadata
|
||||
self._metadata = MetaData(engine)
|
||||
# Make store table
|
||||
self._store = Table(tablename, self._metadata,
|
||||
Column('key', String(255), primary_key=True, nullable=False),
|
||||
Column('value', Binary, nullable=False),
|
||||
)
|
||||
# Create store table if it does not exist
|
||||
if not self._store.exists():
|
||||
self._store.create()
|
||||
|
||||
def __getitem__(self, key):
|
||||
row = select(
|
||||
[self._store.c.value], self._store.c.key == key,
|
||||
).execute().fetchone()
|
||||
if row is not None:
|
||||
return self.loads(str(row.value))
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, k, v):
|
||||
v, store = self.dumps(v), self._store
|
||||
# Update database if key already present
|
||||
if k in self:
|
||||
store.update(store.c.key == k).execute(value=v)
|
||||
# Insert new key if key not present
|
||||
else:
|
||||
store.insert().execute(key=k, value=v)
|
||||
|
||||
def keys(self):
|
||||
'''Returns a list of keys in the store.'''
|
||||
return list(i[0] for i in select(
|
||||
[self._store.c.key]
|
||||
).execute().fetchall())
|
||||
|
||||
|
||||
__all__ = ['DbStore']
|
|
@ -1,33 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
DBM Database Store.
|
||||
|
||||
shove's psuedo-URL for DBM stores follows the form:
|
||||
|
||||
dbm://<path>
|
||||
|
||||
Where <path> is a URL path to a DBM database. Alternatively, the native
|
||||
pathname to a DBM database can be passed as the 'engine' parameter.
|
||||
'''
|
||||
|
||||
import anydbm
|
||||
|
||||
from shove.store import SyncStore
|
||||
|
||||
|
||||
class DbmStore(SyncStore):
|
||||
|
||||
'''Class for variants of the DBM database.'''
|
||||
|
||||
init = 'dbm://'
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(DbmStore, self).__init__(engine, **kw)
|
||||
self._store = anydbm.open(self._engine, 'c')
|
||||
try:
|
||||
self.sync = self._store.sync
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['DbmStore']
|
|
@ -1,43 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Durus object database frontend.
|
||||
|
||||
shove's psuedo-URL for Durus stores follows the form:
|
||||
|
||||
durus://<path>
|
||||
|
||||
|
||||
Where the path is a URL path to a durus FileStorage database. Alternatively, a
|
||||
native pathname to a durus database can be passed as the 'engine' parameter.
|
||||
'''
|
||||
|
||||
try:
|
||||
from durus.connection import Connection
|
||||
from durus.file_storage import FileStorage
|
||||
except ImportError:
|
||||
raise ImportError('Requires Durus library')
|
||||
|
||||
from shove.store import SyncStore
|
||||
|
||||
|
||||
class DurusStore(SyncStore):
|
||||
|
||||
'''Class for Durus object database frontend.'''
|
||||
|
||||
init = 'durus://'
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(DurusStore, self).__init__(engine, **kw)
|
||||
self._db = FileStorage(self._engine)
|
||||
self._connection = Connection(self._db)
|
||||
self.sync = self._connection.commit
|
||||
self._store = self._connection.get_root()
|
||||
|
||||
def close(self):
|
||||
'''Closes all open storage and connections.'''
|
||||
self.sync()
|
||||
self._db.close()
|
||||
super(DurusStore, self).close()
|
||||
|
||||
|
||||
__all__ = ['DurusStore']
|
|
@ -1,25 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Filesystem-based object store
|
||||
|
||||
shove's psuedo-URL for filesystem-based stores follows the form:
|
||||
|
||||
file://<path>
|
||||
|
||||
Where the path is a URL path to a directory on a local filesystem.
|
||||
Alternatively, a native pathname to the directory can be passed as the 'engine'
|
||||
argument.
|
||||
'''
|
||||
|
||||
from shove import BaseStore, FileBase
|
||||
|
||||
|
||||
class FileStore(FileBase, BaseStore):
|
||||
|
||||
'''File-based store.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(FileStore, self).__init__(engine, **kw)
|
||||
|
||||
|
||||
__all__ = ['FileStore']
|
|
@ -1,88 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
FTP-accessed stores
|
||||
|
||||
shove's URL for FTP accessed stores follows the standard form for FTP URLs
|
||||
defined in RFC-1738:
|
||||
|
||||
ftp://<user>:<password>@<host>:<port>/<url-path>
|
||||
'''
|
||||
|
||||
import urlparse
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO
|
||||
from ftplib import FTP, error_perm
|
||||
|
||||
from shove import BaseStore
|
||||
|
||||
|
||||
class FtpStore(BaseStore):
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(FtpStore, self).__init__(engine, **kw)
|
||||
user = kw.get('user', 'anonymous')
|
||||
password = kw.get('password', '')
|
||||
spliturl = urlparse.urlsplit(engine)
|
||||
# Set URL, path, and strip 'ftp://' off
|
||||
base, path = spliturl[1], spliturl[2] + '/'
|
||||
if '@' in base:
|
||||
auth, base = base.split('@')
|
||||
user, password = auth.split(':')
|
||||
self._store = FTP(base, user, password)
|
||||
# Change to remote path if it exits
|
||||
try:
|
||||
self._store.cwd(path)
|
||||
except error_perm:
|
||||
self._makedir(path)
|
||||
self._base, self._user, self._password = base, user, password
|
||||
self._updated, self ._keys = True, None
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
local = StringIO()
|
||||
# Download item
|
||||
self._store.retrbinary('RETR %s' % key, local.write)
|
||||
self._updated = False
|
||||
return self.loads(local.getvalue())
|
||||
except:
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
local = StringIO(self.dumps(value))
|
||||
self._store.storbinary('STOR %s' % key, local)
|
||||
self._updated = True
|
||||
|
||||
def __delitem__(self, key):
|
||||
try:
|
||||
self._store.delete(key)
|
||||
self._updated = True
|
||||
except:
|
||||
raise KeyError(key)
|
||||
|
||||
def _makedir(self, path):
|
||||
'''Makes remote paths on an FTP server.'''
|
||||
paths = list(reversed([i for i in path.split('/') if i != '']))
|
||||
while paths:
|
||||
tpath = paths.pop()
|
||||
self._store.mkd(tpath)
|
||||
self._store.cwd(tpath)
|
||||
|
||||
def keys(self):
|
||||
'''Returns a list of keys in a store.'''
|
||||
if self._updated or self._keys is None:
|
||||
rlist, nlist = list(), list()
|
||||
# Remote directory listing
|
||||
self._store.retrlines('LIST -a', rlist.append)
|
||||
for rlisting in rlist:
|
||||
# Split remote file based on whitespace
|
||||
rfile = rlisting.split()
|
||||
# Append tuple of remote item type & name
|
||||
if rfile[-1] not in ('.', '..') and rfile[0].startswith('-'):
|
||||
nlist.append(rfile[-1])
|
||||
self._keys = nlist
|
||||
return self._keys
|
||||
|
||||
|
||||
__all__ = ['FtpStore']
|
|
@ -1,34 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
HDF5 Database Store.
|
||||
|
||||
shove's psuedo-URL for HDF5 stores follows the form:
|
||||
|
||||
hdf5://<path>/<group>
|
||||
|
||||
Where <path> is a URL path to a HDF5 database. Alternatively, the native
|
||||
pathname to a HDF5 database can be passed as the 'engine' parameter.
|
||||
<group> is the name of the database.
|
||||
'''
|
||||
|
||||
try:
|
||||
import h5py
|
||||
except ImportError:
|
||||
raise ImportError('This store requires h5py library')
|
||||
|
||||
from shove.store import ClientStore
|
||||
|
||||
|
||||
class HDF5Store(ClientStore):
|
||||
|
||||
'''LevelDB based store'''
|
||||
|
||||
init = 'hdf5://'
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(HDF5Store, self).__init__(engine, **kw)
|
||||
engine, group = self._engine.rsplit('/')
|
||||
self._store = h5py.File(engine).require_group(group).attrs
|
||||
|
||||
|
||||
__all__ = ['HDF5Store']
|
|
@ -1,47 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
LevelDB Database Store.
|
||||
|
||||
shove's psuedo-URL for LevelDB stores follows the form:
|
||||
|
||||
leveldb://<path>
|
||||
|
||||
Where <path> is a URL path to a LevelDB database. Alternatively, the native
|
||||
pathname to a LevelDB database can be passed as the 'engine' parameter.
|
||||
'''
|
||||
|
||||
try:
|
||||
import leveldb
|
||||
except ImportError:
|
||||
raise ImportError('This store requires py-leveldb library')
|
||||
|
||||
from shove.store import ClientStore
|
||||
|
||||
|
||||
class LevelDBStore(ClientStore):
|
||||
|
||||
'''LevelDB based store'''
|
||||
|
||||
init = 'leveldb://'
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(LevelDBStore, self).__init__(engine, **kw)
|
||||
self._store = leveldb.LevelDB(self._engine)
|
||||
|
||||
def __getitem__(self, key):
|
||||
item = self.loads(self._store.Get(key))
|
||||
if item is not None:
|
||||
return item
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._store.Put(key, self.dumps(value))
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._store.Delete(key)
|
||||
|
||||
def keys(self):
|
||||
return list(k for k in self._store.RangeIter(include_value=False))
|
||||
|
||||
|
||||
__all__ = ['LevelDBStore']
|
|
@ -1,38 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Thread-safe in-memory store.
|
||||
|
||||
The shove psuedo-URL for a memory store is:
|
||||
|
||||
memory://
|
||||
'''
|
||||
|
||||
import copy
|
||||
import threading
|
||||
|
||||
from shove import synchronized
|
||||
from shove.store.simple import SimpleStore
|
||||
|
||||
|
||||
class MemoryStore(SimpleStore):
|
||||
|
||||
'''Thread-safe in-memory store.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(MemoryStore, self).__init__(engine, **kw)
|
||||
self._lock = threading.Condition()
|
||||
|
||||
@synchronized
|
||||
def __getitem__(self, key):
|
||||
return copy.deepcopy(super(MemoryStore, self).__getitem__(key))
|
||||
|
||||
@synchronized
|
||||
def __setitem__(self, key, value):
|
||||
super(MemoryStore, self).__setitem__(key, value)
|
||||
|
||||
@synchronized
|
||||
def __delitem__(self, key):
|
||||
super(MemoryStore, self).__delitem__(key)
|
||||
|
||||
|
||||
__all__ = ['MemoryStore']
|
|
@ -1,50 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Redis-based object store
|
||||
|
||||
The shove psuedo-URL for a redis-based store is:
|
||||
|
||||
redis://<host>:<port>/<db>
|
||||
'''
|
||||
|
||||
import urlparse
|
||||
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ImportError('This store requires the redis library')
|
||||
|
||||
from shove.store import ClientStore
|
||||
|
||||
|
||||
class RedisStore(ClientStore):
|
||||
|
||||
'''Redis based store'''
|
||||
|
||||
init = 'redis://'
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(RedisStore, self).__init__(engine, **kw)
|
||||
spliturl = urlparse.urlsplit(engine)
|
||||
host, port = spliturl[1].split(':')
|
||||
db = spliturl[2].replace('/', '')
|
||||
self._store = redis.Redis(host, int(port), db)
|
||||
|
||||
def __contains__(self, key):
|
||||
return self._store.exists(key)
|
||||
|
||||
def clear(self):
|
||||
self._store.flushdb()
|
||||
|
||||
def keys(self):
|
||||
return self._store.keys()
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
return self._store.getset(key, default)
|
||||
|
||||
def update(self, other=None, **kw):
|
||||
args = kw if other is not None else other
|
||||
self._store.mset(args)
|
||||
|
||||
|
||||
__all__ = ['RedisStore']
|
|
@ -1,91 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
S3-accessed stores
|
||||
|
||||
shove's psuedo-URL for stores found on Amazon.com's S3 web service follows this
|
||||
form:
|
||||
|
||||
s3://<s3_key>:<s3_secret>@<bucket>
|
||||
|
||||
<s3_key> is the Access Key issued by Amazon
|
||||
<s3_secret> is the Secret Access Key issued by Amazon
|
||||
<bucket> is the name of the bucket accessed through the S3 service
|
||||
'''
|
||||
|
||||
try:
|
||||
from boto.s3.connection import S3Connection
|
||||
from boto.s3.key import Key
|
||||
except ImportError:
|
||||
raise ImportError('Requires boto library')
|
||||
|
||||
from shove import BaseStore
|
||||
|
||||
|
||||
class S3Store(BaseStore):
|
||||
|
||||
def __init__(self, engine=None, **kw):
|
||||
super(S3Store, self).__init__(engine, **kw)
|
||||
# key = Access Key, secret=Secret Access Key, bucket=bucket name
|
||||
key, secret, bucket = kw.get('key'), kw.get('secret'), kw.get('bucket')
|
||||
if engine is not None:
|
||||
auth, bucket = engine.split('://')[1].split('@')
|
||||
key, secret = auth.split(':')
|
||||
# kw 'secure' = (True or False, use HTTPS)
|
||||
self._conn = S3Connection(key, secret, kw.get('secure', False))
|
||||
buckets = self._conn.get_all_buckets()
|
||||
# Use bucket if it exists
|
||||
for b in buckets:
|
||||
if b.name == bucket:
|
||||
self._store = b
|
||||
break
|
||||
# Create bucket if it doesn't exist
|
||||
else:
|
||||
self._store = self._conn.create_bucket(bucket)
|
||||
# Set bucket permission ('private', 'public-read',
|
||||
# 'public-read-write', 'authenticated-read'
|
||||
self._store.set_acl(kw.get('acl', 'private'))
|
||||
# Updated flag used for avoiding network calls
|
||||
self._updated, self._keys = True, None
|
||||
|
||||
def __getitem__(self, key):
|
||||
rkey = self._store.lookup(key)
|
||||
if rkey is None:
|
||||
raise KeyError(key)
|
||||
# Fetch string
|
||||
value = self.loads(rkey.get_contents_as_string())
|
||||
# Flag that the store has not been updated
|
||||
self._updated = False
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
rkey = Key(self._store)
|
||||
rkey.key = key
|
||||
rkey.set_contents_from_string(self.dumps(value))
|
||||
# Flag that the store has been updated
|
||||
self._updated = True
|
||||
|
||||
def __delitem__(self, key):
|
||||
try:
|
||||
self._store.delete_key(key)
|
||||
# Flag that the store has been updated
|
||||
self._updated = True
|
||||
except:
|
||||
raise KeyError(key)
|
||||
|
||||
def keys(self):
|
||||
'''Returns a list of keys in the store.'''
|
||||
return list(i[0] for i in self.items())
|
||||
|
||||
def items(self):
|
||||
'''Returns a list of items from the store.'''
|
||||
if self._updated or self._keys is None:
|
||||
self._keys = self._store.get_all_keys()
|
||||
return list((str(k.key), k) for k in self._keys)
|
||||
|
||||
def iteritems(self):
|
||||
'''Lazily returns items from the store.'''
|
||||
for k in self.items():
|
||||
yield (k.key, k)
|
||||
|
||||
|
||||
__all__ = ['S3Store']
|
|
@ -1,21 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Single-process in-memory store.
|
||||
|
||||
The shove psuedo-URL for a simple store is:
|
||||
|
||||
simple://
|
||||
'''
|
||||
|
||||
from shove import BaseStore, SimpleBase
|
||||
|
||||
|
||||
class SimpleStore(SimpleBase, BaseStore):
|
||||
|
||||
'''Single-process in-memory store.'''
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(SimpleStore, self).__init__(engine, **kw)
|
||||
|
||||
|
||||
__all__ = ['SimpleStore']
|
|
@ -1,110 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
subversion managed store.
|
||||
|
||||
The shove psuedo-URL used for a subversion store that is password protected is:
|
||||
|
||||
svn:<username><password>:<path>?url=<url>
|
||||
|
||||
or for non-password protected repositories:
|
||||
|
||||
svn://<path>?url=<url>
|
||||
|
||||
<path> is the local repository copy
|
||||
<url> is the URL of the subversion repository
|
||||
'''
|
||||
|
||||
import os
|
||||
import urllib
|
||||
import threading
|
||||
|
||||
try:
|
||||
import pysvn
|
||||
except ImportError:
|
||||
raise ImportError('Requires Python Subversion library')
|
||||
|
||||
from shove import BaseStore, synchronized
|
||||
|
||||
|
||||
class SvnStore(BaseStore):
|
||||
|
||||
'''Class for subversion store.'''
|
||||
|
||||
def __init__(self, engine=None, **kw):
|
||||
super(SvnStore, self).__init__(engine, **kw)
|
||||
# Get path, url from keywords if used
|
||||
path, url = kw.get('path'), kw.get('url')
|
||||
# Get username. password from keywords if used
|
||||
user, password = kw.get('user'), kw.get('password')
|
||||
# Process psuedo URL if used
|
||||
if engine is not None:
|
||||
path, query = engine.split('n://')[1].split('?')
|
||||
url = query.split('=')[1]
|
||||
# Check for username, password
|
||||
if '@' in path:
|
||||
auth, path = path.split('@')
|
||||
user, password = auth.split(':')
|
||||
path = urllib.url2pathname(path)
|
||||
# Create subversion client
|
||||
self._client = pysvn.Client()
|
||||
# Assign username, password
|
||||
if user is not None:
|
||||
self._client.set_username(user)
|
||||
if password is not None:
|
||||
self._client.set_password(password)
|
||||
# Verify that store exists in repository
|
||||
try:
|
||||
self._client.info2(url)
|
||||
# Create store in repository if it doesn't exist
|
||||
except pysvn.ClientError:
|
||||
self._client.mkdir(url, 'Adding directory')
|
||||
# Verify that local copy exists
|
||||
try:
|
||||
if self._client.info(path) is None:
|
||||
self._client.checkout(url, path)
|
||||
# Check it out if it doesn't exist
|
||||
except pysvn.ClientError:
|
||||
self._client.checkout(url, path)
|
||||
self._path, self._url = path, url
|
||||
# Lock
|
||||
self._lock = threading.Condition()
|
||||
|
||||
@synchronized
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
return self.loads(self._client.cat(self._key_to_file(key)))
|
||||
except:
|
||||
raise KeyError(key)
|
||||
|
||||
@synchronized
|
||||
def __setitem__(self, key, value):
|
||||
fname = self._key_to_file(key)
|
||||
# Write value to file
|
||||
open(fname, 'wb').write(self.dumps(value))
|
||||
# Add to repository
|
||||
if key not in self:
|
||||
self._client.add(fname)
|
||||
self._client.checkin([fname], 'Adding %s' % fname)
|
||||
|
||||
@synchronized
|
||||
def __delitem__(self, key):
|
||||
try:
|
||||
fname = self._key_to_file(key)
|
||||
self._client.remove(fname)
|
||||
# Remove deleted value from repository
|
||||
self._client.checkin([fname], 'Removing %s' % fname)
|
||||
except:
|
||||
raise KeyError(key)
|
||||
|
||||
def _key_to_file(self, key):
|
||||
'''Gives the filesystem path for a key.'''
|
||||
return os.path.join(self._path, urllib.quote_plus(key))
|
||||
|
||||
@synchronized
|
||||
def keys(self):
|
||||
'''Returns a list of keys in the subversion repository.'''
|
||||
return list(str(i.name.split('/')[-1]) for i
|
||||
in self._client.ls(self._path))
|
||||
|
||||
|
||||
__all__ = ['SvnStore']
|
|
@ -1,48 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Zope Object Database store frontend.
|
||||
|
||||
shove's psuedo-URL for ZODB stores follows the form:
|
||||
|
||||
zodb:<path>
|
||||
|
||||
|
||||
Where the path is a URL path to a ZODB FileStorage database. Alternatively, a
|
||||
native pathname to a ZODB database can be passed as the 'engine' argument.
|
||||
'''
|
||||
|
||||
try:
|
||||
import transaction
|
||||
from ZODB import FileStorage, DB
|
||||
except ImportError:
|
||||
raise ImportError('Requires ZODB library')
|
||||
|
||||
from shove.store import SyncStore
|
||||
|
||||
|
||||
class ZodbStore(SyncStore):
|
||||
|
||||
'''ZODB store front end.'''
|
||||
|
||||
init = 'zodb://'
|
||||
|
||||
def __init__(self, engine, **kw):
|
||||
super(ZodbStore, self).__init__(engine, **kw)
|
||||
# Handle psuedo-URL
|
||||
self._storage = FileStorage.FileStorage(self._engine)
|
||||
self._db = DB(self._storage)
|
||||
self._connection = self._db.open()
|
||||
self._store = self._connection.root()
|
||||
# Keeps DB in synch through commits of transactions
|
||||
self.sync = transaction.commit
|
||||
|
||||
def close(self):
|
||||
'''Closes all open storage and connections.'''
|
||||
self.sync()
|
||||
super(ZodbStore, self).close()
|
||||
self._connection.close()
|
||||
self._db.close()
|
||||
self._storage.close()
|
||||
|
||||
|
||||
__all__ = ['ZodbStore']
|
|
@ -1 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
|
@ -1,133 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestBsdbStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove('bsddb://test.db', compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
import os
|
||||
self.store.close()
|
||||
os.remove('test.db')
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_close(self):
|
||||
self.store.close()
|
||||
self.assertEqual(self.store, None)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,137 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestCassandraStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
from pycassa.system_manager import SystemManager
|
||||
system_manager = SystemManager('localhost:9160')
|
||||
try:
|
||||
system_manager.create_column_family('Foo', 'shove')
|
||||
except:
|
||||
pass
|
||||
self.store = Shove('cassandra://localhost:9160/Foo/shove')
|
||||
|
||||
def tearDown(self):
|
||||
self.store.clear()
|
||||
self.store.close()
|
||||
from pycassa.system_manager import SystemManager
|
||||
system_manager = SystemManager('localhost:9160')
|
||||
system_manager.drop_column_family('Foo', 'shove')
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
# def test_clear(self):
|
||||
# self.store['max'] = 3
|
||||
# self.store['min'] = 6
|
||||
# self.store['pow'] = 7
|
||||
# self.store.clear()
|
||||
# self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
# def test_popitem(self):
|
||||
# self.store['max'] = 3
|
||||
# self.store['min'] = 6
|
||||
# self.store['pow'] = 7
|
||||
# item = self.store.popitem()
|
||||
# self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
# self.store['pow'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.assertEqual(self.store.setdefault('pow', 8), 8)
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,54 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestDbCache(unittest.TestCase):
|
||||
|
||||
initstring = 'sqlite:///'
|
||||
|
||||
def setUp(self):
|
||||
from shove.cache.db import DbCache
|
||||
self.cache = DbCache(self.initstring)
|
||||
|
||||
def tearDown(self):
|
||||
self.cache = None
|
||||
|
||||
def test_getitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_setitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_delitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
del self.cache['test']
|
||||
self.assertEqual('test' in self.cache, False)
|
||||
|
||||
def test_get(self):
|
||||
self.assertEqual(self.cache.get('min'), None)
|
||||
|
||||
def test_timeout(self):
|
||||
import time
|
||||
from shove.cache.db import DbCache
|
||||
cache = DbCache(self.initstring, timeout=1)
|
||||
cache['test'] = 'test'
|
||||
time.sleep(2)
|
||||
|
||||
def tmp():
|
||||
cache['test']
|
||||
self.assertRaises(KeyError, tmp)
|
||||
|
||||
def test_cull(self):
|
||||
from shove.cache.db import DbCache
|
||||
cache = DbCache(self.initstring, max_entries=1)
|
||||
cache['test'] = 'test'
|
||||
cache['test2'] = 'test'
|
||||
cache['test2'] = 'test'
|
||||
self.assertEquals(len(cache), 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,131 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestDbStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove('sqlite://', compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
self.store.close()
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_close(self):
|
||||
self.store.close()
|
||||
self.assertEqual(self.store, None)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,136 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestDbmStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove('dbm://test.dbm', compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
import os
|
||||
self.store.close()
|
||||
try:
|
||||
os.remove('test.dbm.db')
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_close(self):
|
||||
self.store.close()
|
||||
self.assertEqual(self.store, None)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.setdefault('how', 8)
|
||||
self.assertEqual(self.store['how'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,133 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestDurusStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove('durus://test.durus', compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
import os
|
||||
self.store.close()
|
||||
os.remove('test.durus')
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_close(self):
|
||||
self.store.close()
|
||||
self.assertEqual(self.store, None)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,58 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestFileCache(unittest.TestCase):
|
||||
|
||||
initstring = 'file://test'
|
||||
|
||||
def setUp(self):
|
||||
from shove.cache.file import FileCache
|
||||
self.cache = FileCache(self.initstring)
|
||||
|
||||
def tearDown(self):
|
||||
import os
|
||||
self.cache = None
|
||||
for x in os.listdir('test'):
|
||||
os.remove(os.path.join('test', x))
|
||||
os.rmdir('test')
|
||||
|
||||
def test_getitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_setitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_delitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
del self.cache['test']
|
||||
self.assertEqual('test' in self.cache, False)
|
||||
|
||||
def test_get(self):
|
||||
self.assertEqual(self.cache.get('min'), None)
|
||||
|
||||
def test_timeout(self):
|
||||
import time
|
||||
from shove.cache.file import FileCache
|
||||
cache = FileCache(self.initstring, timeout=1)
|
||||
cache['test'] = 'test'
|
||||
time.sleep(2)
|
||||
|
||||
def tmp():
|
||||
cache['test']
|
||||
self.assertRaises(KeyError, tmp)
|
||||
|
||||
def test_cull(self):
|
||||
from shove.cache.file import FileCache
|
||||
cache = FileCache(self.initstring, max_entries=1)
|
||||
cache['test'] = 'test'
|
||||
cache['test2'] = 'test'
|
||||
num = len(cache)
|
||||
self.assertEquals(num, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,140 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestFileStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove('file://test', compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
import os
|
||||
self.store.close()
|
||||
for x in os.listdir('test'):
|
||||
os.remove(os.path.join('test', x))
|
||||
os.rmdir('test')
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.store.sync()
|
||||
tstore.sync()
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_close(self):
|
||||
self.store.close()
|
||||
self.assertEqual(self.store, None)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,149 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestFtpStore(unittest.TestCase):
|
||||
|
||||
ftpstring = 'put ftp string here'
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove(self.ftpstring, compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
self.store.clear()
|
||||
self.store.close()
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.store.sync()
|
||||
tstore.sync()
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store.sync()
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store.sync()
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
item = self.store.popitem()
|
||||
self.store.sync()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
self.store.update(tstore)
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,135 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest2
|
||||
|
||||
|
||||
class TestHDF5Store(unittest2.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove('hdf5://test.hdf5/test')
|
||||
|
||||
def tearDown(self):
|
||||
import os
|
||||
self.store.close()
|
||||
try:
|
||||
os.remove('test.hdf5')
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_close(self):
|
||||
self.store.close()
|
||||
self.assertEqual(self.store, None)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.setdefault('bow', 8)
|
||||
self.assertEqual(self.store['bow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest2.main()
|
|
@ -1,132 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest2
|
||||
|
||||
|
||||
class TestLevelDBStore(unittest2.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove('leveldb://test', compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
import shutil
|
||||
shutil.rmtree('test')
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_close(self):
|
||||
self.store.close()
|
||||
self.assertEqual(self.store, None)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.setdefault('bow', 8)
|
||||
self.assertEqual(self.store['bow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest2.main()
|
|
@ -1,46 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestMemcached(unittest.TestCase):
|
||||
|
||||
initstring = 'memcache://localhost:11211'
|
||||
|
||||
def setUp(self):
|
||||
from shove.cache.memcached import MemCached
|
||||
self.cache = MemCached(self.initstring)
|
||||
|
||||
def tearDown(self):
|
||||
self.cache = None
|
||||
|
||||
def test_getitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_setitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_delitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
del self.cache['test']
|
||||
self.assertEqual('test' in self.cache, False)
|
||||
|
||||
def test_get(self):
|
||||
self.assertEqual(self.cache.get('min'), None)
|
||||
|
||||
def test_timeout(self):
|
||||
import time
|
||||
from shove.cache.memcached import MemCached
|
||||
cache = MemCached(self.initstring, timeout=1)
|
||||
cache['test'] = 'test'
|
||||
time.sleep(1)
|
||||
|
||||
def tmp():
|
||||
cache['test']
|
||||
self.assertRaises(KeyError, tmp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,54 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestMemoryCache(unittest.TestCase):
|
||||
|
||||
initstring = 'memory://'
|
||||
|
||||
def setUp(self):
|
||||
from shove.cache.memory import MemoryCache
|
||||
self.cache = MemoryCache(self.initstring)
|
||||
|
||||
def tearDown(self):
|
||||
self.cache = None
|
||||
|
||||
def test_getitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_setitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_delitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
del self.cache['test']
|
||||
self.assertEqual('test' in self.cache, False)
|
||||
|
||||
def test_get(self):
|
||||
self.assertEqual(self.cache.get('min'), None)
|
||||
|
||||
def test_timeout(self):
|
||||
import time
|
||||
from shove.cache.memory import MemoryCache
|
||||
cache = MemoryCache(self.initstring, timeout=1)
|
||||
cache['test'] = 'test'
|
||||
time.sleep(1)
|
||||
|
||||
def tmp():
|
||||
cache['test']
|
||||
self.assertRaises(KeyError, tmp)
|
||||
|
||||
def test_cull(self):
|
||||
from shove.cache.memory import MemoryCache
|
||||
cache = MemoryCache(self.initstring, max_entries=1)
|
||||
cache['test'] = 'test'
|
||||
cache['test2'] = 'test'
|
||||
cache['test2'] = 'test'
|
||||
self.assertEquals(len(cache), 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,135 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestMemoryStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove('memory://', compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
self.store.close()
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.store.sync()
|
||||
tstore.sync()
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_close(self):
|
||||
self.store.close()
|
||||
self.assertEqual(self.store, None)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,45 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestRedisCache(unittest.TestCase):
|
||||
|
||||
initstring = 'redis://localhost:6379/0'
|
||||
|
||||
def setUp(self):
|
||||
from shove.cache.redisdb import RedisCache
|
||||
self.cache = RedisCache(self.initstring)
|
||||
|
||||
def tearDown(self):
|
||||
self.cache = None
|
||||
|
||||
def test_getitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_setitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_delitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
del self.cache['test']
|
||||
self.assertEqual('test' in self.cache, False)
|
||||
|
||||
def test_get(self):
|
||||
self.assertEqual(self.cache.get('min'), None)
|
||||
|
||||
def test_timeout(self):
|
||||
import time
|
||||
from shove.cache.redisdb import RedisCache
|
||||
cache = RedisCache(self.initstring, timeout=1)
|
||||
cache['test'] = 'test'
|
||||
time.sleep(3)
|
||||
def tmp(): #@IgnorePep8
|
||||
return cache['test']
|
||||
self.assertRaises(KeyError, tmp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,128 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestRedisStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove('redis://localhost:6379/0')
|
||||
|
||||
def tearDown(self):
|
||||
self.store.clear()
|
||||
self.store.close()
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.assertEqual(self.store.setdefault('pow', 8), 8)
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,149 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestS3Store(unittest.TestCase):
|
||||
|
||||
s3string = 's3 test string here'
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove(self.s3string, compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
self.store.clear()
|
||||
self.store.close()
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.store.sync()
|
||||
tstore.sync()
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store.sync()
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store.sync()
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
item = self.store.popitem()
|
||||
self.store.sync()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
self.store.update(tstore)
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,54 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestSimpleCache(unittest.TestCase):
|
||||
|
||||
initstring = 'simple://'
|
||||
|
||||
def setUp(self):
|
||||
from shove.cache.simple import SimpleCache
|
||||
self.cache = SimpleCache(self.initstring)
|
||||
|
||||
def tearDown(self):
|
||||
self.cache = None
|
||||
|
||||
def test_getitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_setitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
self.assertEqual(self.cache['test'], 'test')
|
||||
|
||||
def test_delitem(self):
|
||||
self.cache['test'] = 'test'
|
||||
del self.cache['test']
|
||||
self.assertEqual('test' in self.cache, False)
|
||||
|
||||
def test_get(self):
|
||||
self.assertEqual(self.cache.get('min'), None)
|
||||
|
||||
def test_timeout(self):
|
||||
import time
|
||||
from shove.cache.simple import SimpleCache
|
||||
cache = SimpleCache(self.initstring, timeout=1)
|
||||
cache['test'] = 'test'
|
||||
time.sleep(1)
|
||||
|
||||
def tmp():
|
||||
cache['test']
|
||||
self.assertRaises(KeyError, tmp)
|
||||
|
||||
def test_cull(self):
|
||||
from shove.cache.simple import SimpleCache
|
||||
cache = SimpleCache(self.initstring, max_entries=1)
|
||||
cache['test'] = 'test'
|
||||
cache['test2'] = 'test'
|
||||
cache['test2'] = 'test'
|
||||
self.assertEquals(len(cache), 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,135 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestSimpleStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove('simple://', compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
self.store.close()
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.store.sync()
|
||||
tstore.sync()
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_close(self):
|
||||
self.store.close()
|
||||
self.assertEqual(self.store, None)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,148 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestSvnStore(unittest.TestCase):
|
||||
|
||||
svnstring = 'SVN test string here'
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove(self.svnstring, compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
self.store.clear()
|
||||
self.store.close()
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.store.sync()
|
||||
tstore.sync()
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store.sync()
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store.sync()
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
item = self.store.popitem()
|
||||
self.store.sync()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
self.store.update(tstore)
|
||||
self.store.sync()
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.sync()
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,138 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestZodbStore(unittest.TestCase):
|
||||
|
||||
init = 'zodb://test.db'
|
||||
|
||||
def setUp(self):
|
||||
from shove import Shove
|
||||
self.store = Shove(self.init, compress=True)
|
||||
|
||||
def tearDown(self):
|
||||
self.store.close()
|
||||
import os
|
||||
os.remove('test.db')
|
||||
os.remove('test.db.index')
|
||||
os.remove('test.db.tmp')
|
||||
os.remove('test.db.lock')
|
||||
|
||||
def test__getitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__setitem__(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store['max'], 3)
|
||||
|
||||
def test__delitem__(self):
|
||||
self.store['max'] = 3
|
||||
del self.store['max']
|
||||
self.assertEqual('max' in self.store, False)
|
||||
|
||||
def test_get(self):
|
||||
self.store['max'] = 3
|
||||
self.assertEqual(self.store.get('min'), None)
|
||||
|
||||
def test__cmp__(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
self.store['max'] = 3
|
||||
tstore['max'] = 3
|
||||
self.assertEqual(self.store, tstore)
|
||||
|
||||
def test__len__(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.assertEqual(len(self.store), 2)
|
||||
|
||||
def test_close(self):
|
||||
self.store.close()
|
||||
self.assertEqual(self.store, None)
|
||||
|
||||
def test_clear(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
self.store.clear()
|
||||
self.assertEqual(len(self.store), 0)
|
||||
|
||||
def test_items(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.items())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iteritems(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iteritems())
|
||||
self.assertEqual(('min', 6) in slist, True)
|
||||
|
||||
def test_iterkeys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.iterkeys())
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
def test_itervalues(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = list(self.store.itervalues())
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_pop(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
item = self.store.pop('min')
|
||||
self.assertEqual(item, 6)
|
||||
|
||||
def test_popitem(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
item = self.store.popitem()
|
||||
self.assertEqual(len(item) + len(self.store), 4)
|
||||
|
||||
def test_setdefault(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['powl'] = 7
|
||||
self.store.setdefault('pow', 8)
|
||||
self.assertEqual(self.store['pow'], 8)
|
||||
|
||||
def test_update(self):
|
||||
from shove import Shove
|
||||
tstore = Shove()
|
||||
tstore['max'] = 3
|
||||
tstore['min'] = 6
|
||||
tstore['pow'] = 7
|
||||
self.store['max'] = 2
|
||||
self.store['min'] = 3
|
||||
self.store['pow'] = 7
|
||||
self.store.update(tstore)
|
||||
self.assertEqual(self.store['min'], 6)
|
||||
|
||||
def test_values(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.values()
|
||||
self.assertEqual(6 in slist, True)
|
||||
|
||||
def test_keys(self):
|
||||
self.store['max'] = 3
|
||||
self.store['min'] = 6
|
||||
self.store['pow'] = 7
|
||||
slist = self.store.keys()
|
||||
self.assertEqual('min' in slist, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,133 +0,0 @@
|
|||
# sqlalchemy/__init__.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
from .sql import (
|
||||
alias,
|
||||
and_,
|
||||
asc,
|
||||
between,
|
||||
bindparam,
|
||||
case,
|
||||
cast,
|
||||
collate,
|
||||
delete,
|
||||
desc,
|
||||
distinct,
|
||||
except_,
|
||||
except_all,
|
||||
exists,
|
||||
extract,
|
||||
false,
|
||||
func,
|
||||
insert,
|
||||
intersect,
|
||||
intersect_all,
|
||||
join,
|
||||
literal,
|
||||
literal_column,
|
||||
modifier,
|
||||
not_,
|
||||
null,
|
||||
or_,
|
||||
outerjoin,
|
||||
outparam,
|
||||
over,
|
||||
select,
|
||||
subquery,
|
||||
text,
|
||||
true,
|
||||
tuple_,
|
||||
type_coerce,
|
||||
union,
|
||||
union_all,
|
||||
update,
|
||||
)
|
||||
|
||||
from .types import (
|
||||
BIGINT,
|
||||
BINARY,
|
||||
BLOB,
|
||||
BOOLEAN,
|
||||
BigInteger,
|
||||
Binary,
|
||||
Boolean,
|
||||
CHAR,
|
||||
CLOB,
|
||||
DATE,
|
||||
DATETIME,
|
||||
DECIMAL,
|
||||
Date,
|
||||
DateTime,
|
||||
Enum,
|
||||
FLOAT,
|
||||
Float,
|
||||
INT,
|
||||
INTEGER,
|
||||
Integer,
|
||||
Interval,
|
||||
LargeBinary,
|
||||
NCHAR,
|
||||
NVARCHAR,
|
||||
NUMERIC,
|
||||
Numeric,
|
||||
PickleType,
|
||||
REAL,
|
||||
SMALLINT,
|
||||
SmallInteger,
|
||||
String,
|
||||
TEXT,
|
||||
TIME,
|
||||
TIMESTAMP,
|
||||
Text,
|
||||
Time,
|
||||
TypeDecorator,
|
||||
Unicode,
|
||||
UnicodeText,
|
||||
VARBINARY,
|
||||
VARCHAR,
|
||||
)
|
||||
|
||||
|
||||
from .schema import (
|
||||
CheckConstraint,
|
||||
Column,
|
||||
ColumnDefault,
|
||||
Constraint,
|
||||
DefaultClause,
|
||||
FetchedValue,
|
||||
ForeignKey,
|
||||
ForeignKeyConstraint,
|
||||
Index,
|
||||
MetaData,
|
||||
PassiveDefault,
|
||||
PrimaryKeyConstraint,
|
||||
Sequence,
|
||||
Table,
|
||||
ThreadLocalMetaData,
|
||||
UniqueConstraint,
|
||||
DDL,
|
||||
)
|
||||
|
||||
|
||||
from .inspection import inspect
|
||||
from .engine import create_engine, engine_from_config
|
||||
|
||||
__version__ = '0.9.4'
|
||||
|
||||
def __go(lcls):
|
||||
global __all__
|
||||
|
||||
from . import events
|
||||
from . import util as _sa_util
|
||||
|
||||
import inspect as _inspect
|
||||
|
||||
__all__ = sorted(name for name, obj in lcls.items()
|
||||
if not (name.startswith('_') or _inspect.ismodule(obj)))
|
||||
|
||||
_sa_util.dependencies.resolve_all("sqlalchemy")
|
||||
__go(locals())
|
|
@ -1,706 +0,0 @@
|
|||
/*
|
||||
processors.c
|
||||
Copyright (C) 2010-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com
|
||||
|
||||
This module is part of SQLAlchemy and is released under
|
||||
the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
*/
|
||||
|
||||
#include <Python.h>
|
||||
#include <datetime.h>
|
||||
|
||||
#define MODULE_NAME "cprocessors"
|
||||
#define MODULE_DOC "Module containing C versions of data processing functions."
|
||||
|
||||
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
|
||||
typedef int Py_ssize_t;
|
||||
#define PY_SSIZE_T_MAX INT_MAX
|
||||
#define PY_SSIZE_T_MIN INT_MIN
|
||||
#endif
|
||||
|
||||
static PyObject *
|
||||
int_to_boolean(PyObject *self, PyObject *arg)
|
||||
{
|
||||
long l = 0;
|
||||
PyObject *res;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
l = PyLong_AsLong(arg);
|
||||
#else
|
||||
l = PyInt_AsLong(arg);
|
||||
#endif
|
||||
if (l == 0) {
|
||||
res = Py_False;
|
||||
} else if (l == 1) {
|
||||
res = Py_True;
|
||||
} else if ((l == -1) && PyErr_Occurred()) {
|
||||
/* -1 can be either the actual value, or an error flag. */
|
||||
return NULL;
|
||||
} else {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"int_to_boolean only accepts None, 0 or 1");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Py_INCREF(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
to_str(PyObject *self, PyObject *arg)
|
||||
{
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
return PyObject_Str(arg);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
to_float(PyObject *self, PyObject *arg)
|
||||
{
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
return PyNumber_Float(arg);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
str_to_datetime(PyObject *self, PyObject *arg)
|
||||
{
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *bytes;
|
||||
PyObject *err_bytes;
|
||||
#endif
|
||||
const char *str;
|
||||
int numparsed;
|
||||
unsigned int year, month, day, hour, minute, second, microsecond = 0;
|
||||
PyObject *err_repr;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
bytes = PyUnicode_AsASCIIString(arg);
|
||||
if (bytes == NULL)
|
||||
str = NULL;
|
||||
else
|
||||
str = PyBytes_AS_STRING(bytes);
|
||||
#else
|
||||
str = PyString_AsString(arg);
|
||||
#endif
|
||||
if (str == NULL) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse datetime string '%.200s' "
|
||||
"- value is not a string.",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse datetime string '%.200s' "
|
||||
"- value is not a string.",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* microseconds are optional */
|
||||
/*
|
||||
TODO: this is slightly less picky than the Python version which would
|
||||
not accept "2000-01-01 00:00:00.". I don't know which is better, but they
|
||||
should be coherent.
|
||||
*/
|
||||
numparsed = sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day,
|
||||
&hour, &minute, &second, µsecond);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_DECREF(bytes);
|
||||
#endif
|
||||
if (numparsed < 6) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse datetime string: %.200s",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse datetime string: %.200s",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
return PyDateTime_FromDateAndTime(year, month, day,
|
||||
hour, minute, second, microsecond);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
str_to_time(PyObject *self, PyObject *arg)
|
||||
{
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *bytes;
|
||||
PyObject *err_bytes;
|
||||
#endif
|
||||
const char *str;
|
||||
int numparsed;
|
||||
unsigned int hour, minute, second, microsecond = 0;
|
||||
PyObject *err_repr;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
bytes = PyUnicode_AsASCIIString(arg);
|
||||
if (bytes == NULL)
|
||||
str = NULL;
|
||||
else
|
||||
str = PyBytes_AS_STRING(bytes);
|
||||
#else
|
||||
str = PyString_AsString(arg);
|
||||
#endif
|
||||
if (str == NULL) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse time string '%.200s' - value is not a string.",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse time string '%.200s' - value is not a string.",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* microseconds are optional */
|
||||
/*
|
||||
TODO: this is slightly less picky than the Python version which would
|
||||
not accept "00:00:00.". I don't know which is better, but they should be
|
||||
coherent.
|
||||
*/
|
||||
numparsed = sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second,
|
||||
µsecond);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_DECREF(bytes);
|
||||
#endif
|
||||
if (numparsed < 3) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse time string: %.200s",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse time string: %.200s",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
return PyTime_FromTime(hour, minute, second, microsecond);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
str_to_date(PyObject *self, PyObject *arg)
|
||||
{
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *bytes;
|
||||
PyObject *err_bytes;
|
||||
#endif
|
||||
const char *str;
|
||||
int numparsed;
|
||||
unsigned int year, month, day;
|
||||
PyObject *err_repr;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
bytes = PyUnicode_AsASCIIString(arg);
|
||||
if (bytes == NULL)
|
||||
str = NULL;
|
||||
else
|
||||
str = PyBytes_AS_STRING(bytes);
|
||||
#else
|
||||
str = PyString_AsString(arg);
|
||||
#endif
|
||||
if (str == NULL) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse date string '%.200s' - value is not a string.",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse date string '%.200s' - value is not a string.",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
numparsed = sscanf(str, "%4u-%2u-%2u", &year, &month, &day);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_DECREF(bytes);
|
||||
#endif
|
||||
if (numparsed != 3) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse date string: %.200s",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse date string: %.200s",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
return PyDate_FromDate(year, month, day);
|
||||
}
|
||||
|
||||
|
||||
/***********
|
||||
* Structs *
|
||||
***********/
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
PyObject *encoding;
|
||||
PyObject *errors;
|
||||
} UnicodeResultProcessor;
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
PyObject *type;
|
||||
PyObject *format;
|
||||
} DecimalResultProcessor;
|
||||
|
||||
|
||||
|
||||
/**************************
|
||||
* UnicodeResultProcessor *
|
||||
**************************/
|
||||
|
||||
static int
|
||||
UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args,
|
||||
PyObject *kwds)
|
||||
{
|
||||
PyObject *encoding, *errors = NULL;
|
||||
static char *kwlist[] = {"encoding", "errors", NULL};
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "U|U:__init__", kwlist,
|
||||
&encoding, &errors))
|
||||
return -1;
|
||||
#else
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:__init__", kwlist,
|
||||
&encoding, &errors))
|
||||
return -1;
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
encoding = PyUnicode_AsASCIIString(encoding);
|
||||
#else
|
||||
Py_INCREF(encoding);
|
||||
#endif
|
||||
self->encoding = encoding;
|
||||
|
||||
if (errors) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
errors = PyUnicode_AsASCIIString(errors);
|
||||
#else
|
||||
Py_INCREF(errors);
|
||||
#endif
|
||||
} else {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
errors = PyBytes_FromString("strict");
|
||||
#else
|
||||
errors = PyString_FromString("strict");
|
||||
#endif
|
||||
if (errors == NULL)
|
||||
return -1;
|
||||
}
|
||||
self->errors = errors;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value)
|
||||
{
|
||||
const char *encoding, *errors;
|
||||
char *str;
|
||||
Py_ssize_t len;
|
||||
|
||||
if (value == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (PyBytes_AsStringAndSize(value, &str, &len))
|
||||
return NULL;
|
||||
|
||||
encoding = PyBytes_AS_STRING(self->encoding);
|
||||
errors = PyBytes_AS_STRING(self->errors);
|
||||
#else
|
||||
if (PyString_AsStringAndSize(value, &str, &len))
|
||||
return NULL;
|
||||
|
||||
encoding = PyString_AS_STRING(self->encoding);
|
||||
errors = PyString_AS_STRING(self->errors);
|
||||
#endif
|
||||
|
||||
return PyUnicode_Decode(str, len, encoding, errors);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
UnicodeResultProcessor_conditional_process(UnicodeResultProcessor *self, PyObject *value)
|
||||
{
|
||||
const char *encoding, *errors;
|
||||
char *str;
|
||||
Py_ssize_t len;
|
||||
|
||||
if (value == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (PyUnicode_Check(value) == 1) {
|
||||
Py_INCREF(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
if (PyBytes_AsStringAndSize(value, &str, &len))
|
||||
return NULL;
|
||||
|
||||
encoding = PyBytes_AS_STRING(self->encoding);
|
||||
errors = PyBytes_AS_STRING(self->errors);
|
||||
#else
|
||||
|
||||
if (PyUnicode_Check(value) == 1) {
|
||||
Py_INCREF(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
if (PyString_AsStringAndSize(value, &str, &len))
|
||||
return NULL;
|
||||
|
||||
|
||||
encoding = PyString_AS_STRING(self->encoding);
|
||||
errors = PyString_AS_STRING(self->errors);
|
||||
#endif
|
||||
|
||||
return PyUnicode_Decode(str, len, encoding, errors);
|
||||
}
|
||||
|
||||
static void
|
||||
UnicodeResultProcessor_dealloc(UnicodeResultProcessor *self)
|
||||
{
|
||||
Py_XDECREF(self->encoding);
|
||||
Py_XDECREF(self->errors);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
#else
|
||||
self->ob_type->tp_free((PyObject*)self);
|
||||
#endif
|
||||
}
|
||||
|
||||
static PyMethodDef UnicodeResultProcessor_methods[] = {
|
||||
{"process", (PyCFunction)UnicodeResultProcessor_process, METH_O,
|
||||
"The value processor itself."},
|
||||
{"conditional_process", (PyCFunction)UnicodeResultProcessor_conditional_process, METH_O,
|
||||
"Conditional version of the value processor."},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static PyTypeObject UnicodeResultProcessorType = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
"sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */
|
||||
sizeof(UnicodeResultProcessor), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)UnicodeResultProcessor_dealloc, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_compare */
|
||||
0, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
0, /* tp_as_sequence */
|
||||
0, /* tp_as_mapping */
|
||||
0, /* tp_hash */
|
||||
0, /* tp_call */
|
||||
0, /* tp_str */
|
||||
0, /* tp_getattro */
|
||||
0, /* tp_setattro */
|
||||
0, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
||||
"UnicodeResultProcessor objects", /* tp_doc */
|
||||
0, /* tp_traverse */
|
||||
0, /* tp_clear */
|
||||
0, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
0, /* tp_iter */
|
||||
0, /* tp_iternext */
|
||||
UnicodeResultProcessor_methods, /* tp_methods */
|
||||
0, /* tp_members */
|
||||
0, /* tp_getset */
|
||||
0, /* tp_base */
|
||||
0, /* tp_dict */
|
||||
0, /* tp_descr_get */
|
||||
0, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)UnicodeResultProcessor_init, /* tp_init */
|
||||
0, /* tp_alloc */
|
||||
0, /* tp_new */
|
||||
};
|
||||
|
||||
/**************************
|
||||
* DecimalResultProcessor *
|
||||
**************************/
|
||||
|
||||
static int
|
||||
DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args,
|
||||
PyObject *kwds)
|
||||
{
|
||||
PyObject *type, *format;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (!PyArg_ParseTuple(args, "OU", &type, &format))
|
||||
#else
|
||||
if (!PyArg_ParseTuple(args, "OS", &type, &format))
|
||||
#endif
|
||||
return -1;
|
||||
|
||||
Py_INCREF(type);
|
||||
self->type = type;
|
||||
|
||||
Py_INCREF(format);
|
||||
self->format = format;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value)
|
||||
{
|
||||
PyObject *str, *result, *args;
|
||||
|
||||
if (value == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
/* Decimal does not accept float values directly */
|
||||
/* SQLite can also give us an integer here (see [ticket:2432]) */
|
||||
/* XXX: starting with Python 3.1, we could use Decimal.from_float(f),
|
||||
but the result wouldn't be the same */
|
||||
|
||||
args = PyTuple_Pack(1, value);
|
||||
if (args == NULL)
|
||||
return NULL;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
str = PyUnicode_Format(self->format, args);
|
||||
#else
|
||||
str = PyString_Format(self->format, args);
|
||||
#endif
|
||||
|
||||
Py_DECREF(args);
|
||||
if (str == NULL)
|
||||
return NULL;
|
||||
|
||||
result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
|
||||
Py_DECREF(str);
|
||||
return result;
|
||||
}
|
||||
|
||||
static void
|
||||
DecimalResultProcessor_dealloc(DecimalResultProcessor *self)
|
||||
{
|
||||
Py_XDECREF(self->type);
|
||||
Py_XDECREF(self->format);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
#else
|
||||
self->ob_type->tp_free((PyObject*)self);
|
||||
#endif
|
||||
}
|
||||
|
||||
static PyMethodDef DecimalResultProcessor_methods[] = {
|
||||
{"process", (PyCFunction)DecimalResultProcessor_process, METH_O,
|
||||
"The value processor itself."},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static PyTypeObject DecimalResultProcessorType = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
"sqlalchemy.DecimalResultProcessor", /* tp_name */
|
||||
sizeof(DecimalResultProcessor), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)DecimalResultProcessor_dealloc, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_compare */
|
||||
0, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
0, /* tp_as_sequence */
|
||||
0, /* tp_as_mapping */
|
||||
0, /* tp_hash */
|
||||
0, /* tp_call */
|
||||
0, /* tp_str */
|
||||
0, /* tp_getattro */
|
||||
0, /* tp_setattro */
|
||||
0, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
||||
"DecimalResultProcessor objects", /* tp_doc */
|
||||
0, /* tp_traverse */
|
||||
0, /* tp_clear */
|
||||
0, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
0, /* tp_iter */
|
||||
0, /* tp_iternext */
|
||||
DecimalResultProcessor_methods, /* tp_methods */
|
||||
0, /* tp_members */
|
||||
0, /* tp_getset */
|
||||
0, /* tp_base */
|
||||
0, /* tp_dict */
|
||||
0, /* tp_descr_get */
|
||||
0, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)DecimalResultProcessor_init, /* tp_init */
|
||||
0, /* tp_alloc */
|
||||
0, /* tp_new */
|
||||
};
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"int_to_boolean", int_to_boolean, METH_O,
|
||||
"Convert an integer to a boolean."},
|
||||
{"to_str", to_str, METH_O,
|
||||
"Convert any value to its string representation."},
|
||||
{"to_float", to_float, METH_O,
|
||||
"Convert any value to its floating point representation."},
|
||||
{"str_to_datetime", str_to_datetime, METH_O,
|
||||
"Convert an ISO string to a datetime.datetime object."},
|
||||
{"str_to_time", str_to_time, METH_O,
|
||||
"Convert an ISO string to a datetime.time object."},
|
||||
{"str_to_date", str_to_date, METH_O,
|
||||
"Convert an ISO string to a datetime.date object."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
|
||||
#define PyMODINIT_FUNC void
|
||||
#endif
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
|
||||
static struct PyModuleDef module_def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
MODULE_NAME,
|
||||
MODULE_DOC,
|
||||
-1,
|
||||
module_methods
|
||||
};
|
||||
|
||||
#define INITERROR return NULL
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit_cprocessors(void)
|
||||
|
||||
#else
|
||||
|
||||
#define INITERROR return
|
||||
|
||||
PyMODINIT_FUNC
|
||||
initcprocessors(void)
|
||||
|
||||
#endif
|
||||
|
||||
{
|
||||
PyObject *m;
|
||||
|
||||
UnicodeResultProcessorType.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&UnicodeResultProcessorType) < 0)
|
||||
INITERROR;
|
||||
|
||||
DecimalResultProcessorType.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&DecimalResultProcessorType) < 0)
|
||||
INITERROR;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
m = PyModule_Create(&module_def);
|
||||
#else
|
||||
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
|
||||
#endif
|
||||
if (m == NULL)
|
||||
INITERROR;
|
||||
|
||||
PyDateTime_IMPORT;
|
||||
|
||||
Py_INCREF(&UnicodeResultProcessorType);
|
||||
PyModule_AddObject(m, "UnicodeResultProcessor",
|
||||
(PyObject *)&UnicodeResultProcessorType);
|
||||
|
||||
Py_INCREF(&DecimalResultProcessorType);
|
||||
PyModule_AddObject(m, "DecimalResultProcessor",
|
||||
(PyObject *)&DecimalResultProcessorType);
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return m;
|
||||
#endif
|
||||
}
|
|
@ -1,718 +0,0 @@
|
|||
/*
|
||||
resultproxy.c
|
||||
Copyright (C) 2010-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com
|
||||
|
||||
This module is part of SQLAlchemy and is released under
|
||||
the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
*/
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#define MODULE_NAME "cresultproxy"
|
||||
#define MODULE_DOC "Module containing C versions of core ResultProxy classes."
|
||||
|
||||
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
|
||||
typedef int Py_ssize_t;
|
||||
#define PY_SSIZE_T_MAX INT_MAX
|
||||
#define PY_SSIZE_T_MIN INT_MIN
|
||||
typedef Py_ssize_t (*lenfunc)(PyObject *);
|
||||
#define PyInt_FromSsize_t(x) PyInt_FromLong(x)
|
||||
typedef intargfunc ssizeargfunc;
|
||||
#endif
|
||||
|
||||
|
||||
/***********
|
||||
* Structs *
|
||||
***********/
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
PyObject *parent;
|
||||
PyObject *row;
|
||||
PyObject *processors;
|
||||
PyObject *keymap;
|
||||
} BaseRowProxy;
|
||||
|
||||
/****************
|
||||
* BaseRowProxy *
|
||||
****************/
|
||||
|
||||
static PyObject *
|
||||
safe_rowproxy_reconstructor(PyObject *self, PyObject *args)
|
||||
{
|
||||
PyObject *cls, *state, *tmp;
|
||||
BaseRowProxy *obj;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OO", &cls, &state))
|
||||
return NULL;
|
||||
|
||||
obj = (BaseRowProxy *)PyObject_CallMethod(cls, "__new__", "O", cls);
|
||||
if (obj == NULL)
|
||||
return NULL;
|
||||
|
||||
tmp = PyObject_CallMethod((PyObject *)obj, "__setstate__", "O", state);
|
||||
if (tmp == NULL) {
|
||||
Py_DECREF(obj);
|
||||
return NULL;
|
||||
}
|
||||
Py_DECREF(tmp);
|
||||
|
||||
if (obj->parent == NULL || obj->row == NULL ||
|
||||
obj->processors == NULL || obj->keymap == NULL) {
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
"__setstate__ for BaseRowProxy subclasses must set values "
|
||||
"for parent, row, processors and keymap");
|
||||
Py_DECREF(obj);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return (PyObject *)obj;
|
||||
}
|
||||
|
||||
static int
|
||||
BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds)
|
||||
{
|
||||
PyObject *parent, *row, *processors, *keymap;
|
||||
|
||||
if (!PyArg_UnpackTuple(args, "BaseRowProxy", 4, 4,
|
||||
&parent, &row, &processors, &keymap))
|
||||
return -1;
|
||||
|
||||
Py_INCREF(parent);
|
||||
self->parent = parent;
|
||||
|
||||
if (!PySequence_Check(row)) {
|
||||
PyErr_SetString(PyExc_TypeError, "row must be a sequence");
|
||||
return -1;
|
||||
}
|
||||
Py_INCREF(row);
|
||||
self->row = row;
|
||||
|
||||
if (!PyList_CheckExact(processors)) {
|
||||
PyErr_SetString(PyExc_TypeError, "processors must be a list");
|
||||
return -1;
|
||||
}
|
||||
Py_INCREF(processors);
|
||||
self->processors = processors;
|
||||
|
||||
if (!PyDict_CheckExact(keymap)) {
|
||||
PyErr_SetString(PyExc_TypeError, "keymap must be a dict");
|
||||
return -1;
|
||||
}
|
||||
Py_INCREF(keymap);
|
||||
self->keymap = keymap;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* We need the reduce method because otherwise the default implementation
|
||||
* does very weird stuff for pickle protocol 0 and 1. It calls
|
||||
* BaseRowProxy.__new__(RowProxy_instance) upon *pickling*.
|
||||
*/
|
||||
static PyObject *
|
||||
BaseRowProxy_reduce(PyObject *self)
|
||||
{
|
||||
PyObject *method, *state;
|
||||
PyObject *module, *reconstructor, *cls;
|
||||
|
||||
method = PyObject_GetAttrString(self, "__getstate__");
|
||||
if (method == NULL)
|
||||
return NULL;
|
||||
|
||||
state = PyObject_CallObject(method, NULL);
|
||||
Py_DECREF(method);
|
||||
if (state == NULL)
|
||||
return NULL;
|
||||
|
||||
module = PyImport_ImportModule("sqlalchemy.engine.result");
|
||||
if (module == NULL)
|
||||
return NULL;
|
||||
|
||||
reconstructor = PyObject_GetAttrString(module, "rowproxy_reconstructor");
|
||||
Py_DECREF(module);
|
||||
if (reconstructor == NULL) {
|
||||
Py_DECREF(state);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
cls = PyObject_GetAttrString(self, "__class__");
|
||||
if (cls == NULL) {
|
||||
Py_DECREF(reconstructor);
|
||||
Py_DECREF(state);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return Py_BuildValue("(N(NN))", reconstructor, cls, state);
|
||||
}
|
||||
|
||||
static void
|
||||
BaseRowProxy_dealloc(BaseRowProxy *self)
|
||||
{
|
||||
Py_XDECREF(self->parent);
|
||||
Py_XDECREF(self->row);
|
||||
Py_XDECREF(self->processors);
|
||||
Py_XDECREF(self->keymap);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_TYPE(self)->tp_free((PyObject *)self);
|
||||
#else
|
||||
self->ob_type->tp_free((PyObject *)self);
|
||||
#endif
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
|
||||
{
|
||||
Py_ssize_t num_values, num_processors;
|
||||
PyObject **valueptr, **funcptr, **resultptr;
|
||||
PyObject *func, *result, *processed_value, *values_fastseq;
|
||||
|
||||
num_values = PySequence_Length(values);
|
||||
num_processors = PyList_Size(processors);
|
||||
if (num_values != num_processors) {
|
||||
PyErr_Format(PyExc_RuntimeError,
|
||||
"number of values in row (%d) differ from number of column "
|
||||
"processors (%d)",
|
||||
(int)num_values, (int)num_processors);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (astuple) {
|
||||
result = PyTuple_New(num_values);
|
||||
} else {
|
||||
result = PyList_New(num_values);
|
||||
}
|
||||
if (result == NULL)
|
||||
return NULL;
|
||||
|
||||
values_fastseq = PySequence_Fast(values, "row must be a sequence");
|
||||
if (values_fastseq == NULL)
|
||||
return NULL;
|
||||
|
||||
valueptr = PySequence_Fast_ITEMS(values_fastseq);
|
||||
funcptr = PySequence_Fast_ITEMS(processors);
|
||||
resultptr = PySequence_Fast_ITEMS(result);
|
||||
while (--num_values >= 0) {
|
||||
func = *funcptr;
|
||||
if (func != Py_None) {
|
||||
processed_value = PyObject_CallFunctionObjArgs(func, *valueptr,
|
||||
NULL);
|
||||
if (processed_value == NULL) {
|
||||
Py_DECREF(values_fastseq);
|
||||
Py_DECREF(result);
|
||||
return NULL;
|
||||
}
|
||||
*resultptr = processed_value;
|
||||
} else {
|
||||
Py_INCREF(*valueptr);
|
||||
*resultptr = *valueptr;
|
||||
}
|
||||
valueptr++;
|
||||
funcptr++;
|
||||
resultptr++;
|
||||
}
|
||||
Py_DECREF(values_fastseq);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyListObject *
|
||||
BaseRowProxy_values(BaseRowProxy *self)
|
||||
{
|
||||
return (PyListObject *)BaseRowProxy_processvalues(self->row,
|
||||
self->processors, 0);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_iter(BaseRowProxy *self)
|
||||
{
|
||||
PyObject *values, *result;
|
||||
|
||||
values = BaseRowProxy_processvalues(self->row, self->processors, 1);
|
||||
if (values == NULL)
|
||||
return NULL;
|
||||
|
||||
result = PyObject_GetIter(values);
|
||||
Py_DECREF(values);
|
||||
if (result == NULL)
|
||||
return NULL;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static Py_ssize_t
|
||||
BaseRowProxy_length(BaseRowProxy *self)
|
||||
{
|
||||
return PySequence_Length(self->row);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
|
||||
{
|
||||
PyObject *processors, *values;
|
||||
PyObject *processor, *value, *processed_value;
|
||||
PyObject *row, *record, *result, *indexobject;
|
||||
PyObject *exc_module, *exception, *cstr_obj;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *bytes;
|
||||
#endif
|
||||
char *cstr_key;
|
||||
long index;
|
||||
int key_fallback = 0;
|
||||
int tuple_check = 0;
|
||||
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (PyInt_CheckExact(key)) {
|
||||
index = PyInt_AS_LONG(key);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (PyLong_CheckExact(key)) {
|
||||
index = PyLong_AsLong(key);
|
||||
if ((index == -1) && PyErr_Occurred())
|
||||
/* -1 can be either the actual value, or an error flag. */
|
||||
return NULL;
|
||||
} else if (PySlice_Check(key)) {
|
||||
values = PyObject_GetItem(self->row, key);
|
||||
if (values == NULL)
|
||||
return NULL;
|
||||
|
||||
processors = PyObject_GetItem(self->processors, key);
|
||||
if (processors == NULL) {
|
||||
Py_DECREF(values);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
result = BaseRowProxy_processvalues(values, processors, 1);
|
||||
Py_DECREF(values);
|
||||
Py_DECREF(processors);
|
||||
return result;
|
||||
} else {
|
||||
record = PyDict_GetItem((PyObject *)self->keymap, key);
|
||||
if (record == NULL) {
|
||||
record = PyObject_CallMethod(self->parent, "_key_fallback",
|
||||
"O", key);
|
||||
if (record == NULL)
|
||||
return NULL;
|
||||
key_fallback = 1;
|
||||
}
|
||||
|
||||
indexobject = PyTuple_GetItem(record, 2);
|
||||
if (indexobject == NULL)
|
||||
return NULL;
|
||||
|
||||
if (key_fallback) {
|
||||
Py_DECREF(record);
|
||||
}
|
||||
|
||||
if (indexobject == Py_None) {
|
||||
exc_module = PyImport_ImportModule("sqlalchemy.exc");
|
||||
if (exc_module == NULL)
|
||||
return NULL;
|
||||
|
||||
exception = PyObject_GetAttrString(exc_module,
|
||||
"InvalidRequestError");
|
||||
Py_DECREF(exc_module);
|
||||
if (exception == NULL)
|
||||
return NULL;
|
||||
|
||||
// wow. this seems quite excessive.
|
||||
cstr_obj = PyObject_Str(key);
|
||||
if (cstr_obj == NULL)
|
||||
return NULL;
|
||||
|
||||
/*
|
||||
FIXME: raise encoding error exception (in both versions below)
|
||||
if the key contains non-ascii chars, instead of an
|
||||
InvalidRequestError without any message like in the
|
||||
python version.
|
||||
*/
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
bytes = PyUnicode_AsASCIIString(cstr_obj);
|
||||
if (bytes == NULL)
|
||||
return NULL;
|
||||
cstr_key = PyBytes_AS_STRING(bytes);
|
||||
#else
|
||||
cstr_key = PyString_AsString(cstr_obj);
|
||||
#endif
|
||||
if (cstr_key == NULL) {
|
||||
Py_DECREF(cstr_obj);
|
||||
return NULL;
|
||||
}
|
||||
Py_DECREF(cstr_obj);
|
||||
|
||||
PyErr_Format(exception,
|
||||
"Ambiguous column name '%.200s' in result set! "
|
||||
"try 'use_labels' option on select statement.", cstr_key);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
index = PyLong_AsLong(indexobject);
|
||||
#else
|
||||
index = PyInt_AsLong(indexobject);
|
||||
#endif
|
||||
if ((index == -1) && PyErr_Occurred())
|
||||
/* -1 can be either the actual value, or an error flag. */
|
||||
return NULL;
|
||||
}
|
||||
processor = PyList_GetItem(self->processors, index);
|
||||
if (processor == NULL)
|
||||
return NULL;
|
||||
|
||||
row = self->row;
|
||||
if (PyTuple_CheckExact(row)) {
|
||||
value = PyTuple_GetItem(row, index);
|
||||
tuple_check = 1;
|
||||
}
|
||||
else {
|
||||
value = PySequence_GetItem(row, index);
|
||||
tuple_check = 0;
|
||||
}
|
||||
|
||||
if (value == NULL)
|
||||
return NULL;
|
||||
|
||||
if (processor != Py_None) {
|
||||
processed_value = PyObject_CallFunctionObjArgs(processor, value, NULL);
|
||||
if (!tuple_check) {
|
||||
Py_DECREF(value);
|
||||
}
|
||||
return processed_value;
|
||||
} else {
|
||||
if (tuple_check) {
|
||||
Py_INCREF(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getitem(PyObject *self, Py_ssize_t i)
|
||||
{
|
||||
PyObject *index;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
index = PyLong_FromSsize_t(i);
|
||||
#else
|
||||
index = PyInt_FromSsize_t(i);
|
||||
#endif
|
||||
return BaseRowProxy_subscript((BaseRowProxy*)self, index);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
|
||||
{
|
||||
PyObject *tmp;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *err_bytes;
|
||||
#endif
|
||||
|
||||
if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) {
|
||||
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
|
||||
return NULL;
|
||||
PyErr_Clear();
|
||||
}
|
||||
else
|
||||
return tmp;
|
||||
|
||||
tmp = BaseRowProxy_subscript(self, name);
|
||||
if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) {
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(name);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_AttributeError,
|
||||
"Could not locate column in row for column '%.200s'",
|
||||
PyBytes_AS_STRING(err_bytes)
|
||||
);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_AttributeError,
|
||||
"Could not locate column in row for column '%.200s'",
|
||||
PyString_AsString(name)
|
||||
);
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
/***********************
|
||||
* getters and setters *
|
||||
***********************/
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getparent(BaseRowProxy *self, void *closure)
|
||||
{
|
||||
Py_INCREF(self->parent);
|
||||
return self->parent;
|
||||
}
|
||||
|
||||
static int
|
||||
BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure)
|
||||
{
|
||||
PyObject *module, *cls;
|
||||
|
||||
if (value == NULL) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Cannot delete the 'parent' attribute");
|
||||
return -1;
|
||||
}
|
||||
|
||||
module = PyImport_ImportModule("sqlalchemy.engine.result");
|
||||
if (module == NULL)
|
||||
return -1;
|
||||
|
||||
cls = PyObject_GetAttrString(module, "ResultMetaData");
|
||||
Py_DECREF(module);
|
||||
if (cls == NULL)
|
||||
return -1;
|
||||
|
||||
if (PyObject_IsInstance(value, cls) != 1) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"The 'parent' attribute value must be an instance of "
|
||||
"ResultMetaData");
|
||||
return -1;
|
||||
}
|
||||
Py_DECREF(cls);
|
||||
Py_XDECREF(self->parent);
|
||||
Py_INCREF(value);
|
||||
self->parent = value;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getrow(BaseRowProxy *self, void *closure)
|
||||
{
|
||||
Py_INCREF(self->row);
|
||||
return self->row;
|
||||
}
|
||||
|
||||
static int
|
||||
BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure)
|
||||
{
|
||||
if (value == NULL) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Cannot delete the 'row' attribute");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!PySequence_Check(value)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"The 'row' attribute value must be a sequence");
|
||||
return -1;
|
||||
}
|
||||
|
||||
Py_XDECREF(self->row);
|
||||
Py_INCREF(value);
|
||||
self->row = value;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getprocessors(BaseRowProxy *self, void *closure)
|
||||
{
|
||||
Py_INCREF(self->processors);
|
||||
return self->processors;
|
||||
}
|
||||
|
||||
static int
|
||||
BaseRowProxy_setprocessors(BaseRowProxy *self, PyObject *value, void *closure)
|
||||
{
|
||||
if (value == NULL) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Cannot delete the 'processors' attribute");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!PyList_CheckExact(value)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"The 'processors' attribute value must be a list");
|
||||
return -1;
|
||||
}
|
||||
|
||||
Py_XDECREF(self->processors);
|
||||
Py_INCREF(value);
|
||||
self->processors = value;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getkeymap(BaseRowProxy *self, void *closure)
|
||||
{
|
||||
Py_INCREF(self->keymap);
|
||||
return self->keymap;
|
||||
}
|
||||
|
||||
static int
|
||||
BaseRowProxy_setkeymap(BaseRowProxy *self, PyObject *value, void *closure)
|
||||
{
|
||||
if (value == NULL) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Cannot delete the 'keymap' attribute");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!PyDict_CheckExact(value)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"The 'keymap' attribute value must be a dict");
|
||||
return -1;
|
||||
}
|
||||
|
||||
Py_XDECREF(self->keymap);
|
||||
Py_INCREF(value);
|
||||
self->keymap = value;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyGetSetDef BaseRowProxy_getseters[] = {
|
||||
{"_parent",
|
||||
(getter)BaseRowProxy_getparent, (setter)BaseRowProxy_setparent,
|
||||
"ResultMetaData",
|
||||
NULL},
|
||||
{"_row",
|
||||
(getter)BaseRowProxy_getrow, (setter)BaseRowProxy_setrow,
|
||||
"Original row tuple",
|
||||
NULL},
|
||||
{"_processors",
|
||||
(getter)BaseRowProxy_getprocessors, (setter)BaseRowProxy_setprocessors,
|
||||
"list of type processors",
|
||||
NULL},
|
||||
{"_keymap",
|
||||
(getter)BaseRowProxy_getkeymap, (setter)BaseRowProxy_setkeymap,
|
||||
"Key to (processor, index) dict",
|
||||
NULL},
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyMethodDef BaseRowProxy_methods[] = {
|
||||
{"values", (PyCFunction)BaseRowProxy_values, METH_NOARGS,
|
||||
"Return the values represented by this BaseRowProxy as a list."},
|
||||
{"__reduce__", (PyCFunction)BaseRowProxy_reduce, METH_NOARGS,
|
||||
"Pickle support method."},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static PySequenceMethods BaseRowProxy_as_sequence = {
|
||||
(lenfunc)BaseRowProxy_length, /* sq_length */
|
||||
0, /* sq_concat */
|
||||
0, /* sq_repeat */
|
||||
(ssizeargfunc)BaseRowProxy_getitem, /* sq_item */
|
||||
0, /* sq_slice */
|
||||
0, /* sq_ass_item */
|
||||
0, /* sq_ass_slice */
|
||||
0, /* sq_contains */
|
||||
0, /* sq_inplace_concat */
|
||||
0, /* sq_inplace_repeat */
|
||||
};
|
||||
|
||||
static PyMappingMethods BaseRowProxy_as_mapping = {
|
||||
(lenfunc)BaseRowProxy_length, /* mp_length */
|
||||
(binaryfunc)BaseRowProxy_subscript, /* mp_subscript */
|
||||
0 /* mp_ass_subscript */
|
||||
};
|
||||
|
||||
static PyTypeObject BaseRowProxyType = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
"sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */
|
||||
sizeof(BaseRowProxy), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)BaseRowProxy_dealloc, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_compare */
|
||||
0, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
&BaseRowProxy_as_sequence, /* tp_as_sequence */
|
||||
&BaseRowProxy_as_mapping, /* tp_as_mapping */
|
||||
0, /* tp_hash */
|
||||
0, /* tp_call */
|
||||
0, /* tp_str */
|
||||
(getattrofunc)BaseRowProxy_getattro,/* tp_getattro */
|
||||
0, /* tp_setattro */
|
||||
0, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
||||
"BaseRowProxy is a abstract base class for RowProxy", /* tp_doc */
|
||||
0, /* tp_traverse */
|
||||
0, /* tp_clear */
|
||||
0, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
(getiterfunc)BaseRowProxy_iter, /* tp_iter */
|
||||
0, /* tp_iternext */
|
||||
BaseRowProxy_methods, /* tp_methods */
|
||||
0, /* tp_members */
|
||||
BaseRowProxy_getseters, /* tp_getset */
|
||||
0, /* tp_base */
|
||||
0, /* tp_dict */
|
||||
0, /* tp_descr_get */
|
||||
0, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)BaseRowProxy_init, /* tp_init */
|
||||
0, /* tp_alloc */
|
||||
0 /* tp_new */
|
||||
};
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS,
|
||||
"reconstruct a RowProxy instance from its pickled form."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
|
||||
#define PyMODINIT_FUNC void
|
||||
#endif
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
|
||||
static struct PyModuleDef module_def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
MODULE_NAME,
|
||||
MODULE_DOC,
|
||||
-1,
|
||||
module_methods
|
||||
};
|
||||
|
||||
#define INITERROR return NULL
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit_cresultproxy(void)
|
||||
|
||||
#else
|
||||
|
||||
#define INITERROR return
|
||||
|
||||
PyMODINIT_FUNC
|
||||
initcresultproxy(void)
|
||||
|
||||
#endif
|
||||
|
||||
{
|
||||
PyObject *m;
|
||||
|
||||
BaseRowProxyType.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&BaseRowProxyType) < 0)
|
||||
INITERROR;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
m = PyModule_Create(&module_def);
|
||||
#else
|
||||
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
|
||||
#endif
|
||||
if (m == NULL)
|
||||
INITERROR;
|
||||
|
||||
Py_INCREF(&BaseRowProxyType);
|
||||
PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType);
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return m;
|
||||
#endif
|
||||
}
|
|
@ -1,225 +0,0 @@
|
|||
/*
|
||||
utils.c
|
||||
Copyright (C) 2012-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
|
||||
This module is part of SQLAlchemy and is released under
|
||||
the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
*/
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#define MODULE_NAME "cutils"
|
||||
#define MODULE_DOC "Module containing C versions of utility functions."
|
||||
|
||||
/*
|
||||
Given arguments from the calling form *multiparams, **params,
|
||||
return a list of bind parameter structures, usually a list of
|
||||
dictionaries.
|
||||
|
||||
In the case of 'raw' execution which accepts positional parameters,
|
||||
it may be a list of tuples or lists.
|
||||
|
||||
*/
|
||||
static PyObject *
|
||||
distill_params(PyObject *self, PyObject *args)
|
||||
{
|
||||
PyObject *multiparams, *params;
|
||||
PyObject *enclosing_list, *double_enclosing_list;
|
||||
PyObject *zero_element, *zero_element_item;
|
||||
Py_ssize_t multiparam_size, zero_element_length;
|
||||
|
||||
if (!PyArg_UnpackTuple(args, "_distill_params", 2, 2, &multiparams, ¶ms)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (multiparams != Py_None) {
|
||||
multiparam_size = PyTuple_Size(multiparams);
|
||||
if (multiparam_size < 0) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
else {
|
||||
multiparam_size = 0;
|
||||
}
|
||||
|
||||
if (multiparam_size == 0) {
|
||||
if (params != Py_None && PyDict_Size(params) != 0) {
|
||||
enclosing_list = PyList_New(1);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(params);
|
||||
if (PyList_SetItem(enclosing_list, 0, params) == -1) {
|
||||
Py_DECREF(params);
|
||||
Py_DECREF(enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
else {
|
||||
enclosing_list = PyList_New(0);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
return enclosing_list;
|
||||
}
|
||||
else if (multiparam_size == 1) {
|
||||
zero_element = PyTuple_GetItem(multiparams, 0);
|
||||
if (PyTuple_Check(zero_element) || PyList_Check(zero_element)) {
|
||||
zero_element_length = PySequence_Length(zero_element);
|
||||
|
||||
if (zero_element_length != 0) {
|
||||
zero_element_item = PySequence_GetItem(zero_element, 0);
|
||||
if (zero_element_item == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
else {
|
||||
zero_element_item = NULL;
|
||||
}
|
||||
|
||||
if (zero_element_length == 0 ||
|
||||
(
|
||||
PyObject_HasAttrString(zero_element_item, "__iter__") &&
|
||||
!PyObject_HasAttrString(zero_element_item, "strip")
|
||||
)
|
||||
) {
|
||||
/*
|
||||
* execute(stmt, [{}, {}, {}, ...])
|
||||
* execute(stmt, [(), (), (), ...])
|
||||
*/
|
||||
Py_XDECREF(zero_element_item);
|
||||
Py_INCREF(zero_element);
|
||||
return zero_element;
|
||||
}
|
||||
else {
|
||||
/*
|
||||
* execute(stmt, ("value", "value"))
|
||||
*/
|
||||
Py_XDECREF(zero_element_item);
|
||||
enclosing_list = PyList_New(1);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(zero_element);
|
||||
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
|
||||
Py_DECREF(zero_element);
|
||||
Py_DECREF(enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
return enclosing_list;
|
||||
}
|
||||
}
|
||||
else if (PyObject_HasAttrString(zero_element, "keys")) {
|
||||
/*
|
||||
* execute(stmt, {"key":"value"})
|
||||
*/
|
||||
enclosing_list = PyList_New(1);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(zero_element);
|
||||
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
|
||||
Py_DECREF(zero_element);
|
||||
Py_DECREF(enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
return enclosing_list;
|
||||
} else {
|
||||
enclosing_list = PyList_New(1);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
double_enclosing_list = PyList_New(1);
|
||||
if (double_enclosing_list == NULL) {
|
||||
Py_DECREF(enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(zero_element);
|
||||
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
|
||||
Py_DECREF(zero_element);
|
||||
Py_DECREF(enclosing_list);
|
||||
Py_DECREF(double_enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
if (PyList_SetItem(double_enclosing_list, 0, enclosing_list) == -1) {
|
||||
Py_DECREF(zero_element);
|
||||
Py_DECREF(enclosing_list);
|
||||
Py_DECREF(double_enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
return double_enclosing_list;
|
||||
}
|
||||
}
|
||||
else {
|
||||
zero_element = PyTuple_GetItem(multiparams, 0);
|
||||
if (PyObject_HasAttrString(zero_element, "__iter__") &&
|
||||
!PyObject_HasAttrString(zero_element, "strip")
|
||||
) {
|
||||
Py_INCREF(multiparams);
|
||||
return multiparams;
|
||||
}
|
||||
else {
|
||||
enclosing_list = PyList_New(1);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(multiparams);
|
||||
if (PyList_SetItem(enclosing_list, 0, multiparams) == -1) {
|
||||
Py_DECREF(multiparams);
|
||||
Py_DECREF(enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
return enclosing_list;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"_distill_params", distill_params, METH_VARARGS,
|
||||
"Distill an execute() parameter structure."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
|
||||
#define PyMODINIT_FUNC void
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
|
||||
static struct PyModuleDef module_def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
MODULE_NAME,
|
||||
MODULE_DOC,
|
||||
-1,
|
||||
module_methods
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyMODINIT_FUNC
|
||||
PyInit_cutils(void)
|
||||
#else
|
||||
PyMODINIT_FUNC
|
||||
initcutils(void)
|
||||
#endif
|
||||
{
|
||||
PyObject *m;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
m = PyModule_Create(&module_def);
|
||||
#else
|
||||
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (m == NULL)
|
||||
return NULL;
|
||||
return m;
|
||||
#else
|
||||
if (m == NULL)
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# connectors/__init__.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
class Connector(object):
|
||||
pass
|
|
@ -1,149 +0,0 @@
|
|||
# connectors/mxodbc.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Provide an SQLALchemy connector for the eGenix mxODBC commercial
|
||||
Python adapter for ODBC. This is not a free product, but eGenix
|
||||
provides SQLAlchemy with a license for use in continuous integration
|
||||
testing.
|
||||
|
||||
This has been tested for use with mxODBC 3.1.2 on SQL Server 2005
|
||||
and 2008, using the SQL Server Native driver. However, it is
|
||||
possible for this to be used on other database platforms.
|
||||
|
||||
For more info on mxODBC, see http://www.egenix.com/
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
import re
|
||||
import warnings
|
||||
|
||||
from . import Connector
|
||||
|
||||
|
||||
class MxODBCConnector(Connector):
|
||||
driver = 'mxodbc'
|
||||
|
||||
supports_sane_multi_rowcount = False
|
||||
supports_unicode_statements = True
|
||||
supports_unicode_binds = True
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
# this classmethod will normally be replaced by an instance
|
||||
# attribute of the same name, so this is normally only called once.
|
||||
cls._load_mx_exceptions()
|
||||
platform = sys.platform
|
||||
if platform == 'win32':
|
||||
from mx.ODBC import Windows as module
|
||||
# this can be the string "linux2", and possibly others
|
||||
elif 'linux' in platform:
|
||||
from mx.ODBC import unixODBC as module
|
||||
elif platform == 'darwin':
|
||||
from mx.ODBC import iODBC as module
|
||||
else:
|
||||
raise ImportError("Unrecognized platform for mxODBC import")
|
||||
return module
|
||||
|
||||
@classmethod
|
||||
def _load_mx_exceptions(cls):
|
||||
""" Import mxODBC exception classes into the module namespace,
|
||||
as if they had been imported normally. This is done here
|
||||
to avoid requiring all SQLAlchemy users to install mxODBC.
|
||||
"""
|
||||
global InterfaceError, ProgrammingError
|
||||
from mx.ODBC import InterfaceError
|
||||
from mx.ODBC import ProgrammingError
|
||||
|
||||
def on_connect(self):
|
||||
def connect(conn):
|
||||
conn.stringformat = self.dbapi.MIXED_STRINGFORMAT
|
||||
conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT
|
||||
conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT
|
||||
conn.errorhandler = self._error_handler()
|
||||
return connect
|
||||
|
||||
def _error_handler(self):
|
||||
""" Return a handler that adjusts mxODBC's raised Warnings to
|
||||
emit Python standard warnings.
|
||||
"""
|
||||
from mx.ODBC.Error import Warning as MxOdbcWarning
|
||||
|
||||
def error_handler(connection, cursor, errorclass, errorvalue):
|
||||
if issubclass(errorclass, MxOdbcWarning):
|
||||
errorclass.__bases__ = (Warning,)
|
||||
warnings.warn(message=str(errorvalue),
|
||||
category=errorclass,
|
||||
stacklevel=2)
|
||||
else:
|
||||
raise errorclass(errorvalue)
|
||||
return error_handler
|
||||
|
||||
def create_connect_args(self, url):
|
||||
""" Return a tuple of *args,**kwargs for creating a connection.
|
||||
|
||||
The mxODBC 3.x connection constructor looks like this:
|
||||
|
||||
connect(dsn, user='', password='',
|
||||
clear_auto_commit=1, errorhandler=None)
|
||||
|
||||
This method translates the values in the provided uri
|
||||
into args and kwargs needed to instantiate an mxODBC Connection.
|
||||
|
||||
The arg 'errorhandler' is not used by SQLAlchemy and will
|
||||
not be populated.
|
||||
|
||||
"""
|
||||
opts = url.translate_connect_args(username='user')
|
||||
opts.update(url.query)
|
||||
args = opts.pop('host')
|
||||
opts.pop('port', None)
|
||||
opts.pop('database', None)
|
||||
return (args,), opts
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
# TODO: eGenix recommends checking connection.closed here
|
||||
# Does that detect dropped connections ?
|
||||
if isinstance(e, self.dbapi.ProgrammingError):
|
||||
return "connection already closed" in str(e)
|
||||
elif isinstance(e, self.dbapi.Error):
|
||||
return '[08S01]' in str(e)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
# eGenix suggests using conn.dbms_version instead
|
||||
# of what we're doing here
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
# 18 == pyodbc.SQL_DBMS_VER
|
||||
for n in r.split(dbapi_con.getinfo(18)[1]):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
def _get_direct(self, context):
|
||||
if context:
|
||||
native_odbc_execute = context.execution_options.\
|
||||
get('native_odbc_execute', 'auto')
|
||||
# default to direct=True in all cases, is more generally
|
||||
# compatible especially with SQL Server
|
||||
return False if native_odbc_execute is True else True
|
||||
else:
|
||||
return True
|
||||
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
cursor.executemany(
|
||||
statement, parameters, direct=self._get_direct(context))
|
||||
|
||||
def do_execute(self, cursor, statement, parameters, context=None):
|
||||
cursor.execute(statement, parameters, direct=self._get_direct(context))
|
|
@ -1,144 +0,0 @@
|
|||
# connectors/mysqldb.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Define behaviors common to MySQLdb dialects.
|
||||
|
||||
Currently includes MySQL and Drizzle.
|
||||
|
||||
"""
|
||||
|
||||
from . import Connector
|
||||
from ..engine import base as engine_base, default
|
||||
from ..sql import operators as sql_operators
|
||||
from .. import exc, log, schema, sql, types as sqltypes, util, processors
|
||||
import re
|
||||
|
||||
|
||||
# the subclassing of Connector by all classes
|
||||
# here is not strictly necessary
|
||||
|
||||
|
||||
class MySQLDBExecutionContext(Connector):
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
if hasattr(self, '_rowcount'):
|
||||
return self._rowcount
|
||||
else:
|
||||
return self.cursor.rowcount
|
||||
|
||||
|
||||
class MySQLDBCompiler(Connector):
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
return self.process(binary.left, **kw) + " %% " + \
|
||||
self.process(binary.right, **kw)
|
||||
|
||||
def post_process_text(self, text):
|
||||
return text.replace('%', '%%')
|
||||
|
||||
|
||||
class MySQLDBIdentifierPreparer(Connector):
|
||||
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value.replace("%", "%%")
|
||||
|
||||
|
||||
class MySQLDBConnector(Connector):
|
||||
driver = 'mysqldb'
|
||||
supports_unicode_statements = False
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
default_paramstyle = 'format'
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
# is overridden when pymysql is used
|
||||
return __import__('MySQLdb')
|
||||
|
||||
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
rowcount = cursor.executemany(statement, parameters)
|
||||
if context is not None:
|
||||
context._rowcount = rowcount
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(database='db', username='user',
|
||||
password='passwd')
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'compress', bool)
|
||||
util.coerce_kw_type(opts, 'connect_timeout', int)
|
||||
util.coerce_kw_type(opts, 'read_timeout', int)
|
||||
util.coerce_kw_type(opts, 'client_flag', int)
|
||||
util.coerce_kw_type(opts, 'local_infile', int)
|
||||
# Note: using either of the below will cause all strings to be returned
|
||||
# as Unicode, both in raw SQL operations and with column types like
|
||||
# String and MSString.
|
||||
util.coerce_kw_type(opts, 'use_unicode', bool)
|
||||
util.coerce_kw_type(opts, 'charset', str)
|
||||
|
||||
# Rich values 'cursorclass' and 'conv' are not supported via
|
||||
# query string.
|
||||
|
||||
ssl = {}
|
||||
keys = ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']
|
||||
for key in keys:
|
||||
if key in opts:
|
||||
ssl[key[4:]] = opts[key]
|
||||
util.coerce_kw_type(ssl, key[4:], str)
|
||||
del opts[key]
|
||||
if ssl:
|
||||
opts['ssl'] = ssl
|
||||
|
||||
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
|
||||
# supports_sane_rowcount.
|
||||
client_flag = opts.get('client_flag', 0)
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
CLIENT_FLAGS = __import__(
|
||||
self.dbapi.__name__ + '.constants.CLIENT'
|
||||
).constants.CLIENT
|
||||
client_flag |= CLIENT_FLAGS.FOUND_ROWS
|
||||
except (AttributeError, ImportError):
|
||||
self.supports_sane_rowcount = False
|
||||
opts['client_flag'] = client_flag
|
||||
return [[], opts]
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
for n in r.split(dbapi_con.get_server_info()):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.args[0]
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
try:
|
||||
# note: the SQL here would be
|
||||
# "SHOW VARIABLES LIKE 'character_set%%'"
|
||||
cset_name = connection.connection.character_set_name
|
||||
except AttributeError:
|
||||
util.warn(
|
||||
"No 'character_set_name' can be detected with "
|
||||
"this MySQL-Python version; "
|
||||
"please upgrade to a recent version of MySQL-Python. "
|
||||
"Assuming latin1.")
|
||||
return 'latin1'
|
||||
else:
|
||||
return cset_name()
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
# connectors/pyodbc.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from . import Connector
|
||||
from .. import util
|
||||
|
||||
|
||||
import sys
|
||||
import re
|
||||
|
||||
|
||||
class PyODBCConnector(Connector):
|
||||
driver = 'pyodbc'
|
||||
|
||||
supports_sane_multi_rowcount = False
|
||||
|
||||
if util.py2k:
|
||||
# PyODBC unicode is broken on UCS-4 builds
|
||||
supports_unicode = sys.maxunicode == 65535
|
||||
supports_unicode_statements = supports_unicode
|
||||
|
||||
supports_native_decimal = True
|
||||
default_paramstyle = 'named'
|
||||
|
||||
# for non-DSN connections, this should
|
||||
# hold the desired driver name
|
||||
pyodbc_driver_name = None
|
||||
|
||||
# will be set to True after initialize()
|
||||
# if the freetds.so is detected
|
||||
freetds = False
|
||||
|
||||
# will be set to the string version of
|
||||
# the FreeTDS driver if freetds is detected
|
||||
freetds_driver_version = None
|
||||
|
||||
# will be set to True after initialize()
|
||||
# if the libessqlsrv.so is detected
|
||||
easysoft = False
|
||||
|
||||
def __init__(self, supports_unicode_binds=None, **kw):
|
||||
super(PyODBCConnector, self).__init__(**kw)
|
||||
self._user_supports_unicode_binds = supports_unicode_binds
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('pyodbc')
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
opts.update(url.query)
|
||||
|
||||
keys = opts
|
||||
query = url.query
|
||||
|
||||
connect_args = {}
|
||||
for param in ('ansi', 'unicode_results', 'autocommit'):
|
||||
if param in keys:
|
||||
connect_args[param] = util.asbool(keys.pop(param))
|
||||
|
||||
if 'odbc_connect' in keys:
|
||||
connectors = [util.unquote_plus(keys.pop('odbc_connect'))]
|
||||
else:
|
||||
dsn_connection = 'dsn' in keys or \
|
||||
('host' in keys and 'database' not in keys)
|
||||
if dsn_connection:
|
||||
connectors = ['dsn=%s' % (keys.pop('host', '') or \
|
||||
keys.pop('dsn', ''))]
|
||||
else:
|
||||
port = ''
|
||||
if 'port' in keys and not 'port' in query:
|
||||
port = ',%d' % int(keys.pop('port'))
|
||||
|
||||
connectors = ["DRIVER={%s}" %
|
||||
keys.pop('driver', self.pyodbc_driver_name),
|
||||
'Server=%s%s' % (keys.pop('host', ''), port),
|
||||
'Database=%s' % keys.pop('database', '')]
|
||||
|
||||
user = keys.pop("user", None)
|
||||
if user:
|
||||
connectors.append("UID=%s" % user)
|
||||
connectors.append("PWD=%s" % keys.pop('password', ''))
|
||||
else:
|
||||
connectors.append("Trusted_Connection=Yes")
|
||||
|
||||
# if set to 'Yes', the ODBC layer will try to automagically
|
||||
# convert textual data from your database encoding to your
|
||||
# client encoding. This should obviously be set to 'No' if
|
||||
# you query a cp1253 encoded database from a latin1 client...
|
||||
if 'odbc_autotranslate' in keys:
|
||||
connectors.append("AutoTranslate=%s" %
|
||||
keys.pop("odbc_autotranslate"))
|
||||
|
||||
connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()])
|
||||
return [[";".join(connectors)], connect_args]
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.ProgrammingError):
|
||||
return "The cursor's connection has been closed." in str(e) or \
|
||||
'Attempt to use a closed connection.' in str(e)
|
||||
elif isinstance(e, self.dbapi.Error):
|
||||
return '[08S01]' in str(e)
|
||||
else:
|
||||
return False
|
||||
|
||||
def initialize(self, connection):
|
||||
# determine FreeTDS first. can't issue SQL easily
|
||||
# without getting unicode_statements/binds set up.
|
||||
|
||||
pyodbc = self.dbapi
|
||||
|
||||
dbapi_con = connection.connection
|
||||
|
||||
_sql_driver_name = dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)
|
||||
self.freetds = bool(re.match(r".*libtdsodbc.*\.so", _sql_driver_name
|
||||
))
|
||||
self.easysoft = bool(re.match(r".*libessqlsrv.*\.so", _sql_driver_name
|
||||
))
|
||||
|
||||
if self.freetds:
|
||||
self.freetds_driver_version = dbapi_con.getinfo(
|
||||
pyodbc.SQL_DRIVER_VER)
|
||||
|
||||
self.supports_unicode_statements = (
|
||||
not util.py2k or
|
||||
(not self.freetds and not self.easysoft)
|
||||
)
|
||||
|
||||
if self._user_supports_unicode_binds is not None:
|
||||
self.supports_unicode_binds = self._user_supports_unicode_binds
|
||||
elif util.py2k:
|
||||
self.supports_unicode_binds = (
|
||||
not self.freetds or self.freetds_driver_version >= '0.91'
|
||||
) and not self.easysoft
|
||||
else:
|
||||
self.supports_unicode_binds = True
|
||||
|
||||
# run other initialization which asks for user name, etc.
|
||||
super(PyODBCConnector, self).initialize(connection)
|
||||
|
||||
def _dbapi_version(self):
|
||||
if not self.dbapi:
|
||||
return ()
|
||||
return self._parse_dbapi_version(self.dbapi.version)
|
||||
|
||||
def _parse_dbapi_version(self, vers):
|
||||
m = re.match(
|
||||
r'(?:py.*-)?([\d\.]+)(?:-(\w+))?',
|
||||
vers
|
||||
)
|
||||
if not m:
|
||||
return ()
|
||||
vers = tuple([int(x) for x in m.group(1).split(".")])
|
||||
if m.group(2):
|
||||
vers += (m.group(2),)
|
||||
return vers
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
|
@ -1,59 +0,0 @@
|
|||
# connectors/zxJDBC.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import sys
|
||||
from . import Connector
|
||||
|
||||
|
||||
class ZxJDBCConnector(Connector):
|
||||
driver = 'zxjdbc'
|
||||
|
||||
supports_sane_rowcount = False
|
||||
supports_sane_multi_rowcount = False
|
||||
|
||||
supports_unicode_binds = True
|
||||
supports_unicode_statements = sys.version > '2.5.0+'
|
||||
description_encoding = None
|
||||
default_paramstyle = 'qmark'
|
||||
|
||||
jdbc_db_name = None
|
||||
jdbc_driver_name = None
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
from com.ziclix.python.sql import zxJDBC
|
||||
return zxJDBC
|
||||
|
||||
def _driver_kwargs(self):
|
||||
"""Return kw arg dict to be sent to connect()."""
|
||||
return {}
|
||||
|
||||
def _create_jdbc_url(self, url):
|
||||
"""Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`"""
|
||||
return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host,
|
||||
url.port is not None
|
||||
and ':%s' % url.port or '',
|
||||
url.database)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = self._driver_kwargs()
|
||||
opts.update(url.query)
|
||||
return [
|
||||
[self._create_jdbc_url(url),
|
||||
url.username, url.password,
|
||||
self.jdbc_driver_name],
|
||||
opts]
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if not isinstance(e, self.dbapi.ProgrammingError):
|
||||
return False
|
||||
e = str(e)
|
||||
return 'connection is closed' in e or 'cursor is closed' in e
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
# use connection.connection.dbversion, and parse appropriately
|
||||
# to get a tuple
|
||||
raise NotImplementedError()
|
|
@ -1,31 +0,0 @@
|
|||
# databases/__init__.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Include imports from the sqlalchemy.dialects package for backwards
|
||||
compatibility with pre 0.6 versions.
|
||||
|
||||
"""
|
||||
from ..dialects.sqlite import base as sqlite
|
||||
from ..dialects.postgresql import base as postgresql
|
||||
postgres = postgresql
|
||||
from ..dialects.mysql import base as mysql
|
||||
from ..dialects.drizzle import base as drizzle
|
||||
from ..dialects.oracle import base as oracle
|
||||
from ..dialects.firebird import base as firebird
|
||||
from ..dialects.mssql import base as mssql
|
||||
from ..dialects.sybase import base as sybase
|
||||
|
||||
|
||||
__all__ = (
|
||||
'drizzle',
|
||||
'firebird',
|
||||
'mssql',
|
||||
'mysql',
|
||||
'postgresql',
|
||||
'sqlite',
|
||||
'oracle',
|
||||
'sybase',
|
||||
)
|
|
@ -1,44 +0,0 @@
|
|||
# dialects/__init__.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
__all__ = (
|
||||
'drizzle',
|
||||
'firebird',
|
||||
'mssql',
|
||||
'mysql',
|
||||
'oracle',
|
||||
'postgresql',
|
||||
'sqlite',
|
||||
'sybase',
|
||||
)
|
||||
|
||||
from .. import util
|
||||
|
||||
def _auto_fn(name):
|
||||
"""default dialect importer.
|
||||
|
||||
plugs into the :class:`.PluginLoader`
|
||||
as a first-hit system.
|
||||
|
||||
"""
|
||||
if "." in name:
|
||||
dialect, driver = name.split(".")
|
||||
else:
|
||||
dialect = name
|
||||
driver = "base"
|
||||
try:
|
||||
module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
module = getattr(module, dialect)
|
||||
if hasattr(module, driver):
|
||||
module = getattr(module, driver)
|
||||
return lambda: module.dialect
|
||||
else:
|
||||
return None
|
||||
|
||||
registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
|
|
@ -1,22 +0,0 @@
|
|||
from sqlalchemy.dialects.drizzle import base, mysqldb
|
||||
|
||||
base.dialect = mysqldb.dialect
|
||||
|
||||
from sqlalchemy.dialects.drizzle.base import \
|
||||
BIGINT, BINARY, BLOB, \
|
||||
BOOLEAN, CHAR, DATE, \
|
||||
DATETIME, DECIMAL, DOUBLE, \
|
||||
ENUM, FLOAT, INTEGER, \
|
||||
NUMERIC, REAL, TEXT, \
|
||||
TIME, TIMESTAMP, VARBINARY, \
|
||||
VARCHAR, dialect
|
||||
|
||||
__all__ = (
|
||||
'BIGINT', 'BINARY', 'BLOB',
|
||||
'BOOLEAN', 'CHAR', 'DATE',
|
||||
'DATETIME', 'DECIMAL', 'DOUBLE',
|
||||
'ENUM', 'FLOAT', 'INTEGER',
|
||||
'NUMERIC', 'REAL', 'TEXT',
|
||||
'TIME', 'TIMESTAMP', 'VARBINARY',
|
||||
'VARCHAR', 'dialect'
|
||||
)
|
|
@ -1,498 +0,0 @@
|
|||
# drizzle/base.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2010-2011 Monty Taylor <mordred@inaugust.com>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: drizzle
|
||||
:name: Drizzle
|
||||
|
||||
Drizzle is a variant of MySQL. Unlike MySQL, Drizzle's default storage engine
|
||||
is InnoDB (transactions, foreign-keys) rather than MyISAM. For more
|
||||
`Notable Differences <http://docs.drizzle.org/mysql_differences.html>`_, visit
|
||||
the `Drizzle Documentation <http://docs.drizzle.org/index.html>`_.
|
||||
|
||||
The SQLAlchemy Drizzle dialect leans heavily on the MySQL dialect, so much of
|
||||
the :doc:`SQLAlchemy MySQL <mysql>` documentation is also relevant.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy import log
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.dialects.mysql import base as mysql_dialect
|
||||
from sqlalchemy.types import DATE, DATETIME, BOOLEAN, TIME, \
|
||||
BLOB, BINARY, VARBINARY
|
||||
|
||||
|
||||
class _NumericType(object):
|
||||
"""Base for Drizzle numeric types."""
|
||||
|
||||
def __init__(self, **kw):
|
||||
super(_NumericType, self).__init__(**kw)
|
||||
|
||||
|
||||
class _FloatType(_NumericType, sqltypes.Float):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
if isinstance(self, (REAL, DOUBLE)) and \
|
||||
(
|
||||
(precision is None and scale is not None) or
|
||||
(precision is not None and scale is None)
|
||||
):
|
||||
raise exc.ArgumentError(
|
||||
"You must specify both precision and scale or omit "
|
||||
"both altogether.")
|
||||
|
||||
super(_FloatType, self).__init__(precision=precision,
|
||||
asdecimal=asdecimal, **kw)
|
||||
self.scale = scale
|
||||
|
||||
|
||||
class _StringType(mysql_dialect._StringType):
|
||||
"""Base for Drizzle string types."""
|
||||
|
||||
def __init__(self, collation=None, binary=False, **kw):
|
||||
kw['national'] = False
|
||||
super(_StringType, self).__init__(collation=collation, binary=binary,
|
||||
**kw)
|
||||
|
||||
|
||||
class NUMERIC(_NumericType, sqltypes.NUMERIC):
|
||||
"""Drizzle NUMERIC type."""
|
||||
|
||||
__visit_name__ = 'NUMERIC'
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a NUMERIC.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
"""
|
||||
|
||||
super(NUMERIC, self).__init__(precision=precision, scale=scale,
|
||||
asdecimal=asdecimal, **kw)
|
||||
|
||||
|
||||
class DECIMAL(_NumericType, sqltypes.DECIMAL):
|
||||
"""Drizzle DECIMAL type."""
|
||||
|
||||
__visit_name__ = 'DECIMAL'
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a DECIMAL.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
"""
|
||||
super(DECIMAL, self).__init__(precision=precision, scale=scale,
|
||||
asdecimal=asdecimal, **kw)
|
||||
|
||||
|
||||
class DOUBLE(_FloatType):
|
||||
"""Drizzle DOUBLE type."""
|
||||
|
||||
__visit_name__ = 'DOUBLE'
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a DOUBLE.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
"""
|
||||
|
||||
super(DOUBLE, self).__init__(precision=precision, scale=scale,
|
||||
asdecimal=asdecimal, **kw)
|
||||
|
||||
|
||||
class REAL(_FloatType, sqltypes.REAL):
|
||||
"""Drizzle REAL type."""
|
||||
|
||||
__visit_name__ = 'REAL'
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a REAL.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
"""
|
||||
|
||||
super(REAL, self).__init__(precision=precision, scale=scale,
|
||||
asdecimal=asdecimal, **kw)
|
||||
|
||||
|
||||
class FLOAT(_FloatType, sqltypes.FLOAT):
|
||||
"""Drizzle FLOAT type."""
|
||||
|
||||
__visit_name__ = 'FLOAT'
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
|
||||
"""Construct a FLOAT.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
"""
|
||||
|
||||
super(FLOAT, self).__init__(precision=precision, scale=scale,
|
||||
asdecimal=asdecimal, **kw)
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
|
||||
class INTEGER(sqltypes.INTEGER):
|
||||
"""Drizzle INTEGER type."""
|
||||
|
||||
__visit_name__ = 'INTEGER'
|
||||
|
||||
def __init__(self, **kw):
|
||||
"""Construct an INTEGER."""
|
||||
|
||||
super(INTEGER, self).__init__(**kw)
|
||||
|
||||
|
||||
class BIGINT(sqltypes.BIGINT):
|
||||
"""Drizzle BIGINTEGER type."""
|
||||
|
||||
__visit_name__ = 'BIGINT'
|
||||
|
||||
def __init__(self, **kw):
|
||||
"""Construct a BIGINTEGER."""
|
||||
|
||||
super(BIGINT, self).__init__(**kw)
|
||||
|
||||
|
||||
class TIME(mysql_dialect.TIME):
|
||||
"""Drizzle TIME type."""
|
||||
|
||||
|
||||
class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
"""Drizzle TIMESTAMP type."""
|
||||
|
||||
__visit_name__ = 'TIMESTAMP'
|
||||
|
||||
|
||||
class TEXT(_StringType, sqltypes.TEXT):
|
||||
"""Drizzle TEXT type, for text up to 2^16 characters."""
|
||||
|
||||
__visit_name__ = 'TEXT'
|
||||
|
||||
def __init__(self, length=None, **kw):
|
||||
"""Construct a TEXT.
|
||||
|
||||
:param length: Optional, if provided the server may optimize storage
|
||||
by substituting the smallest TEXT type sufficient to store
|
||||
``length`` characters.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
|
||||
super(TEXT, self).__init__(length=length, **kw)
|
||||
|
||||
|
||||
class VARCHAR(_StringType, sqltypes.VARCHAR):
|
||||
"""Drizzle VARCHAR type, for variable-length character data."""
|
||||
|
||||
__visit_name__ = 'VARCHAR'
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct a VARCHAR.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
|
||||
super(VARCHAR, self).__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class CHAR(_StringType, sqltypes.CHAR):
|
||||
"""Drizzle CHAR type, for fixed-length character data."""
|
||||
|
||||
__visit_name__ = 'CHAR'
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct a CHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
|
||||
:param binary: Optional, use the default binary collation for the
|
||||
national character set. This does not affect the type of data
|
||||
stored, use a BINARY type for binary data.
|
||||
|
||||
:param collation: Optional, request a particular collation. Must be
|
||||
compatible with the national character set.
|
||||
|
||||
"""
|
||||
|
||||
super(CHAR, self).__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class ENUM(mysql_dialect.ENUM):
|
||||
"""Drizzle ENUM type."""
|
||||
|
||||
def __init__(self, *enums, **kw):
|
||||
"""Construct an ENUM.
|
||||
|
||||
Example:
|
||||
|
||||
Column('myenum', ENUM("foo", "bar", "baz"))
|
||||
|
||||
:param enums: The range of valid values for this ENUM. Values will be
|
||||
quoted when generating the schema according to the quoting flag (see
|
||||
below).
|
||||
|
||||
:param strict: Defaults to False: ensure that a given value is in this
|
||||
ENUM's range of permissible values when inserting or updating rows.
|
||||
Note that Drizzle will not raise a fatal error if you attempt to
|
||||
store an out of range value- an alternate value will be stored
|
||||
instead.
|
||||
(See Drizzle ENUM documentation.)
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
:param quoting: Defaults to 'auto': automatically determine enum value
|
||||
quoting. If all enum values are surrounded by the same quoting
|
||||
character, then use 'quoted' mode. Otherwise, use 'unquoted' mode.
|
||||
|
||||
'quoted': values in enums are already quoted, they will be used
|
||||
directly when generating the schema - this usage is deprecated.
|
||||
|
||||
'unquoted': values in enums are not quoted, they will be escaped and
|
||||
surrounded by single quotes when generating the schema.
|
||||
|
||||
Previous versions of this type always required manually quoted
|
||||
values to be supplied; future versions will always quote the string
|
||||
literals for you. This is a transitional option.
|
||||
|
||||
"""
|
||||
|
||||
super(ENUM, self).__init__(*enums, **kw)
|
||||
|
||||
|
||||
class _DrizzleBoolean(sqltypes.Boolean):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.NUMERIC
|
||||
|
||||
|
||||
colspecs = {
|
||||
sqltypes.Numeric: NUMERIC,
|
||||
sqltypes.Float: FLOAT,
|
||||
sqltypes.Time: TIME,
|
||||
sqltypes.Enum: ENUM,
|
||||
sqltypes.Boolean: _DrizzleBoolean,
|
||||
}
|
||||
|
||||
|
||||
# All the types we have in Drizzle
|
||||
ischema_names = {
|
||||
'BIGINT': BIGINT,
|
||||
'BINARY': BINARY,
|
||||
'BLOB': BLOB,
|
||||
'BOOLEAN': BOOLEAN,
|
||||
'CHAR': CHAR,
|
||||
'DATE': DATE,
|
||||
'DATETIME': DATETIME,
|
||||
'DECIMAL': DECIMAL,
|
||||
'DOUBLE': DOUBLE,
|
||||
'ENUM': ENUM,
|
||||
'FLOAT': FLOAT,
|
||||
'INT': INTEGER,
|
||||
'INTEGER': INTEGER,
|
||||
'NUMERIC': NUMERIC,
|
||||
'TEXT': TEXT,
|
||||
'TIME': TIME,
|
||||
'TIMESTAMP': TIMESTAMP,
|
||||
'VARBINARY': VARBINARY,
|
||||
'VARCHAR': VARCHAR,
|
||||
}
|
||||
|
||||
|
||||
class DrizzleCompiler(mysql_dialect.MySQLCompiler):
|
||||
|
||||
def visit_typeclause(self, typeclause):
|
||||
type_ = typeclause.type.dialect_impl(self.dialect)
|
||||
if isinstance(type_, sqltypes.Integer):
|
||||
return 'INTEGER'
|
||||
else:
|
||||
return super(DrizzleCompiler, self).visit_typeclause(typeclause)
|
||||
|
||||
def visit_cast(self, cast, **kwargs):
|
||||
type_ = self.process(cast.typeclause)
|
||||
if type_ is None:
|
||||
return self.process(cast.clause)
|
||||
|
||||
return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
|
||||
|
||||
|
||||
class DrizzleDDLCompiler(mysql_dialect.MySQLDDLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class DrizzleTypeCompiler(mysql_dialect.MySQLTypeCompiler):
|
||||
def _extend_numeric(self, type_, spec):
|
||||
return spec
|
||||
|
||||
def _extend_string(self, type_, defaults, spec):
|
||||
"""Extend a string-type declaration with standard SQL
|
||||
COLLATE annotations and Drizzle specific extensions.
|
||||
|
||||
"""
|
||||
|
||||
def attr(name):
|
||||
return getattr(type_, name, defaults.get(name))
|
||||
|
||||
if attr('collation'):
|
||||
collation = 'COLLATE %s' % type_.collation
|
||||
elif attr('binary'):
|
||||
collation = 'BINARY'
|
||||
else:
|
||||
collation = None
|
||||
|
||||
return ' '.join([c for c in (spec, collation)
|
||||
if c is not None])
|
||||
|
||||
def visit_NCHAR(self, type):
|
||||
raise NotImplementedError("Drizzle does not support NCHAR")
|
||||
|
||||
def visit_NVARCHAR(self, type):
|
||||
raise NotImplementedError("Drizzle does not support NVARCHAR")
|
||||
|
||||
def visit_FLOAT(self, type_):
|
||||
if type_.scale is not None and type_.precision is not None:
|
||||
return "FLOAT(%s, %s)" % (type_.precision, type_.scale)
|
||||
else:
|
||||
return "FLOAT"
|
||||
|
||||
def visit_BOOLEAN(self, type_):
|
||||
return "BOOLEAN"
|
||||
|
||||
def visit_BLOB(self, type_):
|
||||
return "BLOB"
|
||||
|
||||
|
||||
class DrizzleExecutionContext(mysql_dialect.MySQLExecutionContext):
|
||||
pass
|
||||
|
||||
|
||||
class DrizzleIdentifierPreparer(mysql_dialect.MySQLIdentifierPreparer):
|
||||
pass
|
||||
|
||||
|
||||
@log.class_logger
|
||||
class DrizzleDialect(mysql_dialect.MySQLDialect):
|
||||
"""Details of the Drizzle dialect.
|
||||
|
||||
Not used directly in application code.
|
||||
"""
|
||||
|
||||
name = 'drizzle'
|
||||
|
||||
_supports_cast = True
|
||||
supports_sequences = False
|
||||
supports_native_boolean = True
|
||||
supports_views = False
|
||||
|
||||
default_paramstyle = 'format'
|
||||
colspecs = colspecs
|
||||
|
||||
statement_compiler = DrizzleCompiler
|
||||
ddl_compiler = DrizzleDDLCompiler
|
||||
type_compiler = DrizzleTypeCompiler
|
||||
ischema_names = ischema_names
|
||||
preparer = DrizzleIdentifierPreparer
|
||||
|
||||
def on_connect(self):
|
||||
"""Force autocommit - Drizzle Bug#707842 doesn't set this properly"""
|
||||
|
||||
def connect(conn):
|
||||
conn.autocommit(False)
|
||||
return connect
|
||||
|
||||
@reflection.cache
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
"""Return a Unicode SHOW TABLES from a given schema."""
|
||||
|
||||
if schema is not None:
|
||||
current_schema = schema
|
||||
else:
|
||||
current_schema = self.default_schema_name
|
||||
|
||||
charset = 'utf8'
|
||||
rp = connection.execute("SHOW TABLES FROM %s" %
|
||||
self.identifier_preparer.quote_identifier(current_schema))
|
||||
return [row[0] for row in self._compat_fetchall(rp, charset=charset)]
|
||||
|
||||
@reflection.cache
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
raise NotImplementedError
|
||||
|
||||
def _detect_casing(self, connection):
|
||||
"""Sniff out identifier case sensitivity.
|
||||
|
||||
Cached per-connection. This value can not change without a server
|
||||
restart.
|
||||
"""
|
||||
|
||||
return 0
|
||||
|
||||
def _detect_collations(self, connection):
|
||||
"""Pull the active COLLATIONS list from the server.
|
||||
|
||||
Cached per-connection.
|
||||
"""
|
||||
|
||||
collations = {}
|
||||
charset = self._connection_charset
|
||||
rs = connection.execute(
|
||||
'SELECT CHARACTER_SET_NAME, COLLATION_NAME FROM'
|
||||
' data_dictionary.COLLATIONS')
|
||||
for row in self._compat_fetchall(rs, charset):
|
||||
collations[row[0]] = row[1]
|
||||
return collations
|
||||
|
||||
def _detect_ansiquotes(self, connection):
|
||||
"""Detect and adjust for the ANSI_QUOTES sql mode."""
|
||||
|
||||
self._server_ansiquotes = False
|
||||
self._backslash_escapes = False
|
||||
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
"""
|
||||
.. dialect:: drizzle+mysqldb
|
||||
:name: MySQL-Python
|
||||
:dbapi: mysqldb
|
||||
:connectstring: drizzle+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: http://sourceforge.net/projects/mysql-python
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.drizzle.base import (
|
||||
DrizzleDialect,
|
||||
DrizzleExecutionContext,
|
||||
DrizzleCompiler,
|
||||
DrizzleIdentifierPreparer)
|
||||
from sqlalchemy.connectors.mysqldb import (
|
||||
MySQLDBExecutionContext,
|
||||
MySQLDBCompiler,
|
||||
MySQLDBIdentifierPreparer,
|
||||
MySQLDBConnector)
|
||||
|
||||
|
||||
class DrizzleExecutionContext_mysqldb(MySQLDBExecutionContext,
|
||||
DrizzleExecutionContext):
|
||||
pass
|
||||
|
||||
|
||||
class DrizzleCompiler_mysqldb(MySQLDBCompiler, DrizzleCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class DrizzleIdentifierPreparer_mysqldb(MySQLDBIdentifierPreparer,
|
||||
DrizzleIdentifierPreparer):
|
||||
pass
|
||||
|
||||
|
||||
class DrizzleDialect_mysqldb(MySQLDBConnector, DrizzleDialect):
|
||||
execution_ctx_cls = DrizzleExecutionContext_mysqldb
|
||||
statement_compiler = DrizzleCompiler_mysqldb
|
||||
preparer = DrizzleIdentifierPreparer_mysqldb
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
return 'utf8'
|
||||
|
||||
|
||||
dialect = DrizzleDialect_mysqldb
|
|
@ -1,20 +0,0 @@
|
|||
# firebird/__init__.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from sqlalchemy.dialects.firebird import base, kinterbasdb, fdb
|
||||
|
||||
base.dialect = fdb.dialect
|
||||
|
||||
from sqlalchemy.dialects.firebird.base import \
|
||||
SMALLINT, BIGINT, FLOAT, FLOAT, DATE, TIME, \
|
||||
TEXT, NUMERIC, FLOAT, TIMESTAMP, VARCHAR, CHAR, BLOB,\
|
||||
dialect
|
||||
|
||||
__all__ = (
|
||||
'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME',
|
||||
'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB',
|
||||
'dialect'
|
||||
)
|
|
@ -1,738 +0,0 @@
|
|||
# firebird/base.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: firebird
|
||||
:name: Firebird
|
||||
|
||||
Firebird Dialects
|
||||
-----------------
|
||||
|
||||
Firebird offers two distinct dialects_ (not to be confused with a
|
||||
SQLAlchemy ``Dialect``):
|
||||
|
||||
dialect 1
|
||||
This is the old syntax and behaviour, inherited from Interbase pre-6.0.
|
||||
|
||||
dialect 3
|
||||
This is the newer and supported syntax, introduced in Interbase 6.0.
|
||||
|
||||
The SQLAlchemy Firebird dialect detects these versions and
|
||||
adjusts its representation of SQL accordingly. However,
|
||||
support for dialect 1 is not well tested and probably has
|
||||
incompatibilities.
|
||||
|
||||
Locking Behavior
|
||||
----------------
|
||||
|
||||
Firebird locks tables aggressively. For this reason, a DROP TABLE may
|
||||
hang until other transactions are released. SQLAlchemy does its best
|
||||
to release transactions as quickly as possible. The most common cause
|
||||
of hanging transactions is a non-fully consumed result set, i.e.::
|
||||
|
||||
result = engine.execute("select * from table")
|
||||
row = result.fetchone()
|
||||
return
|
||||
|
||||
Where above, the ``ResultProxy`` has not been fully consumed. The
|
||||
connection will be returned to the pool and the transactional state
|
||||
rolled back once the Python garbage collector reclaims the objects
|
||||
which hold onto the connection, which often occurs asynchronously.
|
||||
The above use case can be alleviated by calling ``first()`` on the
|
||||
``ResultProxy`` which will fetch the first row and immediately close
|
||||
all remaining cursor/connection resources.
|
||||
|
||||
RETURNING support
|
||||
-----------------
|
||||
|
||||
Firebird 2.0 supports returning a result set from inserts, and 2.1
|
||||
extends that to deletes and updates. This is generically exposed by
|
||||
the SQLAlchemy ``returning()`` method, such as::
|
||||
|
||||
# INSERT..RETURNING
|
||||
result = table.insert().returning(table.c.col1, table.c.col2).\\
|
||||
values(name='foo')
|
||||
print result.fetchall()
|
||||
|
||||
# UPDATE..RETURNING
|
||||
raises = empl.update().returning(empl.c.id, empl.c.salary).\\
|
||||
where(empl.c.sales>100).\\
|
||||
values(dict(salary=empl.c.salary * 1.1))
|
||||
print raises.fetchall()
|
||||
|
||||
|
||||
.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html
|
||||
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import schema as sa_schema
|
||||
from sqlalchemy import exc, types as sqltypes, sql, util
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.engine import base, default, reflection
|
||||
from sqlalchemy.sql import compiler
|
||||
|
||||
|
||||
from sqlalchemy.types import (BIGINT, BLOB, DATE, FLOAT, INTEGER, NUMERIC,
|
||||
SMALLINT, TEXT, TIME, TIMESTAMP, Integer)
|
||||
|
||||
|
||||
RESERVED_WORDS = set([
|
||||
"active", "add", "admin", "after", "all", "alter", "and", "any", "as",
|
||||
"asc", "ascending", "at", "auto", "avg", "before", "begin", "between",
|
||||
"bigint", "bit_length", "blob", "both", "by", "case", "cast", "char",
|
||||
"character", "character_length", "char_length", "check", "close",
|
||||
"collate", "column", "commit", "committed", "computed", "conditional",
|
||||
"connect", "constraint", "containing", "count", "create", "cross",
|
||||
"cstring", "current", "current_connection", "current_date",
|
||||
"current_role", "current_time", "current_timestamp",
|
||||
"current_transaction", "current_user", "cursor", "database", "date",
|
||||
"day", "dec", "decimal", "declare", "default", "delete", "desc",
|
||||
"descending", "disconnect", "distinct", "do", "domain", "double",
|
||||
"drop", "else", "end", "entry_point", "escape", "exception",
|
||||
"execute", "exists", "exit", "external", "extract", "fetch", "file",
|
||||
"filter", "float", "for", "foreign", "from", "full", "function",
|
||||
"gdscode", "generator", "gen_id", "global", "grant", "group",
|
||||
"having", "hour", "if", "in", "inactive", "index", "inner",
|
||||
"input_type", "insensitive", "insert", "int", "integer", "into", "is",
|
||||
"isolation", "join", "key", "leading", "left", "length", "level",
|
||||
"like", "long", "lower", "manual", "max", "maximum_segment", "merge",
|
||||
"min", "minute", "module_name", "month", "names", "national",
|
||||
"natural", "nchar", "no", "not", "null", "numeric", "octet_length",
|
||||
"of", "on", "only", "open", "option", "or", "order", "outer",
|
||||
"output_type", "overflow", "page", "pages", "page_size", "parameter",
|
||||
"password", "plan", "position", "post_event", "precision", "primary",
|
||||
"privileges", "procedure", "protected", "rdb$db_key", "read", "real",
|
||||
"record_version", "recreate", "recursive", "references", "release",
|
||||
"reserv", "reserving", "retain", "returning_values", "returns",
|
||||
"revoke", "right", "rollback", "rows", "row_count", "savepoint",
|
||||
"schema", "second", "segment", "select", "sensitive", "set", "shadow",
|
||||
"shared", "singular", "size", "smallint", "snapshot", "some", "sort",
|
||||
"sqlcode", "stability", "start", "starting", "starts", "statistics",
|
||||
"sub_type", "sum", "suspend", "table", "then", "time", "timestamp",
|
||||
"to", "trailing", "transaction", "trigger", "trim", "uncommitted",
|
||||
"union", "unique", "update", "upper", "user", "using", "value",
|
||||
"values", "varchar", "variable", "varying", "view", "wait", "when",
|
||||
"where", "while", "with", "work", "write", "year",
|
||||
])
|
||||
|
||||
|
||||
class _StringType(sqltypes.String):
|
||||
"""Base for Firebird string types."""
|
||||
|
||||
def __init__(self, charset=None, **kw):
|
||||
self.charset = charset
|
||||
super(_StringType, self).__init__(**kw)
|
||||
|
||||
|
||||
class VARCHAR(_StringType, sqltypes.VARCHAR):
|
||||
"""Firebird VARCHAR type"""
|
||||
__visit_name__ = 'VARCHAR'
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
super(VARCHAR, self).__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class CHAR(_StringType, sqltypes.CHAR):
|
||||
"""Firebird CHAR type"""
|
||||
__visit_name__ = 'CHAR'
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
super(CHAR, self).__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class _FBDateTime(sqltypes.DateTime):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if type(value) == datetime.date:
|
||||
return datetime.datetime(value.year, value.month, value.day)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
colspecs = {
|
||||
sqltypes.DateTime: _FBDateTime
|
||||
}
|
||||
|
||||
ischema_names = {
|
||||
'SHORT': SMALLINT,
|
||||
'LONG': INTEGER,
|
||||
'QUAD': FLOAT,
|
||||
'FLOAT': FLOAT,
|
||||
'DATE': DATE,
|
||||
'TIME': TIME,
|
||||
'TEXT': TEXT,
|
||||
'INT64': BIGINT,
|
||||
'DOUBLE': FLOAT,
|
||||
'TIMESTAMP': TIMESTAMP,
|
||||
'VARYING': VARCHAR,
|
||||
'CSTRING': CHAR,
|
||||
'BLOB': BLOB,
|
||||
}
|
||||
|
||||
|
||||
# TODO: date conversion types (should be implemented as _FBDateTime,
|
||||
# _FBDate, etc. as bind/result functionality is required)
|
||||
|
||||
class FBTypeCompiler(compiler.GenericTypeCompiler):
|
||||
def visit_boolean(self, type_):
|
||||
return self.visit_SMALLINT(type_)
|
||||
|
||||
def visit_datetime(self, type_):
|
||||
return self.visit_TIMESTAMP(type_)
|
||||
|
||||
def visit_TEXT(self, type_):
|
||||
return "BLOB SUB_TYPE 1"
|
||||
|
||||
def visit_BLOB(self, type_):
|
||||
return "BLOB SUB_TYPE 0"
|
||||
|
||||
def _extend_string(self, type_, basic):
|
||||
charset = getattr(type_, 'charset', None)
|
||||
if charset is None:
|
||||
return basic
|
||||
else:
|
||||
return '%s CHARACTER SET %s' % (basic, charset)
|
||||
|
||||
def visit_CHAR(self, type_):
|
||||
basic = super(FBTypeCompiler, self).visit_CHAR(type_)
|
||||
return self._extend_string(type_, basic)
|
||||
|
||||
def visit_VARCHAR(self, type_):
|
||||
if not type_.length:
|
||||
raise exc.CompileError(
|
||||
"VARCHAR requires a length on dialect %s" %
|
||||
self.dialect.name)
|
||||
basic = super(FBTypeCompiler, self).visit_VARCHAR(type_)
|
||||
return self._extend_string(type_, basic)
|
||||
|
||||
|
||||
class FBCompiler(sql.compiler.SQLCompiler):
|
||||
"""Firebird specific idiosyncrasies"""
|
||||
|
||||
ansi_bind_rules = True
|
||||
|
||||
#def visit_contains_op_binary(self, binary, operator, **kw):
|
||||
# cant use CONTAINING b.c. it's case insensitive.
|
||||
|
||||
#def visit_notcontains_op_binary(self, binary, operator, **kw):
|
||||
# cant use NOT CONTAINING b.c. it's case insensitive.
|
||||
|
||||
def visit_now_func(self, fn, **kw):
|
||||
return "CURRENT_TIMESTAMP"
|
||||
|
||||
def visit_startswith_op_binary(self, binary, operator, **kw):
|
||||
return '%s STARTING WITH %s' % (
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.right._compiler_dispatch(self, **kw))
|
||||
|
||||
def visit_notstartswith_op_binary(self, binary, operator, **kw):
|
||||
return '%s NOT STARTING WITH %s' % (
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.right._compiler_dispatch(self, **kw))
|
||||
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
return "mod(%s, %s)" % (
|
||||
self.process(binary.left, **kw),
|
||||
self.process(binary.right, **kw))
|
||||
|
||||
def visit_alias(self, alias, asfrom=False, **kwargs):
|
||||
if self.dialect._version_two:
|
||||
return super(FBCompiler, self).\
|
||||
visit_alias(alias, asfrom=asfrom, **kwargs)
|
||||
else:
|
||||
# Override to not use the AS keyword which FB 1.5 does not like
|
||||
if asfrom:
|
||||
alias_name = isinstance(alias.name,
|
||||
expression._truncated_label) and \
|
||||
self._truncated_identifier("alias",
|
||||
alias.name) or alias.name
|
||||
|
||||
return self.process(
|
||||
alias.original, asfrom=asfrom, **kwargs) + \
|
||||
" " + \
|
||||
self.preparer.format_alias(alias, alias_name)
|
||||
else:
|
||||
return self.process(alias.original, **kwargs)
|
||||
|
||||
def visit_substring_func(self, func, **kw):
|
||||
s = self.process(func.clauses.clauses[0])
|
||||
start = self.process(func.clauses.clauses[1])
|
||||
if len(func.clauses.clauses) > 2:
|
||||
length = self.process(func.clauses.clauses[2])
|
||||
return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
|
||||
else:
|
||||
return "SUBSTRING(%s FROM %s)" % (s, start)
|
||||
|
||||
def visit_length_func(self, function, **kw):
|
||||
if self.dialect._version_two:
|
||||
return "char_length" + self.function_argspec(function)
|
||||
else:
|
||||
return "strlen" + self.function_argspec(function)
|
||||
|
||||
visit_char_length_func = visit_length_func
|
||||
|
||||
def function_argspec(self, func, **kw):
|
||||
# TODO: this probably will need to be
|
||||
# narrowed to a fixed list, some no-arg functions
|
||||
# may require parens - see similar example in the oracle
|
||||
# dialect
|
||||
if func.clauses is not None and len(func.clauses):
|
||||
return self.process(func.clause_expr, **kw)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def default_from(self):
|
||||
return " FROM rdb$database"
|
||||
|
||||
def visit_sequence(self, seq):
|
||||
return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
|
||||
|
||||
def get_select_precolumns(self, select):
|
||||
"""Called when building a ``SELECT`` statement, position is just
|
||||
before column list Firebird puts the limit and offset right
|
||||
after the ``SELECT``...
|
||||
"""
|
||||
|
||||
result = ""
|
||||
if select._limit:
|
||||
result += "FIRST %s " % self.process(sql.literal(select._limit))
|
||||
if select._offset:
|
||||
result += "SKIP %s " % self.process(sql.literal(select._offset))
|
||||
if select._distinct:
|
||||
result += "DISTINCT "
|
||||
return result
|
||||
|
||||
def limit_clause(self, select):
|
||||
"""Already taken care of in the `get_select_precolumns` method."""
|
||||
|
||||
return ""
|
||||
|
||||
def returning_clause(self, stmt, returning_cols):
|
||||
columns = [
|
||||
self._label_select_column(None, c, True, False, {})
|
||||
for c in expression._select_iterables(returning_cols)
|
||||
]
|
||||
|
||||
return 'RETURNING ' + ', '.join(columns)
|
||||
|
||||
|
||||
class FBDDLCompiler(sql.compiler.DDLCompiler):
|
||||
"""Firebird syntactic idiosyncrasies"""
|
||||
|
||||
def visit_create_sequence(self, create):
|
||||
"""Generate a ``CREATE GENERATOR`` statement for the sequence."""
|
||||
|
||||
# no syntax for these
|
||||
# http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
|
||||
if create.element.start is not None:
|
||||
raise NotImplemented(
|
||||
"Firebird SEQUENCE doesn't support START WITH")
|
||||
if create.element.increment is not None:
|
||||
raise NotImplemented(
|
||||
"Firebird SEQUENCE doesn't support INCREMENT BY")
|
||||
|
||||
if self.dialect._version_two:
|
||||
return "CREATE SEQUENCE %s" % \
|
||||
self.preparer.format_sequence(create.element)
|
||||
else:
|
||||
return "CREATE GENERATOR %s" % \
|
||||
self.preparer.format_sequence(create.element)
|
||||
|
||||
def visit_drop_sequence(self, drop):
|
||||
"""Generate a ``DROP GENERATOR`` statement for the sequence."""
|
||||
|
||||
if self.dialect._version_two:
|
||||
return "DROP SEQUENCE %s" % \
|
||||
self.preparer.format_sequence(drop.element)
|
||||
else:
|
||||
return "DROP GENERATOR %s" % \
|
||||
self.preparer.format_sequence(drop.element)
|
||||
|
||||
|
||||
class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
|
||||
"""Install Firebird specific reserved words."""
|
||||
|
||||
reserved_words = RESERVED_WORDS
|
||||
illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(['_'])
|
||||
|
||||
def __init__(self, dialect):
|
||||
super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
|
||||
|
||||
|
||||
class FBExecutionContext(default.DefaultExecutionContext):
|
||||
def fire_sequence(self, seq, type_):
|
||||
"""Get the next value from the sequence using ``gen_id()``."""
|
||||
|
||||
return self._execute_scalar(
|
||||
"SELECT gen_id(%s, 1) FROM rdb$database" %
|
||||
self.dialect.identifier_preparer.format_sequence(seq),
|
||||
type_
|
||||
)
|
||||
|
||||
|
||||
class FBDialect(default.DefaultDialect):
|
||||
"""Firebird dialect"""
|
||||
|
||||
name = 'firebird'
|
||||
|
||||
max_identifier_length = 31
|
||||
|
||||
supports_sequences = True
|
||||
sequences_optional = False
|
||||
supports_default_values = True
|
||||
postfetch_lastrowid = False
|
||||
|
||||
supports_native_boolean = False
|
||||
|
||||
requires_name_normalize = True
|
||||
supports_empty_insert = False
|
||||
|
||||
statement_compiler = FBCompiler
|
||||
ddl_compiler = FBDDLCompiler
|
||||
preparer = FBIdentifierPreparer
|
||||
type_compiler = FBTypeCompiler
|
||||
execution_ctx_cls = FBExecutionContext
|
||||
|
||||
colspecs = colspecs
|
||||
ischema_names = ischema_names
|
||||
|
||||
construct_arguments = []
|
||||
|
||||
# defaults to dialect ver. 3,
|
||||
# will be autodetected off upon
|
||||
# first connect
|
||||
_version_two = True
|
||||
|
||||
def initialize(self, connection):
|
||||
super(FBDialect, self).initialize(connection)
|
||||
self._version_two = ('firebird' in self.server_version_info and \
|
||||
self.server_version_info >= (2, )
|
||||
) or \
|
||||
('interbase' in self.server_version_info and \
|
||||
self.server_version_info >= (6, )
|
||||
)
|
||||
|
||||
if not self._version_two:
|
||||
# TODO: whatever other pre < 2.0 stuff goes here
|
||||
self.ischema_names = ischema_names.copy()
|
||||
self.ischema_names['TIMESTAMP'] = sqltypes.DATE
|
||||
self.colspecs = {
|
||||
sqltypes.DateTime: sqltypes.DATE
|
||||
}
|
||||
|
||||
self.implicit_returning = self._version_two and \
|
||||
self.__dict__.get('implicit_returning', True)
|
||||
|
||||
def normalize_name(self, name):
|
||||
# Remove trailing spaces: FB uses a CHAR() type,
|
||||
# that is padded with spaces
|
||||
name = name and name.rstrip()
|
||||
if name is None:
|
||||
return None
|
||||
elif name.upper() == name and \
|
||||
not self.identifier_preparer._requires_quotes(name.lower()):
|
||||
return name.lower()
|
||||
else:
|
||||
return name
|
||||
|
||||
def denormalize_name(self, name):
|
||||
if name is None:
|
||||
return None
|
||||
elif name.lower() == name and \
|
||||
not self.identifier_preparer._requires_quotes(name.lower()):
|
||||
return name.upper()
|
||||
else:
|
||||
return name
|
||||
|
||||
def has_table(self, connection, table_name, schema=None):
|
||||
"""Return ``True`` if the given table exists, ignoring
|
||||
the `schema`."""
|
||||
|
||||
tblqry = """
|
||||
SELECT 1 AS has_table FROM rdb$database
|
||||
WHERE EXISTS (SELECT rdb$relation_name
|
||||
FROM rdb$relations
|
||||
WHERE rdb$relation_name=?)
|
||||
"""
|
||||
c = connection.execute(tblqry, [self.denormalize_name(table_name)])
|
||||
return c.first() is not None
|
||||
|
||||
def has_sequence(self, connection, sequence_name, schema=None):
|
||||
"""Return ``True`` if the given sequence (generator) exists."""
|
||||
|
||||
genqry = """
|
||||
SELECT 1 AS has_sequence FROM rdb$database
|
||||
WHERE EXISTS (SELECT rdb$generator_name
|
||||
FROM rdb$generators
|
||||
WHERE rdb$generator_name=?)
|
||||
"""
|
||||
c = connection.execute(genqry, [self.denormalize_name(sequence_name)])
|
||||
return c.first() is not None
|
||||
|
||||
@reflection.cache
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
# there are two queries commonly mentioned for this.
|
||||
# this one, using view_blr, is at the Firebird FAQ among other places:
|
||||
# http://www.firebirdfaq.org/faq174/
|
||||
s = """
|
||||
select rdb$relation_name
|
||||
from rdb$relations
|
||||
where rdb$view_blr is null
|
||||
and (rdb$system_flag is null or rdb$system_flag = 0);
|
||||
"""
|
||||
|
||||
# the other query is this one. It's not clear if there's really
|
||||
# any difference between these two. This link:
|
||||
# http://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8
|
||||
# states them as interchangeable. Some discussion at [ticket:2898]
|
||||
# SELECT DISTINCT rdb$relation_name
|
||||
# FROM rdb$relation_fields
|
||||
# WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
|
||||
|
||||
return [self.normalize_name(row[0]) for row in connection.execute(s)]
|
||||
|
||||
@reflection.cache
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
# see http://www.firebirdfaq.org/faq174/
|
||||
s = """
|
||||
select rdb$relation_name
|
||||
from rdb$relations
|
||||
where rdb$view_blr is not null
|
||||
and (rdb$system_flag is null or rdb$system_flag = 0);
|
||||
"""
|
||||
return [self.normalize_name(row[0]) for row in connection.execute(s)]
|
||||
|
||||
@reflection.cache
|
||||
def get_view_definition(self, connection, view_name, schema=None, **kw):
|
||||
qry = """
|
||||
SELECT rdb$view_source AS view_source
|
||||
FROM rdb$relations
|
||||
WHERE rdb$relation_name=?
|
||||
"""
|
||||
rp = connection.execute(qry, [self.denormalize_name(view_name)])
|
||||
row = rp.first()
|
||||
if row:
|
||||
return row['view_source']
|
||||
else:
|
||||
return None
|
||||
|
||||
@reflection.cache
|
||||
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
|
||||
# Query to extract the PK/FK constrained fields of the given table
|
||||
keyqry = """
|
||||
SELECT se.rdb$field_name AS fname
|
||||
FROM rdb$relation_constraints rc
|
||||
JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
|
||||
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
|
||||
"""
|
||||
tablename = self.denormalize_name(table_name)
|
||||
# get primary key fields
|
||||
c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
|
||||
pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
|
||||
return {'constrained_columns': pkfields, 'name': None}
|
||||
|
||||
@reflection.cache
|
||||
def get_column_sequence(self, connection,
|
||||
table_name, column_name,
|
||||
schema=None, **kw):
|
||||
tablename = self.denormalize_name(table_name)
|
||||
colname = self.denormalize_name(column_name)
|
||||
# Heuristic-query to determine the generator associated to a PK field
|
||||
genqry = """
|
||||
SELECT trigdep.rdb$depended_on_name AS fgenerator
|
||||
FROM rdb$dependencies tabdep
|
||||
JOIN rdb$dependencies trigdep
|
||||
ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
|
||||
AND trigdep.rdb$depended_on_type=14
|
||||
AND trigdep.rdb$dependent_type=2
|
||||
JOIN rdb$triggers trig ON
|
||||
trig.rdb$trigger_name=tabdep.rdb$dependent_name
|
||||
WHERE tabdep.rdb$depended_on_name=?
|
||||
AND tabdep.rdb$depended_on_type=0
|
||||
AND trig.rdb$trigger_type=1
|
||||
AND tabdep.rdb$field_name=?
|
||||
AND (SELECT count(*)
|
||||
FROM rdb$dependencies trigdep2
|
||||
WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
|
||||
"""
|
||||
genr = connection.execute(genqry, [tablename, colname]).first()
|
||||
if genr is not None:
|
||||
return dict(name=self.normalize_name(genr['fgenerator']))
|
||||
|
||||
@reflection.cache
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
# Query to extract the details of all the fields of the given table
|
||||
tblqry = """
|
||||
SELECT r.rdb$field_name AS fname,
|
||||
r.rdb$null_flag AS null_flag,
|
||||
t.rdb$type_name AS ftype,
|
||||
f.rdb$field_sub_type AS stype,
|
||||
f.rdb$field_length/
|
||||
COALESCE(cs.rdb$bytes_per_character,1) AS flen,
|
||||
f.rdb$field_precision AS fprec,
|
||||
f.rdb$field_scale AS fscale,
|
||||
COALESCE(r.rdb$default_source,
|
||||
f.rdb$default_source) AS fdefault
|
||||
FROM rdb$relation_fields r
|
||||
JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
|
||||
JOIN rdb$types t
|
||||
ON t.rdb$type=f.rdb$field_type AND
|
||||
t.rdb$field_name='RDB$FIELD_TYPE'
|
||||
LEFT JOIN rdb$character_sets cs ON
|
||||
f.rdb$character_set_id=cs.rdb$character_set_id
|
||||
WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
|
||||
ORDER BY r.rdb$field_position
|
||||
"""
|
||||
# get the PK, used to determine the eventual associated sequence
|
||||
pk_constraint = self.get_pk_constraint(connection, table_name)
|
||||
pkey_cols = pk_constraint['constrained_columns']
|
||||
|
||||
tablename = self.denormalize_name(table_name)
|
||||
# get all of the fields for this table
|
||||
c = connection.execute(tblqry, [tablename])
|
||||
cols = []
|
||||
while True:
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
name = self.normalize_name(row['fname'])
|
||||
orig_colname = row['fname']
|
||||
|
||||
# get the data type
|
||||
colspec = row['ftype'].rstrip()
|
||||
coltype = self.ischema_names.get(colspec)
|
||||
if coltype is None:
|
||||
util.warn("Did not recognize type '%s' of column '%s'" %
|
||||
(colspec, name))
|
||||
coltype = sqltypes.NULLTYPE
|
||||
elif issubclass(coltype, Integer) and row['fprec'] != 0:
|
||||
coltype = NUMERIC(
|
||||
precision=row['fprec'],
|
||||
scale=row['fscale'] * -1)
|
||||
elif colspec in ('VARYING', 'CSTRING'):
|
||||
coltype = coltype(row['flen'])
|
||||
elif colspec == 'TEXT':
|
||||
coltype = TEXT(row['flen'])
|
||||
elif colspec == 'BLOB':
|
||||
if row['stype'] == 1:
|
||||
coltype = TEXT()
|
||||
else:
|
||||
coltype = BLOB()
|
||||
else:
|
||||
coltype = coltype()
|
||||
|
||||
# does it have a default value?
|
||||
defvalue = None
|
||||
if row['fdefault'] is not None:
|
||||
# the value comes down as "DEFAULT 'value'": there may be
|
||||
# more than one whitespace around the "DEFAULT" keyword
|
||||
# and it may also be lower case
|
||||
# (see also http://tracker.firebirdsql.org/browse/CORE-356)
|
||||
defexpr = row['fdefault'].lstrip()
|
||||
assert defexpr[:8].rstrip().upper() == \
|
||||
'DEFAULT', "Unrecognized default value: %s" % \
|
||||
defexpr
|
||||
defvalue = defexpr[8:].strip()
|
||||
if defvalue == 'NULL':
|
||||
# Redundant
|
||||
defvalue = None
|
||||
col_d = {
|
||||
'name': name,
|
||||
'type': coltype,
|
||||
'nullable': not bool(row['null_flag']),
|
||||
'default': defvalue,
|
||||
'autoincrement': defvalue is None
|
||||
}
|
||||
|
||||
if orig_colname.lower() == orig_colname:
|
||||
col_d['quote'] = True
|
||||
|
||||
# if the PK is a single field, try to see if its linked to
|
||||
# a sequence thru a trigger
|
||||
if len(pkey_cols) == 1 and name == pkey_cols[0]:
|
||||
seq_d = self.get_column_sequence(connection, tablename, name)
|
||||
if seq_d is not None:
|
||||
col_d['sequence'] = seq_d
|
||||
|
||||
cols.append(col_d)
|
||||
return cols
|
||||
|
||||
@reflection.cache
|
||||
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
|
||||
# Query to extract the details of each UK/FK of the given table
|
||||
fkqry = """
|
||||
SELECT rc.rdb$constraint_name AS cname,
|
||||
cse.rdb$field_name AS fname,
|
||||
ix2.rdb$relation_name AS targetrname,
|
||||
se.rdb$field_name AS targetfname
|
||||
FROM rdb$relation_constraints rc
|
||||
JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
|
||||
JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
|
||||
JOIN rdb$index_segments cse ON
|
||||
cse.rdb$index_name=ix1.rdb$index_name
|
||||
JOIN rdb$index_segments se
|
||||
ON se.rdb$index_name=ix2.rdb$index_name
|
||||
AND se.rdb$field_position=cse.rdb$field_position
|
||||
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
|
||||
ORDER BY se.rdb$index_name, se.rdb$field_position
|
||||
"""
|
||||
tablename = self.denormalize_name(table_name)
|
||||
|
||||
c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
|
||||
fks = util.defaultdict(lambda: {
|
||||
'name': None,
|
||||
'constrained_columns': [],
|
||||
'referred_schema': None,
|
||||
'referred_table': None,
|
||||
'referred_columns': []
|
||||
})
|
||||
|
||||
for row in c:
|
||||
cname = self.normalize_name(row['cname'])
|
||||
fk = fks[cname]
|
||||
if not fk['name']:
|
||||
fk['name'] = cname
|
||||
fk['referred_table'] = self.normalize_name(row['targetrname'])
|
||||
fk['constrained_columns'].append(
|
||||
self.normalize_name(row['fname']))
|
||||
fk['referred_columns'].append(
|
||||
self.normalize_name(row['targetfname']))
|
||||
return list(fks.values())
|
||||
|
||||
@reflection.cache
|
||||
def get_indexes(self, connection, table_name, schema=None, **kw):
|
||||
qry = """
|
||||
SELECT ix.rdb$index_name AS index_name,
|
||||
ix.rdb$unique_flag AS unique_flag,
|
||||
ic.rdb$field_name AS field_name
|
||||
FROM rdb$indices ix
|
||||
JOIN rdb$index_segments ic
|
||||
ON ix.rdb$index_name=ic.rdb$index_name
|
||||
LEFT OUTER JOIN rdb$relation_constraints
|
||||
ON rdb$relation_constraints.rdb$index_name =
|
||||
ic.rdb$index_name
|
||||
WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
|
||||
AND rdb$relation_constraints.rdb$constraint_type IS NULL
|
||||
ORDER BY index_name, ic.rdb$field_position
|
||||
"""
|
||||
c = connection.execute(qry, [self.denormalize_name(table_name)])
|
||||
|
||||
indexes = util.defaultdict(dict)
|
||||
for row in c:
|
||||
indexrec = indexes[row['index_name']]
|
||||
if 'name' not in indexrec:
|
||||
indexrec['name'] = self.normalize_name(row['index_name'])
|
||||
indexrec['column_names'] = []
|
||||
indexrec['unique'] = bool(row['unique_flag'])
|
||||
|
||||
indexrec['column_names'].append(
|
||||
self.normalize_name(row['field_name']))
|
||||
|
||||
return list(indexes.values())
|
||||
|
|
@ -1,115 +0,0 @@
|
|||
# firebird/fdb.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: firebird+fdb
|
||||
:name: fdb
|
||||
:dbapi: pyodbc
|
||||
:connectstring: firebird+fdb://user:password@host:port/path/to/db[?key=value&key=value...]
|
||||
:url: http://pypi.python.org/pypi/fdb/
|
||||
|
||||
fdb is a kinterbasdb compatible DBAPI for Firebird.
|
||||
|
||||
.. versionadded:: 0.8 - Support for the fdb Firebird driver.
|
||||
|
||||
.. versionchanged:: 0.9 - The fdb dialect is now the default dialect
|
||||
under the ``firebird://`` URL space, as ``fdb`` is now the official
|
||||
Python driver for Firebird.
|
||||
|
||||
Arguments
|
||||
----------
|
||||
|
||||
The ``fdb`` dialect is based on the :mod:`sqlalchemy.dialects.firebird.kinterbasdb`
|
||||
dialect, however does not accept every argument that Kinterbasdb does.
|
||||
|
||||
* ``enable_rowcount`` - True by default, setting this to False disables
|
||||
the usage of "cursor.rowcount" with the
|
||||
Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically
|
||||
after any UPDATE or DELETE statement. When disabled, SQLAlchemy's
|
||||
ResultProxy will return -1 for result.rowcount. The rationale here is
|
||||
that Kinterbasdb requires a second round trip to the database when
|
||||
.rowcount is called - since SQLA's resultproxy automatically closes
|
||||
the cursor after a non-result-returning statement, rowcount must be
|
||||
called, if at all, before the result object is returned. Additionally,
|
||||
cursor.rowcount may not return correct results with older versions
|
||||
of Firebird, and setting this flag to False will also cause the
|
||||
SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a
|
||||
per-execution basis using the ``enable_rowcount`` option with
|
||||
:meth:`.Connection.execution_options`::
|
||||
|
||||
conn = engine.connect().execution_options(enable_rowcount=True)
|
||||
r = conn.execute(stmt)
|
||||
print r.rowcount
|
||||
|
||||
* ``retaining`` - False by default. Setting this to True will pass the
|
||||
``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()``
|
||||
methods of the DBAPI connection, which can improve performance in some
|
||||
situations, but apparently with significant caveats.
|
||||
Please read the fdb and/or kinterbasdb DBAPI documentation in order to
|
||||
understand the implications of this flag.
|
||||
|
||||
.. versionadded:: 0.8.2 - ``retaining`` keyword argument specifying
|
||||
transaction retaining behavior - in 0.8 it defaults to ``True``
|
||||
for backwards compatibility.
|
||||
|
||||
.. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``.
|
||||
In 0.8 it defaulted to ``True``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
http://pythonhosted.org/fdb/usage-guide.html#retaining-transactions - information
|
||||
on the "retaining" flag.
|
||||
|
||||
"""
|
||||
|
||||
from .kinterbasdb import FBDialect_kinterbasdb
|
||||
from ... import util
|
||||
|
||||
|
||||
class FBDialect_fdb(FBDialect_kinterbasdb):
|
||||
|
||||
def __init__(self, enable_rowcount=True,
|
||||
retaining=False, **kwargs):
|
||||
super(FBDialect_fdb, self).__init__(
|
||||
enable_rowcount=enable_rowcount,
|
||||
retaining=retaining, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('fdb')
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
if opts.get('port'):
|
||||
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
|
||||
del opts['port']
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'type_conv', int)
|
||||
|
||||
return ([], opts)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
"""Get the version of the Firebird server used by a connection.
|
||||
|
||||
Returns a tuple of (`major`, `minor`, `build`), three integers
|
||||
representing the version of the attached server.
|
||||
"""
|
||||
|
||||
# This is the simpler approach (the other uses the services api),
|
||||
# that for backward compatibility reasons returns a string like
|
||||
# LI-V6.3.3.12981 Firebird 2.0
|
||||
# where the first version is a fake one resembling the old
|
||||
# Interbase signature.
|
||||
|
||||
isc_info_firebird_version = 103
|
||||
fbconn = connection.connection
|
||||
|
||||
version = fbconn.db_info(isc_info_firebird_version)
|
||||
|
||||
return self._parse_version_info(version)
|
||||
|
||||
dialect = FBDialect_fdb
|
|
@ -1,179 +0,0 @@
|
|||
# firebird/kinterbasdb.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: firebird+kinterbasdb
|
||||
:name: kinterbasdb
|
||||
:dbapi: kinterbasdb
|
||||
:connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db[?key=value&key=value...]
|
||||
:url: http://firebirdsql.org/index.php?op=devel&sub=python
|
||||
|
||||
Arguments
|
||||
----------
|
||||
|
||||
The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining``
|
||||
arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect. In addition, it
|
||||
also accepts the following:
|
||||
|
||||
* ``type_conv`` - select the kind of mapping done on the types: by default
|
||||
SQLAlchemy uses 200 with Unicode, datetime and decimal support. See
|
||||
the linked documents below for further information.
|
||||
|
||||
* ``concurrency_level`` - set the backend policy with regards to threading
|
||||
issues: by default SQLAlchemy uses policy 1. See the linked documents
|
||||
below for futher information.
|
||||
|
||||
.. seealso::
|
||||
|
||||
http://sourceforge.net/projects/kinterbasdb
|
||||
|
||||
http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
|
||||
|
||||
http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
|
||||
|
||||
"""
|
||||
|
||||
from .base import FBDialect, FBExecutionContext
|
||||
from ... import util, types as sqltypes
|
||||
from re import match
|
||||
import decimal
|
||||
|
||||
|
||||
class _kinterbasdb_numeric(object):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if isinstance(value, decimal.Decimal):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric):
|
||||
pass
|
||||
|
||||
class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float):
|
||||
pass
|
||||
|
||||
|
||||
class FBExecutionContext_kinterbasdb(FBExecutionContext):
|
||||
@property
|
||||
def rowcount(self):
|
||||
if self.execution_options.get('enable_rowcount',
|
||||
self.dialect.enable_rowcount):
|
||||
return self.cursor.rowcount
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
class FBDialect_kinterbasdb(FBDialect):
|
||||
driver = 'kinterbasdb'
|
||||
supports_sane_rowcount = False
|
||||
supports_sane_multi_rowcount = False
|
||||
execution_ctx_cls = FBExecutionContext_kinterbasdb
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
colspecs = util.update_copy(
|
||||
FBDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric: _FBNumeric_kinterbasdb,
|
||||
sqltypes.Float: _FBFloat_kinterbasdb,
|
||||
}
|
||||
|
||||
)
|
||||
|
||||
def __init__(self, type_conv=200, concurrency_level=1,
|
||||
enable_rowcount=True,
|
||||
retaining=False, **kwargs):
|
||||
super(FBDialect_kinterbasdb, self).__init__(**kwargs)
|
||||
self.enable_rowcount = enable_rowcount
|
||||
self.type_conv = type_conv
|
||||
self.concurrency_level = concurrency_level
|
||||
self.retaining = retaining
|
||||
if enable_rowcount:
|
||||
self.supports_sane_rowcount = True
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('kinterbasdb')
|
||||
|
||||
def do_execute(self, cursor, statement, parameters, context=None):
|
||||
# kinterbase does not accept a None, but wants an empty list
|
||||
# when there are no arguments.
|
||||
cursor.execute(statement, parameters or [])
|
||||
|
||||
def do_rollback(self, dbapi_connection):
|
||||
dbapi_connection.rollback(self.retaining)
|
||||
|
||||
def do_commit(self, dbapi_connection):
|
||||
dbapi_connection.commit(self.retaining)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
if opts.get('port'):
|
||||
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
|
||||
del opts['port']
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'type_conv', int)
|
||||
|
||||
type_conv = opts.pop('type_conv', self.type_conv)
|
||||
concurrency_level = opts.pop('concurrency_level',
|
||||
self.concurrency_level)
|
||||
|
||||
if self.dbapi is not None:
|
||||
initialized = getattr(self.dbapi, 'initialized', None)
|
||||
if initialized is None:
|
||||
# CVS rev 1.96 changed the name of the attribute:
|
||||
# http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/
|
||||
# Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
|
||||
initialized = getattr(self.dbapi, '_initialized', False)
|
||||
if not initialized:
|
||||
self.dbapi.init(type_conv=type_conv,
|
||||
concurrency_level=concurrency_level)
|
||||
return ([], opts)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
"""Get the version of the Firebird server used by a connection.
|
||||
|
||||
Returns a tuple of (`major`, `minor`, `build`), three integers
|
||||
representing the version of the attached server.
|
||||
"""
|
||||
|
||||
# This is the simpler approach (the other uses the services api),
|
||||
# that for backward compatibility reasons returns a string like
|
||||
# LI-V6.3.3.12981 Firebird 2.0
|
||||
# where the first version is a fake one resembling the old
|
||||
# Interbase signature.
|
||||
|
||||
fbconn = connection.connection
|
||||
version = fbconn.server_version
|
||||
|
||||
return self._parse_version_info(version)
|
||||
|
||||
def _parse_version_info(self, version):
|
||||
m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version)
|
||||
if not m:
|
||||
raise AssertionError(
|
||||
"Could not determine version from string '%s'" % version)
|
||||
|
||||
if m.group(5) != None:
|
||||
return tuple([int(x) for x in m.group(6, 7, 4)] + ['firebird'])
|
||||
else:
|
||||
return tuple([int(x) for x in m.group(1, 2, 3)] + ['interbase'])
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, (self.dbapi.OperationalError,
|
||||
self.dbapi.ProgrammingError)):
|
||||
msg = str(e)
|
||||
return ('Unable to complete network request to host' in msg or
|
||||
'Invalid connection state' in msg or
|
||||
'Invalid cursor state' in msg or
|
||||
'connection shutdown' in msg)
|
||||
else:
|
||||
return False
|
||||
|
||||
dialect = FBDialect_kinterbasdb
|
|
@ -1,26 +0,0 @@
|
|||
# mssql/__init__.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, \
|
||||
pymssql, zxjdbc, mxodbc
|
||||
|
||||
base.dialect = pyodbc.dialect
|
||||
|
||||
from sqlalchemy.dialects.mssql.base import \
|
||||
INTEGER, BIGINT, SMALLINT, TINYINT, VARCHAR, NVARCHAR, CHAR, \
|
||||
NCHAR, TEXT, NTEXT, DECIMAL, NUMERIC, FLOAT, DATETIME,\
|
||||
DATETIME2, DATETIMEOFFSET, DATE, TIME, SMALLDATETIME, \
|
||||
BINARY, VARBINARY, BIT, REAL, IMAGE, TIMESTAMP,\
|
||||
MONEY, SMALLMONEY, UNIQUEIDENTIFIER, SQL_VARIANT, dialect
|
||||
|
||||
|
||||
__all__ = (
|
||||
'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR',
|
||||
'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME',
|
||||
'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME',
|
||||
'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP',
|
||||
'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'dialect'
|
||||
)
|
|
@ -1,79 +0,0 @@
|
|||
# mssql/adodbapi.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: mssql+adodbapi
|
||||
:name: adodbapi
|
||||
:dbapi: adodbapi
|
||||
:connectstring: mssql+adodbapi://<username>:<password>@<dsnname>
|
||||
:url: http://adodbapi.sourceforge.net/
|
||||
|
||||
.. note::
|
||||
|
||||
The adodbapi dialect is not implemented SQLAlchemy versions 0.6 and
|
||||
above at this time.
|
||||
|
||||
"""
|
||||
import datetime
|
||||
from sqlalchemy import types as sqltypes, util
|
||||
from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
|
||||
import sys
|
||||
|
||||
|
||||
class MSDateTime_adodbapi(MSDateTime):
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
# adodbapi will return datetimes with empty time
|
||||
# values as datetime.date() objects.
|
||||
# Promote them back to full datetime.datetime()
|
||||
if type(value) is datetime.date:
|
||||
return datetime.datetime(value.year, value.month, value.day)
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class MSDialect_adodbapi(MSDialect):
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
supports_unicode = sys.maxunicode == 65535
|
||||
supports_unicode_statements = True
|
||||
driver = 'adodbapi'
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
import adodbapi as module
|
||||
return module
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{
|
||||
sqltypes.DateTime: MSDateTime_adodbapi
|
||||
}
|
||||
)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
keys = url.query
|
||||
|
||||
connectors = ["Provider=SQLOLEDB"]
|
||||
if 'port' in keys:
|
||||
connectors.append("Data Source=%s, %s" %
|
||||
(keys.get("host"), keys.get("port")))
|
||||
else:
|
||||
connectors.append("Data Source=%s" % keys.get("host"))
|
||||
connectors.append("Initial Catalog=%s" % keys.get("database"))
|
||||
user = keys.get("user")
|
||||
if user:
|
||||
connectors.append("User Id=%s" % user)
|
||||
connectors.append("Password=%s" % keys.get("password", ""))
|
||||
else:
|
||||
connectors.append("Integrated Security=SSPI")
|
||||
return [[";".join(connectors)], {}]
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \
|
||||
"'connection failure'" in str(e)
|
||||
|
||||
dialect = MSDialect_adodbapi
|
File diff suppressed because it is too large
Load diff
|
@ -1,114 +0,0 @@
|
|||
# mssql/information_schema.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
# TODO: should be using the sys. catalog with SQL Server, not information schema
|
||||
|
||||
from ... import Table, MetaData, Column
|
||||
from ...types import String, Unicode, UnicodeText, Integer, TypeDecorator
|
||||
from ... import cast
|
||||
from ... import util
|
||||
from ...sql import expression
|
||||
from ...ext.compiler import compiles
|
||||
|
||||
ischema = MetaData()
|
||||
|
||||
class CoerceUnicode(TypeDecorator):
|
||||
impl = Unicode
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if util.py2k and isinstance(value, util.binary_type):
|
||||
value = value.decode(dialect.encoding)
|
||||
return value
|
||||
|
||||
def bind_expression(self, bindvalue):
|
||||
return _cast_on_2005(bindvalue)
|
||||
|
||||
class _cast_on_2005(expression.ColumnElement):
|
||||
def __init__(self, bindvalue):
|
||||
self.bindvalue = bindvalue
|
||||
|
||||
@compiles(_cast_on_2005)
|
||||
def _compile(element, compiler, **kw):
|
||||
from . import base
|
||||
if compiler.dialect.server_version_info < base.MS_2005_VERSION:
|
||||
return compiler.process(element.bindvalue, **kw)
|
||||
else:
|
||||
return compiler.process(cast(element.bindvalue, Unicode), **kw)
|
||||
|
||||
schemata = Table("SCHEMATA", ischema,
|
||||
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
|
||||
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
|
||||
Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
tables = Table("TABLES", ischema,
|
||||
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
columns = Table("COLUMNS", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("IS_NULLABLE", Integer, key="is_nullable"),
|
||||
Column("DATA_TYPE", String, key="data_type"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"),
|
||||
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
|
||||
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
|
||||
Column("COLUMN_DEFAULT", Integer, key="column_default"),
|
||||
Column("COLLATION_NAME", String, key="collation_name"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
constraints = Table("TABLE_CONSTRAINTS", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
key_constraints = Table("KEY_COLUMN_USAGE", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
|
||||
Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
|
||||
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
# TODO: is CATLOG misspelled ?
|
||||
Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode,
|
||||
key="unique_constraint_catalog"),
|
||||
|
||||
Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode,
|
||||
key="unique_constraint_schema"),
|
||||
Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode,
|
||||
key="unique_constraint_name"),
|
||||
Column("MATCH_OPTION", String, key="match_option"),
|
||||
Column("UPDATE_RULE", String, key="update_rule"),
|
||||
Column("DELETE_RULE", String, key="delete_rule"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
views = Table("VIEWS", ischema,
|
||||
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
|
||||
Column("CHECK_OPTION", String, key="check_option"),
|
||||
Column("IS_UPDATABLE", String, key="is_updatable"),
|
||||
schema="INFORMATION_SCHEMA")
|
|
@ -1,111 +0,0 @@
|
|||
# mssql/mxodbc.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: mssql+mxodbc
|
||||
:name: mxODBC
|
||||
:dbapi: mxodbc
|
||||
:connectstring: mssql+mxodbc://<username>:<password>@<dsnname>
|
||||
:url: http://www.egenix.com/
|
||||
|
||||
Execution Modes
|
||||
---------------
|
||||
|
||||
mxODBC features two styles of statement execution, using the
|
||||
``cursor.execute()`` and ``cursor.executedirect()`` methods (the second being
|
||||
an extension to the DBAPI specification). The former makes use of a particular
|
||||
API call specific to the SQL Server Native Client ODBC driver known
|
||||
SQLDescribeParam, while the latter does not.
|
||||
|
||||
mxODBC apparently only makes repeated use of a single prepared statement
|
||||
when SQLDescribeParam is used. The advantage to prepared statement reuse is
|
||||
one of performance. The disadvantage is that SQLDescribeParam has a limited
|
||||
set of scenarios in which bind parameters are understood, including that they
|
||||
cannot be placed within the argument lists of function calls, anywhere outside
|
||||
the FROM, or even within subqueries within the FROM clause - making the usage
|
||||
of bind parameters within SELECT statements impossible for all but the most
|
||||
simplistic statements.
|
||||
|
||||
For this reason, the mxODBC dialect uses the "native" mode by default only for
|
||||
INSERT, UPDATE, and DELETE statements, and uses the escaped string mode for
|
||||
all other statements.
|
||||
|
||||
This behavior can be controlled via
|
||||
:meth:`~sqlalchemy.sql.expression.Executable.execution_options` using the
|
||||
``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a
|
||||
value of ``True`` will unconditionally use native bind parameters and a value
|
||||
of ``False`` will unconditionally use string-escaped parameters.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
from ... import types as sqltypes
|
||||
from ...connectors.mxodbc import MxODBCConnector
|
||||
from .pyodbc import MSExecutionContext_pyodbc, _MSNumeric_pyodbc
|
||||
from .base import (MSDialect,
|
||||
MSSQLStrictCompiler,
|
||||
_MSDateTime, _MSDate, _MSTime)
|
||||
|
||||
|
||||
class _MSNumeric_mxodbc(_MSNumeric_pyodbc):
|
||||
"""Include pyodbc's numeric processor.
|
||||
"""
|
||||
|
||||
|
||||
class _MSDate_mxodbc(_MSDate):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return "%s-%s-%s" % (value.year, value.month, value.day)
|
||||
else:
|
||||
return None
|
||||
return process
|
||||
|
||||
|
||||
class _MSTime_mxodbc(_MSTime):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return "%s:%s:%s" % (value.hour, value.minute, value.second)
|
||||
else:
|
||||
return None
|
||||
return process
|
||||
|
||||
|
||||
class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
|
||||
"""
|
||||
The pyodbc execution context is useful for enabling
|
||||
SELECT SCOPE_IDENTITY in cases where OUTPUT clause
|
||||
does not work (tables with insert triggers).
|
||||
"""
|
||||
#todo - investigate whether the pyodbc execution context
|
||||
# is really only being used in cases where OUTPUT
|
||||
# won't work.
|
||||
|
||||
|
||||
class MSDialect_mxodbc(MxODBCConnector, MSDialect):
|
||||
|
||||
# this is only needed if "native ODBC" mode is used,
|
||||
# which is now disabled by default.
|
||||
#statement_compiler = MSSQLStrictCompiler
|
||||
|
||||
execution_ctx_cls = MSExecutionContext_mxodbc
|
||||
|
||||
# flag used by _MSNumeric_mxodbc
|
||||
_need_decimal_fix = True
|
||||
|
||||
colspecs = {
|
||||
sqltypes.Numeric: _MSNumeric_mxodbc,
|
||||
sqltypes.DateTime: _MSDateTime,
|
||||
sqltypes.Date: _MSDate_mxodbc,
|
||||
sqltypes.Time: _MSTime_mxodbc,
|
||||
}
|
||||
|
||||
def __init__(self, description_encoding=None, **params):
|
||||
super(MSDialect_mxodbc, self).__init__(**params)
|
||||
self.description_encoding = description_encoding
|
||||
|
||||
dialect = MSDialect_mxodbc
|
|
@ -1,92 +0,0 @@
|
|||
# mssql/pymssql.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: mssql+pymssql
|
||||
:name: pymssql
|
||||
:dbapi: pymssql
|
||||
:connectstring: mssql+pymssql://<username>:<password>@<freetds_name>?charset=utf8
|
||||
:url: http://pymssql.org/
|
||||
|
||||
pymssql is a Python module that provides a Python DBAPI interface around
|
||||
`FreeTDS <http://www.freetds.org/>`_. Compatible builds are available for
|
||||
Linux, MacOSX and Windows platforms.
|
||||
|
||||
"""
|
||||
from .base import MSDialect
|
||||
from ... import types as sqltypes, util, processors
|
||||
import re
|
||||
|
||||
|
||||
class _MSNumeric_pymssql(sqltypes.Numeric):
|
||||
def result_processor(self, dialect, type_):
|
||||
if not self.asdecimal:
|
||||
return processors.to_float
|
||||
else:
|
||||
return sqltypes.Numeric.result_processor(self, dialect, type_)
|
||||
|
||||
|
||||
class MSDialect_pymssql(MSDialect):
|
||||
supports_sane_rowcount = False
|
||||
driver = 'pymssql'
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric: _MSNumeric_pymssql,
|
||||
sqltypes.Float: sqltypes.Float,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
module = __import__('pymssql')
|
||||
# pymmsql doesn't have a Binary method. we use string
|
||||
# TODO: monkeypatching here is less than ideal
|
||||
module.Binary = lambda x: x if hasattr(x, 'decode') else str(x)
|
||||
|
||||
client_ver = tuple(int(x) for x in module.__version__.split("."))
|
||||
if client_ver < (1, ):
|
||||
util.warn("The pymssql dialect expects at least "
|
||||
"the 1.0 series of the pymssql DBAPI.")
|
||||
return module
|
||||
|
||||
def __init__(self, **params):
|
||||
super(MSDialect_pymssql, self).__init__(**params)
|
||||
self.use_scope_identity = True
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
vers = connection.scalar("select @@version")
|
||||
m = re.match(
|
||||
r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers)
|
||||
if m:
|
||||
return tuple(int(x) for x in m.group(1, 2, 3, 4))
|
||||
else:
|
||||
return None
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
opts.update(url.query)
|
||||
port = opts.pop('port', None)
|
||||
if port and 'host' in opts:
|
||||
opts['host'] = "%s:%s" % (opts['host'], port)
|
||||
return [[], opts]
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
for msg in (
|
||||
"Adaptive Server connection timed out",
|
||||
"Net-Lib error during Connection reset by peer",
|
||||
"message 20003", # connection timeout
|
||||
"Error 10054",
|
||||
"Not connected to any MS SQL server",
|
||||
"Connection is closed"
|
||||
):
|
||||
if msg in str(e):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
dialect = MSDialect_pymssql
|
|
@ -1,260 +0,0 @@
|
|||
# mssql/pyodbc.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: mssql+pyodbc
|
||||
:name: PyODBC
|
||||
:dbapi: pyodbc
|
||||
:connectstring: mssql+pyodbc://<username>:<password>@<dsnname>
|
||||
:url: http://pypi.python.org/pypi/pyodbc/
|
||||
|
||||
Additional Connection Examples
|
||||
-------------------------------
|
||||
|
||||
Examples of pyodbc connection string URLs:
|
||||
|
||||
* ``mssql+pyodbc://mydsn`` - connects using the specified DSN named ``mydsn``.
|
||||
The connection string that is created will appear like::
|
||||
|
||||
dsn=mydsn;Trusted_Connection=Yes
|
||||
|
||||
* ``mssql+pyodbc://user:pass@mydsn`` - connects using the DSN named
|
||||
``mydsn`` passing in the ``UID`` and ``PWD`` information. The
|
||||
connection string that is created will appear like::
|
||||
|
||||
dsn=mydsn;UID=user;PWD=pass
|
||||
|
||||
* ``mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english`` - connects
|
||||
using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
|
||||
information, plus the additional connection configuration option
|
||||
``LANGUAGE``. The connection string that is created will appear
|
||||
like::
|
||||
|
||||
dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
|
||||
|
||||
* ``mssql+pyodbc://user:pass@host/db`` - connects using a connection
|
||||
that would appear like::
|
||||
|
||||
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
|
||||
|
||||
* ``mssql+pyodbc://user:pass@host:123/db`` - connects using a connection
|
||||
string which includes the port
|
||||
information using the comma syntax. This will create the following
|
||||
connection string::
|
||||
|
||||
DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
|
||||
|
||||
* ``mssql+pyodbc://user:pass@host/db?port=123`` - connects using a connection
|
||||
string that includes the port
|
||||
information as a separate ``port`` keyword. This will create the
|
||||
following connection string::
|
||||
|
||||
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
|
||||
|
||||
* ``mssql+pyodbc://user:pass@host/db?driver=MyDriver`` - connects using a connection
|
||||
string that includes a custom
|
||||
ODBC driver name. This will create the following connection string::
|
||||
|
||||
DRIVER={MyDriver};Server=host;Database=db;UID=user;PWD=pass
|
||||
|
||||
If you require a connection string that is outside the options
|
||||
presented above, use the ``odbc_connect`` keyword to pass in a
|
||||
urlencoded connection string. What gets passed in will be urldecoded
|
||||
and passed directly.
|
||||
|
||||
For example::
|
||||
|
||||
mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
|
||||
|
||||
would create the following connection string::
|
||||
|
||||
dsn=mydsn;Database=db
|
||||
|
||||
Encoding your connection string can be easily accomplished through
|
||||
the python shell. For example::
|
||||
|
||||
>>> import urllib
|
||||
>>> urllib.quote_plus('dsn=mydsn;Database=db')
|
||||
'dsn%3Dmydsn%3BDatabase%3Ddb'
|
||||
|
||||
Unicode Binds
|
||||
-------------
|
||||
|
||||
The current state of PyODBC on a unix backend with FreeTDS and/or
|
||||
EasySoft is poor regarding unicode; different OS platforms and versions of UnixODBC
|
||||
versus IODBC versus FreeTDS/EasySoft versus PyODBC itself dramatically
|
||||
alter how strings are received. The PyODBC dialect attempts to use all the information
|
||||
it knows to determine whether or not a Python unicode literal can be
|
||||
passed directly to the PyODBC driver or not; while SQLAlchemy can encode
|
||||
these to bytestrings first, some users have reported that PyODBC mis-handles
|
||||
bytestrings for certain encodings and requires a Python unicode object,
|
||||
while the author has observed widespread cases where a Python unicode
|
||||
is completely misinterpreted by PyODBC, particularly when dealing with
|
||||
the information schema tables used in table reflection, and the value
|
||||
must first be encoded to a bytestring.
|
||||
|
||||
It is for this reason that whether or not unicode literals for bound
|
||||
parameters be sent to PyODBC can be controlled using the
|
||||
``supports_unicode_binds`` parameter to ``create_engine()``. When
|
||||
left at its default of ``None``, the PyODBC dialect will use its
|
||||
best guess as to whether or not the driver deals with unicode literals
|
||||
well. When ``False``, unicode literals will be encoded first, and when
|
||||
``True`` unicode literals will be passed straight through. This is an interim
|
||||
flag that hopefully should not be needed when the unicode situation stabilizes
|
||||
for unix + PyODBC.
|
||||
|
||||
.. versionadded:: 0.7.7
|
||||
``supports_unicode_binds`` parameter to ``create_engine()``\ .
|
||||
|
||||
"""
|
||||
|
||||
from .base import MSExecutionContext, MSDialect
|
||||
from ...connectors.pyodbc import PyODBCConnector
|
||||
from ... import types as sqltypes, util
|
||||
import decimal
|
||||
|
||||
class _ms_numeric_pyodbc(object):
|
||||
|
||||
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
|
||||
|
||||
The routines here are needed for older pyodbc versions
|
||||
as well as current mxODBC versions.
|
||||
|
||||
"""
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
|
||||
super_process = super(_ms_numeric_pyodbc, self).\
|
||||
bind_processor(dialect)
|
||||
|
||||
if not dialect._need_decimal_fix:
|
||||
return super_process
|
||||
|
||||
def process(value):
|
||||
if self.asdecimal and \
|
||||
isinstance(value, decimal.Decimal):
|
||||
|
||||
adjusted = value.adjusted()
|
||||
if adjusted < 0:
|
||||
return self._small_dec_to_string(value)
|
||||
elif adjusted > 7:
|
||||
return self._large_dec_to_string(value)
|
||||
|
||||
if super_process:
|
||||
return super_process(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
# these routines needed for older versions of pyodbc.
|
||||
# as of 2.1.8 this logic is integrated.
|
||||
|
||||
def _small_dec_to_string(self, value):
|
||||
return "%s0.%s%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
'0' * (abs(value.adjusted()) - 1),
|
||||
"".join([str(nint) for nint in value.as_tuple()[1]]))
|
||||
|
||||
def _large_dec_to_string(self, value):
|
||||
_int = value.as_tuple()[1]
|
||||
if 'E' in str(value):
|
||||
result = "%s%s%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
"".join([str(s) for s in _int]),
|
||||
"0" * (value.adjusted() - (len(_int) - 1)))
|
||||
else:
|
||||
if (len(_int) - 1) > value.adjusted():
|
||||
result = "%s%s.%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
"".join(
|
||||
[str(s) for s in _int][0:value.adjusted() + 1]),
|
||||
"".join(
|
||||
[str(s) for s in _int][value.adjusted() + 1:]))
|
||||
else:
|
||||
result = "%s%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
"".join(
|
||||
[str(s) for s in _int][0:value.adjusted() + 1]))
|
||||
return result
|
||||
|
||||
class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric):
|
||||
pass
|
||||
|
||||
class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float):
|
||||
pass
|
||||
|
||||
class MSExecutionContext_pyodbc(MSExecutionContext):
|
||||
_embedded_scope_identity = False
|
||||
|
||||
def pre_exec(self):
|
||||
"""where appropriate, issue "select scope_identity()" in the same
|
||||
statement.
|
||||
|
||||
Background on why "scope_identity()" is preferable to "@@identity":
|
||||
http://msdn.microsoft.com/en-us/library/ms190315.aspx
|
||||
|
||||
Background on why we attempt to embed "scope_identity()" into the same
|
||||
statement as the INSERT:
|
||||
http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
|
||||
|
||||
"""
|
||||
|
||||
super(MSExecutionContext_pyodbc, self).pre_exec()
|
||||
|
||||
# don't embed the scope_identity select into an
|
||||
# "INSERT .. DEFAULT VALUES"
|
||||
if self._select_lastrowid and \
|
||||
self.dialect.use_scope_identity and \
|
||||
len(self.parameters[0]):
|
||||
self._embedded_scope_identity = True
|
||||
|
||||
self.statement += "; select scope_identity()"
|
||||
|
||||
def post_exec(self):
|
||||
if self._embedded_scope_identity:
|
||||
# Fetch the last inserted id from the manipulated statement
|
||||
# We may have to skip over a number of result sets with
|
||||
# no data (due to triggers, etc.)
|
||||
while True:
|
||||
try:
|
||||
# fetchall() ensures the cursor is consumed
|
||||
# without closing it (FreeTDS particularly)
|
||||
row = self.cursor.fetchall()[0]
|
||||
break
|
||||
except self.dialect.dbapi.Error as e:
|
||||
# no way around this - nextset() consumes the previous set
|
||||
# so we need to just keep flipping
|
||||
self.cursor.nextset()
|
||||
|
||||
self._lastrowid = int(row[0])
|
||||
else:
|
||||
super(MSExecutionContext_pyodbc, self).post_exec()
|
||||
|
||||
|
||||
class MSDialect_pyodbc(PyODBCConnector, MSDialect):
|
||||
|
||||
execution_ctx_cls = MSExecutionContext_pyodbc
|
||||
|
||||
pyodbc_driver_name = 'SQL Server'
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric: _MSNumeric_pyodbc,
|
||||
sqltypes.Float: _MSFloat_pyodbc
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, description_encoding=None, **params):
|
||||
super(MSDialect_pyodbc, self).__init__(**params)
|
||||
self.description_encoding = description_encoding
|
||||
self.use_scope_identity = self.use_scope_identity and \
|
||||
self.dbapi and \
|
||||
hasattr(self.dbapi.Cursor, 'nextset')
|
||||
self._need_decimal_fix = self.dbapi and \
|
||||
self._dbapi_version() < (2, 1, 8)
|
||||
|
||||
dialect = MSDialect_pyodbc
|
|
@ -1,65 +0,0 @@
|
|||
# mssql/zxjdbc.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: mssql+zxjdbc
|
||||
:name: zxJDBC for Jython
|
||||
:dbapi: zxjdbc
|
||||
:connectstring: mssql+zxjdbc://user:pass@host:port/dbname[?key=value&key=value...]
|
||||
:driverurl: http://jtds.sourceforge.net/
|
||||
|
||||
|
||||
"""
|
||||
from ...connectors.zxJDBC import ZxJDBCConnector
|
||||
from .base import MSDialect, MSExecutionContext
|
||||
from ... import engine
|
||||
|
||||
|
||||
class MSExecutionContext_zxjdbc(MSExecutionContext):
|
||||
|
||||
_embedded_scope_identity = False
|
||||
|
||||
def pre_exec(self):
|
||||
super(MSExecutionContext_zxjdbc, self).pre_exec()
|
||||
# scope_identity after the fact returns null in jTDS so we must
|
||||
# embed it
|
||||
if self._select_lastrowid and self.dialect.use_scope_identity:
|
||||
self._embedded_scope_identity = True
|
||||
self.statement += "; SELECT scope_identity()"
|
||||
|
||||
def post_exec(self):
|
||||
if self._embedded_scope_identity:
|
||||
while True:
|
||||
try:
|
||||
row = self.cursor.fetchall()[0]
|
||||
break
|
||||
except self.dialect.dbapi.Error:
|
||||
self.cursor.nextset()
|
||||
self._lastrowid = int(row[0])
|
||||
|
||||
if (self.isinsert or self.isupdate or self.isdelete) and \
|
||||
self.compiled.returning:
|
||||
self._result_proxy = engine.FullyBufferedResultProxy(self)
|
||||
|
||||
if self._enable_identity_insert:
|
||||
table = self.dialect.identifier_preparer.format_table(
|
||||
self.compiled.statement.table)
|
||||
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table)
|
||||
|
||||
|
||||
class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect):
|
||||
jdbc_db_name = 'jtds:sqlserver'
|
||||
jdbc_driver_name = 'net.sourceforge.jtds.jdbc.Driver'
|
||||
|
||||
execution_ctx_cls = MSExecutionContext_zxjdbc
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
return tuple(
|
||||
int(x)
|
||||
for x in connection.connection.dbversion.split('.')
|
||||
)
|
||||
|
||||
dialect = MSDialect_zxjdbc
|
|
@ -1,28 +0,0 @@
|
|||
# mysql/__init__.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from . import base, mysqldb, oursql, \
|
||||
pyodbc, zxjdbc, mysqlconnector, pymysql,\
|
||||
gaerdbms, cymysql
|
||||
|
||||
# default dialect
|
||||
base.dialect = mysqldb.dialect
|
||||
|
||||
from .base import \
|
||||
BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, \
|
||||
DECIMAL, DOUBLE, ENUM, DECIMAL,\
|
||||
FLOAT, INTEGER, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, \
|
||||
MEDIUMINT, MEDIUMTEXT, NCHAR, \
|
||||
NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, \
|
||||
TINYBLOB, TINYINT, TINYTEXT,\
|
||||
VARBINARY, VARCHAR, YEAR, dialect
|
||||
|
||||
__all__ = (
|
||||
'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE',
|
||||
'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT',
|
||||
'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP',
|
||||
'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect'
|
||||
)
|
File diff suppressed because it is too large
Load diff
|
@ -1,84 +0,0 @@
|
|||
# mysql/cymysql.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: mysql+cymysql
|
||||
:name: CyMySQL
|
||||
:dbapi: cymysql
|
||||
:connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>[?<options>]
|
||||
:url: https://github.com/nakagami/CyMySQL
|
||||
|
||||
"""
|
||||
import re
|
||||
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from .base import (BIT, MySQLDialect)
|
||||
from ... import util
|
||||
|
||||
class _cymysqlBIT(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Convert a MySQL's 64 bit, variable length binary string to a long.
|
||||
"""
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
v = 0
|
||||
for i in util.iterbytes(value):
|
||||
v = v << 8 | i
|
||||
return v
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class MySQLDialect_cymysql(MySQLDialect_mysqldb):
|
||||
driver = 'cymysql'
|
||||
|
||||
description_encoding = None
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = False
|
||||
supports_unicode_statements = True
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs,
|
||||
{
|
||||
BIT: _cymysqlBIT,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('cymysql')
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
for n in r.split(dbapi_con.server_version):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.charset
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.errno
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.OperationalError):
|
||||
return self._extract_error_code(e) in \
|
||||
(2006, 2013, 2014, 2045, 2055)
|
||||
elif isinstance(e, self.dbapi.InterfaceError):
|
||||
# if underlying connection is closed,
|
||||
# this is the error you get
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
dialect = MySQLDialect_cymysql
|
|
@ -1,84 +0,0 @@
|
|||
# mysql/gaerdbms.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""
|
||||
.. dialect:: mysql+gaerdbms
|
||||
:name: Google Cloud SQL
|
||||
:dbapi: rdbms
|
||||
:connectstring: mysql+gaerdbms:///<dbname>?instance=<instancename>
|
||||
:url: https://developers.google.com/appengine/docs/python/cloud-sql/developers-guide
|
||||
|
||||
This dialect is based primarily on the :mod:`.mysql.mysqldb` dialect with minimal
|
||||
changes.
|
||||
|
||||
.. versionadded:: 0.7.8
|
||||
|
||||
|
||||
Pooling
|
||||
-------
|
||||
|
||||
Google App Engine connections appear to be randomly recycled,
|
||||
so the dialect does not pool connections. The :class:`.NullPool`
|
||||
implementation is installed within the :class:`.Engine` by
|
||||
default.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from ...pool import NullPool
|
||||
import re
|
||||
|
||||
|
||||
def _is_dev_environment():
|
||||
return os.environ.get('SERVER_SOFTWARE', '').startswith('Development/')
|
||||
|
||||
|
||||
class MySQLDialect_gaerdbms(MySQLDialect_mysqldb):
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
# from django:
|
||||
# http://code.google.com/p/googleappengine/source/
|
||||
# browse/trunk/python/google/storage/speckle/
|
||||
# python/django/backend/base.py#118
|
||||
# see also [ticket:2649]
|
||||
# see also http://stackoverflow.com/q/14224679/34549
|
||||
from google.appengine.api import apiproxy_stub_map
|
||||
|
||||
if _is_dev_environment():
|
||||
from google.appengine.api import rdbms_mysqldb
|
||||
return rdbms_mysqldb
|
||||
elif apiproxy_stub_map.apiproxy.GetStub('rdbms'):
|
||||
from google.storage.speckle.python.api import rdbms_apiproxy
|
||||
return rdbms_apiproxy
|
||||
else:
|
||||
from google.storage.speckle.python.api import rdbms_googleapi
|
||||
return rdbms_googleapi
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
# Cloud SQL connections die at any moment
|
||||
return NullPool
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args()
|
||||
if not _is_dev_environment():
|
||||
# 'dsn' and 'instance' are because we are skipping
|
||||
# the traditional google.api.rdbms wrapper
|
||||
opts['dsn'] = ''
|
||||
opts['instance'] = url.query['instance']
|
||||
return [], opts
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
match = re.compile(r"^(\d+)L?:|^\((\d+)L?,").match(str(exception))
|
||||
# The rdbms api will wrap then re-raise some types of errors
|
||||
# making this regex return no matches.
|
||||
code = match.group(1) or match.group(2) if match else None
|
||||
if code:
|
||||
return int(code)
|
||||
|
||||
dialect = MySQLDialect_gaerdbms
|
|
@ -1,131 +0,0 @@
|
|||
# mysql/mysqlconnector.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: mysql+mysqlconnector
|
||||
:name: MySQL Connector/Python
|
||||
:dbapi: myconnpy
|
||||
:connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: http://dev.mysql.com/downloads/connector/python/
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from .base import (MySQLDialect,
|
||||
MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer,
|
||||
BIT)
|
||||
|
||||
from ... import util
|
||||
|
||||
|
||||
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
|
||||
|
||||
def get_lastrowid(self):
|
||||
return self.cursor.lastrowid
|
||||
|
||||
|
||||
class MySQLCompiler_mysqlconnector(MySQLCompiler):
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
return self.process(binary.left, **kw) + " %% " + \
|
||||
self.process(binary.right, **kw)
|
||||
|
||||
def post_process_text(self, text):
|
||||
return text.replace('%', '%%')
|
||||
|
||||
|
||||
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
|
||||
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value.replace("%", "%%")
|
||||
|
||||
|
||||
class _myconnpyBIT(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""MySQL-connector already converts mysql bits, so."""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
driver = 'mysqlconnector'
|
||||
|
||||
if util.py2k:
|
||||
supports_unicode_statements = False
|
||||
supports_unicode_binds = True
|
||||
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
default_paramstyle = 'format'
|
||||
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
|
||||
statement_compiler = MySQLCompiler_mysqlconnector
|
||||
|
||||
preparer = MySQLIdentifierPreparer_mysqlconnector
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs,
|
||||
{
|
||||
BIT: _myconnpyBIT,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
from mysql import connector
|
||||
return connector
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'buffered', bool)
|
||||
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
|
||||
opts.setdefault('buffered', True)
|
||||
opts.setdefault('raise_on_warnings', True)
|
||||
|
||||
# FOUND_ROWS must be set in ClientFlag to enable
|
||||
# supports_sane_rowcount.
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
from mysql.connector.constants import ClientFlag
|
||||
client_flags = opts.get('client_flags', ClientFlag.get_default())
|
||||
client_flags |= ClientFlag.FOUND_ROWS
|
||||
opts['client_flags'] = client_flags
|
||||
except:
|
||||
pass
|
||||
return [[], opts]
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = dbapi_con.get_server_version()
|
||||
return tuple(version)
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.charset
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.errno
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
|
||||
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
|
||||
if isinstance(e, exceptions):
|
||||
return e.errno in errnos or \
|
||||
"MySQL Connection not available." in str(e)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _compat_fetchall(self, rp, charset=None):
|
||||
return rp.fetchall()
|
||||
|
||||
def _compat_fetchone(self, rp, charset=None):
|
||||
return rp.fetchone()
|
||||
|
||||
dialect = MySQLDialect_mysqlconnector
|
|
@ -1,94 +0,0 @@
|
|||
# mysql/mysqldb.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: mysql+mysqldb
|
||||
:name: MySQL-Python
|
||||
:dbapi: mysqldb
|
||||
:connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: http://sourceforge.net/projects/mysql-python
|
||||
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
MySQLdb requires a "charset" parameter to be passed in order for it
|
||||
to handle non-ASCII characters correctly. When this parameter is passed,
|
||||
MySQLdb will also implicitly set the "use_unicode" flag to true, which means
|
||||
that it will return Python unicode objects instead of bytestrings.
|
||||
However, SQLAlchemy's decode process, when C extensions are enabled,
|
||||
is orders of magnitude faster than that of MySQLdb as it does not call into
|
||||
Python functions to do so. Therefore, the **recommended URL to use for
|
||||
unicode** will include both charset and use_unicode=0::
|
||||
|
||||
create_engine("mysql+mysqldb://user:pass@host/dbname?charset=utf8&use_unicode=0")
|
||||
|
||||
As of this writing, MySQLdb only runs on Python 2. It is not known how
|
||||
MySQLdb behaves on Python 3 as far as unicode decoding.
|
||||
|
||||
|
||||
Known Issues
|
||||
-------------
|
||||
|
||||
MySQL-python version 1.2.2 has a serious memory leak related
|
||||
to unicode conversion, a feature which is disabled via ``use_unicode=0``.
|
||||
It is strongly advised to use the latest version of MySQL-Python.
|
||||
|
||||
"""
|
||||
|
||||
from .base import (MySQLDialect, MySQLExecutionContext,
|
||||
MySQLCompiler, MySQLIdentifierPreparer)
|
||||
from ...connectors.mysqldb import (
|
||||
MySQLDBExecutionContext,
|
||||
MySQLDBCompiler,
|
||||
MySQLDBIdentifierPreparer,
|
||||
MySQLDBConnector
|
||||
)
|
||||
from .base import TEXT
|
||||
from ... import sql
|
||||
|
||||
class MySQLExecutionContext_mysqldb(MySQLDBExecutionContext, MySQLExecutionContext):
|
||||
pass
|
||||
|
||||
|
||||
class MySQLCompiler_mysqldb(MySQLDBCompiler, MySQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class MySQLIdentifierPreparer_mysqldb(MySQLDBIdentifierPreparer, MySQLIdentifierPreparer):
|
||||
pass
|
||||
|
||||
|
||||
class MySQLDialect_mysqldb(MySQLDBConnector, MySQLDialect):
|
||||
execution_ctx_cls = MySQLExecutionContext_mysqldb
|
||||
statement_compiler = MySQLCompiler_mysqldb
|
||||
preparer = MySQLIdentifierPreparer_mysqldb
|
||||
|
||||
def _check_unicode_returns(self, connection):
|
||||
# work around issue fixed in
|
||||
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
|
||||
# specific issue w/ the utf8_bin collation and unicode returns
|
||||
|
||||
has_utf8_bin = connection.scalar(
|
||||
"show collation where %s = 'utf8' and %s = 'utf8_bin'"
|
||||
% (
|
||||
self.identifier_preparer.quote("Charset"),
|
||||
self.identifier_preparer.quote("Collation")
|
||||
))
|
||||
if has_utf8_bin:
|
||||
additional_tests = [
|
||||
sql.collate(sql.cast(
|
||||
sql.literal_column(
|
||||
"'test collated returns'"),
|
||||
TEXT(charset='utf8')), "utf8_bin")
|
||||
]
|
||||
else:
|
||||
additional_tests = []
|
||||
return super(MySQLDBConnector, self)._check_unicode_returns(
|
||||
connection, additional_tests)
|
||||
|
||||
dialect = MySQLDialect_mysqldb
|
|
@ -1,261 +0,0 @@
|
|||
# mysql/oursql.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: mysql+oursql
|
||||
:name: OurSQL
|
||||
:dbapi: oursql
|
||||
:connectstring: mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: http://packages.python.org/oursql/
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
oursql defaults to using ``utf8`` as the connection charset, but other
|
||||
encodings may be used instead. Like the MySQL-Python driver, unicode support
|
||||
can be completely disabled::
|
||||
|
||||
# oursql sets the connection charset to utf8 automatically; all strings come
|
||||
# back as utf8 str
|
||||
create_engine('mysql+oursql:///mydb?use_unicode=0')
|
||||
|
||||
To not automatically use ``utf8`` and instead use whatever the connection
|
||||
defaults to, there is a separate parameter::
|
||||
|
||||
# use the default connection charset; all strings come back as unicode
|
||||
create_engine('mysql+oursql:///mydb?default_charset=1')
|
||||
|
||||
# use latin1 as the connection charset; all strings come back as unicode
|
||||
create_engine('mysql+oursql:///mydb?charset=latin1')
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from .base import (BIT, MySQLDialect, MySQLExecutionContext)
|
||||
from ... import types as sqltypes, util
|
||||
|
||||
|
||||
class _oursqlBIT(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""oursql already converts mysql bits, so."""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MySQLExecutionContext_oursql(MySQLExecutionContext):
|
||||
|
||||
@property
|
||||
def plain_query(self):
|
||||
return self.execution_options.get('_oursql_plain_query', False)
|
||||
|
||||
|
||||
class MySQLDialect_oursql(MySQLDialect):
|
||||
driver = 'oursql'
|
||||
|
||||
if util.py2k:
|
||||
supports_unicode_binds = True
|
||||
supports_unicode_statements = True
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
execution_ctx_cls = MySQLExecutionContext_oursql
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs,
|
||||
{
|
||||
sqltypes.Time: sqltypes.Time,
|
||||
BIT: _oursqlBIT,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('oursql')
|
||||
|
||||
def do_execute(self, cursor, statement, parameters, context=None):
|
||||
"""Provide an implementation of *cursor.execute(statement, parameters)*."""
|
||||
|
||||
if context and context.plain_query:
|
||||
cursor.execute(statement, plain_query=True)
|
||||
else:
|
||||
cursor.execute(statement, parameters)
|
||||
|
||||
def do_begin(self, connection):
|
||||
connection.cursor().execute('BEGIN', plain_query=True)
|
||||
|
||||
def _xa_query(self, connection, query, xid):
|
||||
if util.py2k:
|
||||
arg = connection.connection._escape_string(xid)
|
||||
else:
|
||||
charset = self._connection_charset
|
||||
arg = connection.connection._escape_string(xid.encode(charset)).decode(charset)
|
||||
arg = "'%s'" % arg
|
||||
connection.execution_options(_oursql_plain_query=True).execute(query % arg)
|
||||
|
||||
# Because mysql is bad, these methods have to be
|
||||
# reimplemented to use _PlainQuery. Basically, some queries
|
||||
# refuse to return any data if they're run through
|
||||
# the parameterized query API, or refuse to be parameterized
|
||||
# in the first place.
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
self._xa_query(connection, 'XA BEGIN %s', xid)
|
||||
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
self._xa_query(connection, 'XA END %s', xid)
|
||||
self._xa_query(connection, 'XA PREPARE %s', xid)
|
||||
|
||||
def do_rollback_twophase(self, connection, xid, is_prepared=True,
|
||||
recover=False):
|
||||
if not is_prepared:
|
||||
self._xa_query(connection, 'XA END %s', xid)
|
||||
self._xa_query(connection, 'XA ROLLBACK %s', xid)
|
||||
|
||||
def do_commit_twophase(self, connection, xid, is_prepared=True,
|
||||
recover=False):
|
||||
if not is_prepared:
|
||||
self.do_prepare_twophase(connection, xid)
|
||||
self._xa_query(connection, 'XA COMMIT %s', xid)
|
||||
|
||||
# Q: why didn't we need all these "plain_query" overrides earlier ?
|
||||
# am i on a newer/older version of OurSQL ?
|
||||
def has_table(self, connection, table_name, schema=None):
|
||||
return MySQLDialect.has_table(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
table_name,
|
||||
schema
|
||||
)
|
||||
|
||||
def get_table_options(self, connection, table_name, schema=None, **kw):
|
||||
return MySQLDialect.get_table_options(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
table_name,
|
||||
schema=schema,
|
||||
**kw
|
||||
)
|
||||
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
return MySQLDialect.get_columns(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
table_name,
|
||||
schema=schema,
|
||||
**kw
|
||||
)
|
||||
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
return MySQLDialect.get_view_names(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
schema=schema,
|
||||
**kw
|
||||
)
|
||||
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
return MySQLDialect.get_table_names(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
schema
|
||||
)
|
||||
|
||||
def get_schema_names(self, connection, **kw):
|
||||
return MySQLDialect.get_schema_names(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
**kw
|
||||
)
|
||||
|
||||
def initialize(self, connection):
|
||||
return MySQLDialect.initialize(
|
||||
self,
|
||||
connection.execution_options(_oursql_plain_query=True)
|
||||
)
|
||||
|
||||
def _show_create_table(self, connection, table, charset=None,
|
||||
full_name=None):
|
||||
return MySQLDialect._show_create_table(
|
||||
self,
|
||||
connection.contextual_connect(close_with_result=True).
|
||||
execution_options(_oursql_plain_query=True),
|
||||
table, charset, full_name
|
||||
)
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.ProgrammingError):
|
||||
return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed')
|
||||
else:
|
||||
return e.errno in (2006, 2013, 2014, 2045, 2055)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(database='db', username='user',
|
||||
password='passwd')
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'port', int)
|
||||
util.coerce_kw_type(opts, 'compress', bool)
|
||||
util.coerce_kw_type(opts, 'autoping', bool)
|
||||
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
|
||||
|
||||
util.coerce_kw_type(opts, 'default_charset', bool)
|
||||
if opts.pop('default_charset', False):
|
||||
opts['charset'] = None
|
||||
else:
|
||||
util.coerce_kw_type(opts, 'charset', str)
|
||||
opts['use_unicode'] = opts.get('use_unicode', True)
|
||||
util.coerce_kw_type(opts, 'use_unicode', bool)
|
||||
|
||||
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
|
||||
# supports_sane_rowcount.
|
||||
opts.setdefault('found_rows', True)
|
||||
|
||||
ssl = {}
|
||||
for key in ['ssl_ca', 'ssl_key', 'ssl_cert',
|
||||
'ssl_capath', 'ssl_cipher']:
|
||||
if key in opts:
|
||||
ssl[key[4:]] = opts[key]
|
||||
util.coerce_kw_type(ssl, key[4:], str)
|
||||
del opts[key]
|
||||
if ssl:
|
||||
opts['ssl'] = ssl
|
||||
|
||||
return [[], opts]
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
for n in r.split(dbapi_con.server_info):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.errno
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
return connection.connection.charset
|
||||
|
||||
def _compat_fetchall(self, rp, charset=None):
|
||||
"""oursql isn't super-broken like MySQLdb, yaaay."""
|
||||
return rp.fetchall()
|
||||
|
||||
def _compat_fetchone(self, rp, charset=None):
|
||||
"""oursql isn't super-broken like MySQLdb, yaaay."""
|
||||
return rp.fetchone()
|
||||
|
||||
def _compat_first(self, rp, charset=None):
|
||||
return rp.first()
|
||||
|
||||
|
||||
dialect = MySQLDialect_oursql
|
|
@ -1,45 +0,0 @@
|
|||
# mysql/pymysql.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: mysql+pymysql
|
||||
:name: PyMySQL
|
||||
:dbapi: pymysql
|
||||
:connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>[?<options>]
|
||||
:url: http://code.google.com/p/pymysql/
|
||||
|
||||
MySQL-Python Compatibility
|
||||
--------------------------
|
||||
|
||||
The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver,
|
||||
and targets 100% compatibility. Most behavioral notes for MySQL-python apply to
|
||||
the pymysql driver as well.
|
||||
|
||||
"""
|
||||
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from ...util import py3k
|
||||
|
||||
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
driver = 'pymysql'
|
||||
|
||||
description_encoding = None
|
||||
if py3k:
|
||||
supports_unicode_statements = True
|
||||
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('pymysql')
|
||||
|
||||
if py3k:
|
||||
def _extract_error_code(self, exception):
|
||||
if isinstance(exception.args[0], Exception):
|
||||
exception = exception.args[0]
|
||||
return exception.args[0]
|
||||
|
||||
dialect = MySQLDialect_pymysql
|
|
@ -1,80 +0,0 @@
|
|||
# mysql/pyodbc.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
|
||||
.. dialect:: mysql+pyodbc
|
||||
:name: PyODBC
|
||||
:dbapi: pyodbc
|
||||
:connectstring: mysql+pyodbc://<username>:<password>@<dsnname>
|
||||
:url: http://pypi.python.org/pypi/pyodbc/
|
||||
|
||||
|
||||
Limitations
|
||||
-----------
|
||||
|
||||
The mysql-pyodbc dialect is subject to unresolved character encoding issues
|
||||
which exist within the current ODBC drivers available.
|
||||
(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage
|
||||
of OurSQL, MySQLdb, or MySQL-connector/Python.
|
||||
|
||||
"""
|
||||
|
||||
from .base import MySQLDialect, MySQLExecutionContext
|
||||
from ...connectors.pyodbc import PyODBCConnector
|
||||
from ... import util
|
||||
import re
|
||||
|
||||
|
||||
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
|
||||
|
||||
def get_lastrowid(self):
|
||||
cursor = self.create_cursor()
|
||||
cursor.execute("SELECT LAST_INSERT_ID()")
|
||||
lastrowid = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
return lastrowid
|
||||
|
||||
|
||||
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
supports_unicode_statements = False
|
||||
execution_ctx_cls = MySQLExecutionContext_pyodbc
|
||||
|
||||
pyodbc_driver_name = "MySQL"
|
||||
|
||||
def __init__(self, **kw):
|
||||
# deal with http://code.google.com/p/pyodbc/issues/detail?id=25
|
||||
kw.setdefault('convert_unicode', True)
|
||||
super(MySQLDialect_pyodbc, self).__init__(**kw)
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
# Prefer 'character_set_results' for the current connection over the
|
||||
# value in the driver. SET NAMES or individual variable SETs will
|
||||
# change the charset without updating the driver's view of the world.
|
||||
#
|
||||
# If it's decided that issuing that sort of SQL leaves you SOL, then
|
||||
# this can prefer the driver value.
|
||||
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
|
||||
opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
|
||||
for key in ('character_set_connection', 'character_set'):
|
||||
if opts.get(key, None):
|
||||
return opts[key]
|
||||
|
||||
util.warn("Could not detect the connection character set. Assuming latin1.")
|
||||
return 'latin1'
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
m = re.compile(r"\((\d+)\)").search(str(exception.args))
|
||||
c = m.group(1)
|
||||
if c:
|
||||
return int(c)
|
||||
else:
|
||||
return None
|
||||
|
||||
dialect = MySQLDialect_pyodbc
|
|
@ -1,111 +0,0 @@
|
|||
# mysql/zxjdbc.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: mysql+zxjdbc
|
||||
:name: zxjdbc for Jython
|
||||
:dbapi: zxjdbc
|
||||
:connectstring: mysql+zxjdbc://<user>:<password>@<hostname>[:<port>]/<database>
|
||||
:driverurl: http://dev.mysql.com/downloads/connector/j/
|
||||
|
||||
Character Sets
|
||||
--------------
|
||||
|
||||
SQLAlchemy zxjdbc dialects pass unicode straight through to the
|
||||
zxjdbc/JDBC layer. To allow multiple character sets to be sent from the
|
||||
MySQL Connector/J JDBC driver, by default SQLAlchemy sets its
|
||||
``characterEncoding`` connection property to ``UTF-8``. It may be
|
||||
overriden via a ``create_engine`` URL parameter.
|
||||
|
||||
"""
|
||||
import re
|
||||
|
||||
from ... import types as sqltypes, util
|
||||
from ...connectors.zxJDBC import ZxJDBCConnector
|
||||
from .base import BIT, MySQLDialect, MySQLExecutionContext
|
||||
|
||||
|
||||
class _ZxJDBCBit(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Converts boolean or byte arrays from MySQL Connector/J to longs."""
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
v = 0
|
||||
for i in value:
|
||||
v = v << 8 | (i & 0xff)
|
||||
value = v
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class MySQLExecutionContext_zxjdbc(MySQLExecutionContext):
|
||||
def get_lastrowid(self):
|
||||
cursor = self.create_cursor()
|
||||
cursor.execute("SELECT LAST_INSERT_ID()")
|
||||
lastrowid = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
return lastrowid
|
||||
|
||||
|
||||
class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
|
||||
jdbc_db_name = 'mysql'
|
||||
jdbc_driver_name = 'com.mysql.jdbc.Driver'
|
||||
|
||||
execution_ctx_cls = MySQLExecutionContext_zxjdbc
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs,
|
||||
{
|
||||
sqltypes.Time: sqltypes.Time,
|
||||
BIT: _ZxJDBCBit
|
||||
}
|
||||
)
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
# Prefer 'character_set_results' for the current connection over the
|
||||
# value in the driver. SET NAMES or individual variable SETs will
|
||||
# change the charset without updating the driver's view of the world.
|
||||
#
|
||||
# If it's decided that issuing that sort of SQL leaves you SOL, then
|
||||
# this can prefer the driver value.
|
||||
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
|
||||
opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs))
|
||||
for key in ('character_set_connection', 'character_set'):
|
||||
if opts.get(key, None):
|
||||
return opts[key]
|
||||
|
||||
util.warn("Could not detect the connection character set. Assuming latin1.")
|
||||
return 'latin1'
|
||||
|
||||
def _driver_kwargs(self):
|
||||
"""return kw arg dict to be sent to connect()."""
|
||||
return dict(characterEncoding='UTF-8', yearIsDateType='false')
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
# e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
|
||||
# [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
|
||||
m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.args))
|
||||
c = m.group(1)
|
||||
if c:
|
||||
return int(c)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
for n in r.split(dbapi_con.dbversion):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
dialect = MySQLDialect_zxjdbc
|
|
@ -1,23 +0,0 @@
|
|||
# oracle/__init__.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc
|
||||
|
||||
base.dialect = cx_oracle.dialect
|
||||
|
||||
from sqlalchemy.dialects.oracle.base import \
|
||||
VARCHAR, NVARCHAR, CHAR, DATE, NUMBER,\
|
||||
BLOB, BFILE, CLOB, NCLOB, TIMESTAMP, RAW,\
|
||||
FLOAT, DOUBLE_PRECISION, LONG, dialect, INTERVAL,\
|
||||
VARCHAR2, NVARCHAR2, ROWID, dialect
|
||||
|
||||
|
||||
__all__ = (
|
||||
'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'NUMBER',
|
||||
'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW',
|
||||
'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL',
|
||||
'VARCHAR2', 'NVARCHAR2', 'ROWID'
|
||||
)
|
File diff suppressed because it is too large
Load diff
|
@ -1,941 +0,0 @@
|
|||
# oracle/cx_oracle.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: oracle+cx_oracle
|
||||
:name: cx-Oracle
|
||||
:dbapi: cx_oracle
|
||||
:connectstring: oracle+cx_oracle://user:pass@host:port/dbname[?key=value&key=value...]
|
||||
:url: http://cx-oracle.sourceforge.net/
|
||||
|
||||
Additional Connect Arguments
|
||||
----------------------------
|
||||
|
||||
When connecting with ``dbname`` present, the host, port, and dbname tokens are
|
||||
converted to a TNS name using
|
||||
the cx_oracle ``makedsn()`` function. Otherwise, the host token is taken
|
||||
directly as a TNS name.
|
||||
|
||||
Additional arguments which may be specified either as query string arguments
|
||||
on the URL, or as keyword arguments to :func:`.create_engine()` are:
|
||||
|
||||
* ``allow_twophase`` - enable two-phase transactions. Defaults to ``True``.
|
||||
|
||||
* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted
|
||||
to 50. This setting is significant with cx_Oracle as the contents of LOB
|
||||
objects are only readable within a "live" row (e.g. within a batch of
|
||||
50 rows).
|
||||
|
||||
* ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`.
|
||||
|
||||
* ``auto_setinputsizes`` - the cx_oracle.setinputsizes() call is issued for
|
||||
all bind parameters. This is required for LOB datatypes but can be
|
||||
disabled to reduce overhead. Defaults to ``True``. Specific types
|
||||
can be excluded from this process using the ``exclude_setinputsizes``
|
||||
parameter.
|
||||
|
||||
* ``coerce_to_unicode`` - see :ref:`cx_oracle_unicode` for detail.
|
||||
|
||||
* ``coerce_to_decimal`` - see :ref:`cx_oracle_numeric` for detail.
|
||||
|
||||
* ``exclude_setinputsizes`` - a tuple or list of string DBAPI type names to
|
||||
be excluded from the "auto setinputsizes" feature. The type names here
|
||||
must match DBAPI types that are found in the "cx_Oracle" module namespace,
|
||||
such as cx_Oracle.UNICODE, cx_Oracle.NCLOB, etc. Defaults to
|
||||
``(STRING, UNICODE)``.
|
||||
|
||||
.. versionadded:: 0.8 specific DBAPI types can be excluded from the
|
||||
auto_setinputsizes feature via the exclude_setinputsizes attribute.
|
||||
|
||||
* ``mode`` - This is given the string value of SYSDBA or SYSOPER, or alternatively
|
||||
an integer value. This value is only available as a URL query string
|
||||
argument.
|
||||
|
||||
* ``threaded`` - enable multithreaded access to cx_oracle connections. Defaults
|
||||
to ``True``. Note that this is the opposite default of the cx_Oracle DBAPI
|
||||
itself.
|
||||
|
||||
.. _cx_oracle_unicode:
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
The cx_Oracle DBAPI as of version 5 fully supports unicode, and has the ability
|
||||
to return string results as Python unicode objects natively.
|
||||
|
||||
When used in Python 3, cx_Oracle returns all strings as Python unicode objects
|
||||
(that is, plain ``str`` in Python 3). In Python 2, it will return as Python
|
||||
unicode those column values that are of type ``NVARCHAR`` or ``NCLOB``. For
|
||||
column values that are of type ``VARCHAR`` or other non-unicode string types,
|
||||
it will return values as Python strings (e.g. bytestrings).
|
||||
|
||||
The cx_Oracle SQLAlchemy dialect presents two different options for the use case of
|
||||
returning ``VARCHAR`` column values as Python unicode objects under Python 2:
|
||||
|
||||
* the cx_Oracle DBAPI has the ability to coerce all string results to Python
|
||||
unicode objects unconditionally using output type handlers. This has
|
||||
the advantage that the unicode conversion is global to all statements
|
||||
at the cx_Oracle driver level, meaning it works with raw textual SQL
|
||||
statements that have no typing information associated. However, this system
|
||||
has been observed to incur signfiicant performance overhead, not only because
|
||||
it takes effect for all string values unconditionally, but also because cx_Oracle under
|
||||
Python 2 seems to use a pure-Python function call in order to do the
|
||||
decode operation, which under cPython can orders of magnitude slower
|
||||
than doing it using C functions alone.
|
||||
|
||||
* SQLAlchemy has unicode-decoding services built in, and when using SQLAlchemy's
|
||||
C extensions, these functions do not use any Python function calls and
|
||||
are very fast. The disadvantage to this approach is that the unicode
|
||||
conversion only takes effect for statements where the :class:`.Unicode` type
|
||||
or :class:`.String` type with ``convert_unicode=True`` is explicitly
|
||||
associated with the result column. This is the case for any ORM or Core
|
||||
query or SQL expression as well as for a :func:`.text` construct that specifies
|
||||
output column types, so in the vast majority of cases this is not an issue.
|
||||
However, when sending a completely raw string to :meth:`.Connection.execute`,
|
||||
this typing information isn't present, unless the string is handled
|
||||
within a :func:`.text` construct that adds typing information.
|
||||
|
||||
As of version 0.9.2 of SQLAlchemy, the default approach is to use SQLAlchemy's
|
||||
typing system. This keeps cx_Oracle's expensive Python 2 approach
|
||||
disabled unless the user explicitly wants it. Under Python 3, SQLAlchemy detects
|
||||
that cx_Oracle is returning unicode objects natively and cx_Oracle's system
|
||||
is used.
|
||||
|
||||
To re-enable cx_Oracle's output type handler under Python 2, the
|
||||
``coerce_to_unicode=True`` flag (new in 0.9.4) can be passed to
|
||||
:func:`.create_engine`::
|
||||
|
||||
engine = create_engine("oracle+cx_oracle://dsn", coerce_to_unicode=True)
|
||||
|
||||
Alternatively, to run a pure string SQL statement and get ``VARCHAR`` results
|
||||
as Python unicode under Python 2 without using cx_Oracle's native handlers,
|
||||
the :func:`.text` feature can be used::
|
||||
|
||||
from sqlalchemy import text, Unicode
|
||||
result = conn.execute(text("select username from user").columns(username=Unicode))
|
||||
|
||||
.. versionchanged:: 0.9.2 cx_Oracle's outputtypehandlers are no longer used for
|
||||
unicode results of non-unicode datatypes in Python 2, after they were identified as a major
|
||||
performance bottleneck. SQLAlchemy's own unicode facilities are used
|
||||
instead.
|
||||
|
||||
.. versionadded:: 0.9.4 Added the ``coerce_to_unicode`` flag, to re-enable
|
||||
cx_Oracle's outputtypehandler and revert to pre-0.9.2 behavior.
|
||||
|
||||
.. _cx_oracle_returning:
|
||||
|
||||
RETURNING Support
|
||||
-----------------
|
||||
|
||||
The cx_oracle DBAPI supports a limited subset of Oracle's already limited RETURNING support.
|
||||
Typically, results can only be guaranteed for at most one column being returned;
|
||||
this is the typical case when SQLAlchemy uses RETURNING to get just the value of a
|
||||
primary-key-associated sequence value. Additional column expressions will
|
||||
cause problems in a non-determinative way, due to cx_oracle's lack of support for
|
||||
the OCI_DATA_AT_EXEC API which is required for more complex RETURNING scenarios.
|
||||
|
||||
For this reason, stability may be enhanced by disabling RETURNING support completely;
|
||||
SQLAlchemy otherwise will use RETURNING to fetch newly sequence-generated
|
||||
primary keys. As illustrated in :ref:`oracle_returning`::
|
||||
|
||||
engine = create_engine("oracle://scott:tiger@dsn", implicit_returning=False)
|
||||
|
||||
.. seealso::
|
||||
|
||||
http://docs.oracle.com/cd/B10501_01/appdev.920/a96584/oci05bnd.htm#420693 - OCI documentation for RETURNING
|
||||
|
||||
http://sourceforge.net/mailarchive/message.php?msg_id=31338136 - cx_oracle developer commentary
|
||||
|
||||
.. _cx_oracle_lob:
|
||||
|
||||
LOB Objects
|
||||
-----------
|
||||
|
||||
cx_oracle returns oracle LOBs using the cx_oracle.LOB object. SQLAlchemy converts
|
||||
these to strings so that the interface of the Binary type is consistent with that of
|
||||
other backends, and so that the linkage to a live cursor is not needed in scenarios
|
||||
like result.fetchmany() and result.fetchall(). This means that by default, LOB
|
||||
objects are fully fetched unconditionally by SQLAlchemy, and the linkage to a live
|
||||
cursor is broken.
|
||||
|
||||
To disable this processing, pass ``auto_convert_lobs=False`` to :func:`.create_engine()`.
|
||||
|
||||
Two Phase Transaction Support
|
||||
-----------------------------
|
||||
|
||||
Two Phase transactions are implemented using XA transactions, and are known
|
||||
to work in a rudimental fashion with recent versions of cx_Oracle
|
||||
as of SQLAlchemy 0.8.0b2, 0.7.10. However, the mechanism is not yet
|
||||
considered to be robust and should still be regarded as experimental.
|
||||
|
||||
In particular, the cx_Oracle DBAPI as recently as 5.1.2 has a bug regarding
|
||||
two phase which prevents
|
||||
a particular DBAPI connection from being consistently usable in both
|
||||
prepared transactions as well as traditional DBAPI usage patterns; therefore
|
||||
once a particular connection is used via :meth:`.Connection.begin_prepared`,
|
||||
all subsequent usages of the underlying DBAPI connection must be within
|
||||
the context of prepared transactions.
|
||||
|
||||
The default behavior of :class:`.Engine` is to maintain a pool of DBAPI
|
||||
connections. Therefore, due to the above glitch, a DBAPI connection that has
|
||||
been used in a two-phase operation, and is then returned to the pool, will
|
||||
not be usable in a non-two-phase context. To avoid this situation,
|
||||
the application can make one of several choices:
|
||||
|
||||
* Disable connection pooling using :class:`.NullPool`
|
||||
|
||||
* Ensure that the particular :class:`.Engine` in use is only used
|
||||
for two-phase operations. A :class:`.Engine` bound to an ORM
|
||||
:class:`.Session` which includes ``twophase=True`` will consistently
|
||||
use the two-phase transaction style.
|
||||
|
||||
* For ad-hoc two-phase operations without disabling pooling, the DBAPI
|
||||
connection in use can be evicted from the connection pool using the
|
||||
:meth:`.Connection.detach` method.
|
||||
|
||||
.. versionchanged:: 0.8.0b2,0.7.10
|
||||
Support for cx_oracle prepared transactions has been implemented
|
||||
and tested.
|
||||
|
||||
.. _cx_oracle_numeric:
|
||||
|
||||
Precision Numerics
|
||||
------------------
|
||||
|
||||
The SQLAlchemy dialect goes through a lot of steps to ensure
|
||||
that decimal numbers are sent and received with full accuracy.
|
||||
An "outputtypehandler" callable is associated with each
|
||||
cx_oracle connection object which detects numeric types and
|
||||
receives them as string values, instead of receiving a Python
|
||||
``float`` directly, which is then passed to the Python
|
||||
``Decimal`` constructor. The :class:`.Numeric` and
|
||||
:class:`.Float` types under the cx_oracle dialect are aware of
|
||||
this behavior, and will coerce the ``Decimal`` to ``float`` if
|
||||
the ``asdecimal`` flag is ``False`` (default on :class:`.Float`,
|
||||
optional on :class:`.Numeric`).
|
||||
|
||||
Because the handler coerces to ``Decimal`` in all cases first,
|
||||
the feature can detract significantly from performance.
|
||||
If precision numerics aren't required, the decimal handling
|
||||
can be disabled by passing the flag ``coerce_to_decimal=False``
|
||||
to :func:`.create_engine`::
|
||||
|
||||
engine = create_engine("oracle+cx_oracle://dsn", coerce_to_decimal=False)
|
||||
|
||||
.. versionadded:: 0.7.6
|
||||
Add the ``coerce_to_decimal`` flag.
|
||||
|
||||
Another alternative to performance is to use the
|
||||
`cdecimal <http://pypi.python.org/pypi/cdecimal/>`_ library;
|
||||
see :class:`.Numeric` for additional notes.
|
||||
|
||||
The handler attempts to use the "precision" and "scale"
|
||||
attributes of the result set column to best determine if
|
||||
subsequent incoming values should be received as ``Decimal`` as
|
||||
opposed to int (in which case no processing is added). There are
|
||||
several scenarios where OCI_ does not provide unambiguous data
|
||||
as to the numeric type, including some situations where
|
||||
individual rows may return a combination of floating point and
|
||||
integer values. Certain values for "precision" and "scale" have
|
||||
been observed to determine this scenario. When it occurs, the
|
||||
outputtypehandler receives as string and then passes off to a
|
||||
processing function which detects, for each returned value, if a
|
||||
decimal point is present, and if so converts to ``Decimal``,
|
||||
otherwise to int. The intention is that simple int-based
|
||||
statements like "SELECT my_seq.nextval() FROM DUAL" continue to
|
||||
return ints and not ``Decimal`` objects, and that any kind of
|
||||
floating point value is received as a string so that there is no
|
||||
floating point loss of precision.
|
||||
|
||||
The "decimal point is present" logic itself is also sensitive to
|
||||
locale. Under OCI_, this is controlled by the NLS_LANG
|
||||
environment variable. Upon first connection, the dialect runs a
|
||||
test to determine the current "decimal" character, which can be
|
||||
a comma "," for european locales. From that point forward the
|
||||
outputtypehandler uses that character to represent a decimal
|
||||
point. Note that cx_oracle 5.0.3 or greater is required
|
||||
when dealing with numerics with locale settings that don't use
|
||||
a period "." as the decimal character.
|
||||
|
||||
.. versionchanged:: 0.6.6
|
||||
The outputtypehandler supports the case where the locale uses a
|
||||
comma "," character to represent a decimal point.
|
||||
|
||||
.. _OCI: http://www.oracle.com/technetwork/database/features/oci/index.html
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .base import OracleCompiler, OracleDialect, OracleExecutionContext
|
||||
from . import base as oracle
|
||||
from ...engine import result as _result
|
||||
from sqlalchemy import types as sqltypes, util, exc, processors
|
||||
import random
|
||||
import collections
|
||||
import decimal
|
||||
import re
|
||||
|
||||
|
||||
class _OracleNumeric(sqltypes.Numeric):
|
||||
def bind_processor(self, dialect):
|
||||
# cx_oracle accepts Decimal objects and floats
|
||||
return None
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
# we apply a cx_oracle type handler to all connections
|
||||
# that converts floating point strings to Decimal().
|
||||
# However, in some subquery situations, Oracle doesn't
|
||||
# give us enough information to determine int or Decimal.
|
||||
# It could even be int/Decimal differently on each row,
|
||||
# regardless of the scale given for the originating type.
|
||||
# So we still need an old school isinstance() handler
|
||||
# here for decimals.
|
||||
|
||||
if dialect.supports_native_decimal:
|
||||
if self.asdecimal:
|
||||
fstring = "%%.%df" % self._effective_decimal_return_scale
|
||||
|
||||
def to_decimal(value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, decimal.Decimal):
|
||||
return value
|
||||
else:
|
||||
return decimal.Decimal(fstring % value)
|
||||
|
||||
return to_decimal
|
||||
else:
|
||||
if self.precision is None and self.scale is None:
|
||||
return processors.to_float
|
||||
elif not getattr(self, '_is_oracle_number', False) \
|
||||
and self.scale is not None:
|
||||
return processors.to_float
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
# cx_oracle 4 behavior, will assume
|
||||
# floats
|
||||
return super(_OracleNumeric, self).\
|
||||
result_processor(dialect, coltype)
|
||||
|
||||
|
||||
class _OracleDate(sqltypes.Date):
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return value.date()
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class _LOBMixin(object):
|
||||
def result_processor(self, dialect, coltype):
|
||||
if not dialect.auto_convert_lobs:
|
||||
# return the cx_oracle.LOB directly.
|
||||
return None
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return value.read()
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class _NativeUnicodeMixin(object):
|
||||
if util.py2k:
|
||||
def bind_processor(self, dialect):
|
||||
if dialect._cx_oracle_with_unicode:
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
return unicode(value)
|
||||
return process
|
||||
else:
|
||||
return super(_NativeUnicodeMixin, self).bind_processor(dialect)
|
||||
|
||||
# we apply a connection output handler that returns
|
||||
# unicode in all cases, so the "native_unicode" flag
|
||||
# will be set for the default String.result_processor.
|
||||
|
||||
|
||||
class _OracleChar(_NativeUnicodeMixin, sqltypes.CHAR):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.FIXED_CHAR
|
||||
|
||||
|
||||
class _OracleNVarChar(_NativeUnicodeMixin, sqltypes.NVARCHAR):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return getattr(dbapi, 'UNICODE', dbapi.STRING)
|
||||
|
||||
|
||||
class _OracleText(_LOBMixin, sqltypes.Text):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.CLOB
|
||||
|
||||
|
||||
class _OracleLong(oracle.LONG):
|
||||
# a raw LONG is a text type, but does *not*
|
||||
# get the LobMixin with cx_oracle.
|
||||
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.LONG_STRING
|
||||
|
||||
class _OracleString(_NativeUnicodeMixin, sqltypes.String):
|
||||
pass
|
||||
|
||||
|
||||
class _OracleUnicodeText(_LOBMixin, _NativeUnicodeMixin, sqltypes.UnicodeText):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.NCLOB
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
lob_processor = _LOBMixin.result_processor(self, dialect, coltype)
|
||||
if lob_processor is None:
|
||||
return None
|
||||
|
||||
string_processor = sqltypes.UnicodeText.result_processor(self, dialect, coltype)
|
||||
|
||||
if string_processor is None:
|
||||
return lob_processor
|
||||
else:
|
||||
def process(value):
|
||||
return string_processor(lob_processor(value))
|
||||
return process
|
||||
|
||||
|
||||
class _OracleInteger(sqltypes.Integer):
|
||||
def result_processor(self, dialect, coltype):
|
||||
def to_int(val):
|
||||
if val is not None:
|
||||
val = int(val)
|
||||
return val
|
||||
return to_int
|
||||
|
||||
|
||||
class _OracleBinary(_LOBMixin, sqltypes.LargeBinary):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.BLOB
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
|
||||
class _OracleInterval(oracle.INTERVAL):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.INTERVAL
|
||||
|
||||
|
||||
class _OracleRaw(oracle.RAW):
|
||||
pass
|
||||
|
||||
|
||||
class _OracleRowid(oracle.ROWID):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.ROWID
|
||||
|
||||
|
||||
class OracleCompiler_cx_oracle(OracleCompiler):
|
||||
def bindparam_string(self, name, **kw):
|
||||
quote = getattr(name, 'quote', None)
|
||||
if quote is True or quote is not False and \
|
||||
self.preparer._bindparam_requires_quotes(name):
|
||||
quoted_name = '"%s"' % name
|
||||
self._quoted_bind_names[name] = quoted_name
|
||||
return OracleCompiler.bindparam_string(self, quoted_name, **kw)
|
||||
else:
|
||||
return OracleCompiler.bindparam_string(self, name, **kw)
|
||||
|
||||
|
||||
class OracleExecutionContext_cx_oracle(OracleExecutionContext):
|
||||
|
||||
def pre_exec(self):
|
||||
quoted_bind_names = \
|
||||
getattr(self.compiled, '_quoted_bind_names', None)
|
||||
if quoted_bind_names:
|
||||
if not self.dialect.supports_unicode_statements:
|
||||
# if DBAPI doesn't accept unicode statements,
|
||||
# keys in self.parameters would have been encoded
|
||||
# here. so convert names in quoted_bind_names
|
||||
# to encoded as well.
|
||||
quoted_bind_names = \
|
||||
dict(
|
||||
(fromname.encode(self.dialect.encoding),
|
||||
toname.encode(self.dialect.encoding))
|
||||
for fromname, toname in
|
||||
quoted_bind_names.items()
|
||||
)
|
||||
for param in self.parameters:
|
||||
for fromname, toname in quoted_bind_names.items():
|
||||
param[toname] = param[fromname]
|
||||
del param[fromname]
|
||||
|
||||
if self.dialect.auto_setinputsizes:
|
||||
# cx_oracle really has issues when you setinputsizes
|
||||
# on String, including that outparams/RETURNING
|
||||
# breaks for varchars
|
||||
self.set_input_sizes(quoted_bind_names,
|
||||
exclude_types=self.dialect.exclude_setinputsizes
|
||||
)
|
||||
|
||||
# if a single execute, check for outparams
|
||||
if len(self.compiled_parameters) == 1:
|
||||
for bindparam in self.compiled.binds.values():
|
||||
if bindparam.isoutparam:
|
||||
dbtype = bindparam.type.dialect_impl(self.dialect).\
|
||||
get_dbapi_type(self.dialect.dbapi)
|
||||
if not hasattr(self, 'out_parameters'):
|
||||
self.out_parameters = {}
|
||||
if dbtype is None:
|
||||
raise exc.InvalidRequestError(
|
||||
"Cannot create out parameter for parameter "
|
||||
"%r - it's type %r is not supported by"
|
||||
" cx_oracle" %
|
||||
(bindparam.key, bindparam.type)
|
||||
)
|
||||
name = self.compiled.bind_names[bindparam]
|
||||
self.out_parameters[name] = self.cursor.var(dbtype)
|
||||
self.parameters[0][quoted_bind_names.get(name, name)] = \
|
||||
self.out_parameters[name]
|
||||
|
||||
def create_cursor(self):
|
||||
c = self._dbapi_connection.cursor()
|
||||
if self.dialect.arraysize:
|
||||
c.arraysize = self.dialect.arraysize
|
||||
|
||||
return c
|
||||
|
||||
def get_result_proxy(self):
|
||||
if hasattr(self, 'out_parameters') and self.compiled.returning:
|
||||
returning_params = dict(
|
||||
(k, v.getvalue())
|
||||
for k, v in self.out_parameters.items()
|
||||
)
|
||||
return ReturningResultProxy(self, returning_params)
|
||||
|
||||
result = None
|
||||
if self.cursor.description is not None:
|
||||
for column in self.cursor.description:
|
||||
type_code = column[1]
|
||||
if type_code in self.dialect._cx_oracle_binary_types:
|
||||
result = _result.BufferedColumnResultProxy(self)
|
||||
|
||||
if result is None:
|
||||
result = _result.ResultProxy(self)
|
||||
|
||||
if hasattr(self, 'out_parameters'):
|
||||
if self.compiled_parameters is not None and \
|
||||
len(self.compiled_parameters) == 1:
|
||||
result.out_parameters = out_parameters = {}
|
||||
|
||||
for bind, name in self.compiled.bind_names.items():
|
||||
if name in self.out_parameters:
|
||||
type = bind.type
|
||||
impl_type = type.dialect_impl(self.dialect)
|
||||
dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi)
|
||||
result_processor = impl_type.\
|
||||
result_processor(self.dialect,
|
||||
dbapi_type)
|
||||
if result_processor is not None:
|
||||
out_parameters[name] = \
|
||||
result_processor(self.out_parameters[name].getvalue())
|
||||
else:
|
||||
out_parameters[name] = self.out_parameters[name].getvalue()
|
||||
else:
|
||||
result.out_parameters = dict(
|
||||
(k, v.getvalue())
|
||||
for k, v in self.out_parameters.items()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class OracleExecutionContext_cx_oracle_with_unicode(OracleExecutionContext_cx_oracle):
|
||||
"""Support WITH_UNICODE in Python 2.xx.
|
||||
|
||||
WITH_UNICODE allows cx_Oracle's Python 3 unicode handling
|
||||
behavior under Python 2.x. This mode in some cases disallows
|
||||
and in other cases silently passes corrupted data when
|
||||
non-Python-unicode strings (a.k.a. plain old Python strings)
|
||||
are passed as arguments to connect(), the statement sent to execute(),
|
||||
or any of the bind parameter keys or values sent to execute().
|
||||
This optional context therefore ensures that all statements are
|
||||
passed as Python unicode objects.
|
||||
|
||||
"""
|
||||
def __init__(self, *arg, **kw):
|
||||
OracleExecutionContext_cx_oracle.__init__(self, *arg, **kw)
|
||||
self.statement = util.text_type(self.statement)
|
||||
|
||||
def _execute_scalar(self, stmt):
|
||||
return super(OracleExecutionContext_cx_oracle_with_unicode, self).\
|
||||
_execute_scalar(util.text_type(stmt))
|
||||
|
||||
|
||||
class ReturningResultProxy(_result.FullyBufferedResultProxy):
|
||||
"""Result proxy which stuffs the _returning clause + outparams into the fetch."""
|
||||
|
||||
def __init__(self, context, returning_params):
|
||||
self._returning_params = returning_params
|
||||
super(ReturningResultProxy, self).__init__(context)
|
||||
|
||||
def _cursor_description(self):
|
||||
returning = self.context.compiled.returning
|
||||
return [
|
||||
("ret_%d" % i, None)
|
||||
for i, col in enumerate(returning)
|
||||
]
|
||||
|
||||
def _buffer_rows(self):
|
||||
return collections.deque([tuple(self._returning_params["ret_%d" % i]
|
||||
for i, c in enumerate(self._returning_params))])
|
||||
|
||||
|
||||
class OracleDialect_cx_oracle(OracleDialect):
|
||||
execution_ctx_cls = OracleExecutionContext_cx_oracle
|
||||
statement_compiler = OracleCompiler_cx_oracle
|
||||
|
||||
driver = "cx_oracle"
|
||||
|
||||
colspecs = colspecs = {
|
||||
sqltypes.Numeric: _OracleNumeric,
|
||||
sqltypes.Date: _OracleDate, # generic type, assume datetime.date is desired
|
||||
sqltypes.LargeBinary: _OracleBinary,
|
||||
sqltypes.Boolean: oracle._OracleBoolean,
|
||||
sqltypes.Interval: _OracleInterval,
|
||||
oracle.INTERVAL: _OracleInterval,
|
||||
sqltypes.Text: _OracleText,
|
||||
sqltypes.String: _OracleString,
|
||||
sqltypes.UnicodeText: _OracleUnicodeText,
|
||||
sqltypes.CHAR: _OracleChar,
|
||||
|
||||
# a raw LONG is a text type, but does *not*
|
||||
# get the LobMixin with cx_oracle.
|
||||
oracle.LONG: _OracleLong,
|
||||
|
||||
# this is only needed for OUT parameters.
|
||||
# it would be nice if we could not use it otherwise.
|
||||
sqltypes.Integer: _OracleInteger,
|
||||
|
||||
oracle.RAW: _OracleRaw,
|
||||
sqltypes.Unicode: _OracleNVarChar,
|
||||
sqltypes.NVARCHAR: _OracleNVarChar,
|
||||
oracle.ROWID: _OracleRowid,
|
||||
}
|
||||
|
||||
execute_sequence_format = list
|
||||
|
||||
def __init__(self,
|
||||
auto_setinputsizes=True,
|
||||
exclude_setinputsizes=("STRING", "UNICODE"),
|
||||
auto_convert_lobs=True,
|
||||
threaded=True,
|
||||
allow_twophase=True,
|
||||
coerce_to_decimal=True,
|
||||
coerce_to_unicode=False,
|
||||
arraysize=50, **kwargs):
|
||||
OracleDialect.__init__(self, **kwargs)
|
||||
self.threaded = threaded
|
||||
self.arraysize = arraysize
|
||||
self.allow_twophase = allow_twophase
|
||||
self.supports_timestamp = self.dbapi is None or \
|
||||
hasattr(self.dbapi, 'TIMESTAMP')
|
||||
self.auto_setinputsizes = auto_setinputsizes
|
||||
self.auto_convert_lobs = auto_convert_lobs
|
||||
|
||||
if hasattr(self.dbapi, 'version'):
|
||||
self.cx_oracle_ver = tuple([int(x) for x in
|
||||
self.dbapi.version.split('.')])
|
||||
else:
|
||||
self.cx_oracle_ver = (0, 0, 0)
|
||||
|
||||
def types(*names):
|
||||
return set(
|
||||
getattr(self.dbapi, name, None) for name in names
|
||||
).difference([None])
|
||||
|
||||
self.exclude_setinputsizes = types(*(exclude_setinputsizes or ()))
|
||||
self._cx_oracle_string_types = types("STRING", "UNICODE",
|
||||
"NCLOB", "CLOB")
|
||||
self._cx_oracle_unicode_types = types("UNICODE", "NCLOB")
|
||||
self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB")
|
||||
self.supports_unicode_binds = self.cx_oracle_ver >= (5, 0)
|
||||
|
||||
self.coerce_to_unicode = (
|
||||
self.cx_oracle_ver >= (5, 0) and
|
||||
coerce_to_unicode
|
||||
)
|
||||
|
||||
self.supports_native_decimal = (
|
||||
self.cx_oracle_ver >= (5, 0) and
|
||||
coerce_to_decimal
|
||||
)
|
||||
|
||||
self._cx_oracle_native_nvarchar = self.cx_oracle_ver >= (5, 0)
|
||||
|
||||
if self.cx_oracle_ver is None:
|
||||
# this occurs in tests with mock DBAPIs
|
||||
self._cx_oracle_string_types = set()
|
||||
self._cx_oracle_with_unicode = False
|
||||
elif self.cx_oracle_ver >= (5,) and not hasattr(self.dbapi, 'UNICODE'):
|
||||
# cx_Oracle WITH_UNICODE mode. *only* python
|
||||
# unicode objects accepted for anything
|
||||
self.supports_unicode_statements = True
|
||||
self.supports_unicode_binds = True
|
||||
self._cx_oracle_with_unicode = True
|
||||
|
||||
if util.py2k:
|
||||
# There's really no reason to run with WITH_UNICODE under Python 2.x.
|
||||
# Give the user a hint.
|
||||
util.warn(
|
||||
"cx_Oracle is compiled under Python 2.xx using the "
|
||||
"WITH_UNICODE flag. Consider recompiling cx_Oracle "
|
||||
"without this flag, which is in no way necessary for full "
|
||||
"support of Unicode. Otherwise, all string-holding bind "
|
||||
"parameters must be explicitly typed using SQLAlchemy's "
|
||||
"String type or one of its subtypes,"
|
||||
"or otherwise be passed as Python unicode. "
|
||||
"Plain Python strings passed as bind parameters will be "
|
||||
"silently corrupted by cx_Oracle."
|
||||
)
|
||||
self.execution_ctx_cls = \
|
||||
OracleExecutionContext_cx_oracle_with_unicode
|
||||
else:
|
||||
self._cx_oracle_with_unicode = False
|
||||
|
||||
if self.cx_oracle_ver is None or \
|
||||
not self.auto_convert_lobs or \
|
||||
not hasattr(self.dbapi, 'CLOB'):
|
||||
self.dbapi_type_map = {}
|
||||
else:
|
||||
# only use this for LOB objects. using it for strings, dates
|
||||
# etc. leads to a little too much magic, reflection doesn't know if it should
|
||||
# expect encoded strings or unicodes, etc.
|
||||
self.dbapi_type_map = {
|
||||
self.dbapi.CLOB: oracle.CLOB(),
|
||||
self.dbapi.NCLOB: oracle.NCLOB(),
|
||||
self.dbapi.BLOB: oracle.BLOB(),
|
||||
self.dbapi.BINARY: oracle.RAW(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
import cx_Oracle
|
||||
return cx_Oracle
|
||||
|
||||
def initialize(self, connection):
|
||||
super(OracleDialect_cx_oracle, self).initialize(connection)
|
||||
if self._is_oracle_8:
|
||||
self.supports_unicode_binds = False
|
||||
self._detect_decimal_char(connection)
|
||||
|
||||
def _detect_decimal_char(self, connection):
|
||||
"""detect if the decimal separator character is not '.', as
|
||||
is the case with european locale settings for NLS_LANG.
|
||||
|
||||
cx_oracle itself uses similar logic when it formats Python
|
||||
Decimal objects to strings on the bind side (as of 5.0.3),
|
||||
as Oracle sends/receives string numerics only in the
|
||||
current locale.
|
||||
|
||||
"""
|
||||
if self.cx_oracle_ver < (5,):
|
||||
# no output type handlers before version 5
|
||||
return
|
||||
|
||||
cx_Oracle = self.dbapi
|
||||
conn = connection.connection
|
||||
|
||||
# override the output_type_handler that's
|
||||
# on the cx_oracle connection with a plain
|
||||
# one on the cursor
|
||||
|
||||
def output_type_handler(cursor, name, defaultType,
|
||||
size, precision, scale):
|
||||
return cursor.var(
|
||||
cx_Oracle.STRING,
|
||||
255, arraysize=cursor.arraysize)
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.outputtypehandler = output_type_handler
|
||||
cursor.execute("SELECT 0.1 FROM DUAL")
|
||||
val = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
char = re.match(r"([\.,])", val).group(1)
|
||||
if char != '.':
|
||||
_detect_decimal = self._detect_decimal
|
||||
self._detect_decimal = \
|
||||
lambda value: _detect_decimal(value.replace(char, '.'))
|
||||
self._to_decimal = \
|
||||
lambda value: decimal.Decimal(value.replace(char, '.'))
|
||||
|
||||
def _detect_decimal(self, value):
|
||||
if "." in value:
|
||||
return decimal.Decimal(value)
|
||||
else:
|
||||
return int(value)
|
||||
|
||||
_to_decimal = decimal.Decimal
|
||||
|
||||
def on_connect(self):
|
||||
if self.cx_oracle_ver < (5,):
|
||||
# no output type handlers before version 5
|
||||
return
|
||||
|
||||
cx_Oracle = self.dbapi
|
||||
|
||||
def output_type_handler(cursor, name, defaultType,
|
||||
size, precision, scale):
|
||||
# convert all NUMBER with precision + positive scale to Decimal
|
||||
# this almost allows "native decimal" mode.
|
||||
if self.supports_native_decimal and \
|
||||
defaultType == cx_Oracle.NUMBER and \
|
||||
precision and scale > 0:
|
||||
return cursor.var(
|
||||
cx_Oracle.STRING,
|
||||
255,
|
||||
outconverter=self._to_decimal,
|
||||
arraysize=cursor.arraysize)
|
||||
# if NUMBER with zero precision and 0 or neg scale, this appears
|
||||
# to indicate "ambiguous". Use a slower converter that will
|
||||
# make a decision based on each value received - the type
|
||||
# may change from row to row (!). This kills
|
||||
# off "native decimal" mode, handlers still needed.
|
||||
elif self.supports_native_decimal and \
|
||||
defaultType == cx_Oracle.NUMBER \
|
||||
and not precision and scale <= 0:
|
||||
return cursor.var(
|
||||
cx_Oracle.STRING,
|
||||
255,
|
||||
outconverter=self._detect_decimal,
|
||||
arraysize=cursor.arraysize)
|
||||
# allow all strings to come back natively as Unicode
|
||||
elif self.coerce_to_unicode and \
|
||||
defaultType in (cx_Oracle.STRING, cx_Oracle.FIXED_CHAR):
|
||||
return cursor.var(util.text_type, size, cursor.arraysize)
|
||||
|
||||
def on_connect(conn):
|
||||
conn.outputtypehandler = output_type_handler
|
||||
|
||||
return on_connect
|
||||
|
||||
def create_connect_args(self, url):
|
||||
dialect_opts = dict(url.query)
|
||||
for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
|
||||
'threaded', 'allow_twophase'):
|
||||
if opt in dialect_opts:
|
||||
util.coerce_kw_type(dialect_opts, opt, bool)
|
||||
setattr(self, opt, dialect_opts[opt])
|
||||
|
||||
if url.database:
|
||||
# if we have a database, then we have a remote host
|
||||
port = url.port
|
||||
if port:
|
||||
port = int(port)
|
||||
else:
|
||||
port = 1521
|
||||
dsn = self.dbapi.makedsn(url.host, port, url.database)
|
||||
else:
|
||||
# we have a local tnsname
|
||||
dsn = url.host
|
||||
|
||||
opts = dict(
|
||||
user=url.username,
|
||||
password=url.password,
|
||||
dsn=dsn,
|
||||
threaded=self.threaded,
|
||||
twophase=self.allow_twophase,
|
||||
)
|
||||
|
||||
if util.py2k:
|
||||
if self._cx_oracle_with_unicode:
|
||||
for k, v in opts.items():
|
||||
if isinstance(v, str):
|
||||
opts[k] = unicode(v)
|
||||
else:
|
||||
for k, v in opts.items():
|
||||
if isinstance(v, unicode):
|
||||
opts[k] = str(v)
|
||||
|
||||
if 'mode' in url.query:
|
||||
opts['mode'] = url.query['mode']
|
||||
if isinstance(opts['mode'], util.string_types):
|
||||
mode = opts['mode'].upper()
|
||||
if mode == 'SYSDBA':
|
||||
opts['mode'] = self.dbapi.SYSDBA
|
||||
elif mode == 'SYSOPER':
|
||||
opts['mode'] = self.dbapi.SYSOPER
|
||||
else:
|
||||
util.coerce_kw_type(opts, 'mode', int)
|
||||
return ([], opts)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
return tuple(
|
||||
int(x)
|
||||
for x in connection.connection.version.split('.')
|
||||
)
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
error, = e.args
|
||||
if isinstance(e, self.dbapi.InterfaceError):
|
||||
return "not connected" in str(e)
|
||||
elif hasattr(error, 'code'):
|
||||
# ORA-00028: your session has been killed
|
||||
# ORA-03114: not connected to ORACLE
|
||||
# ORA-03113: end-of-file on communication channel
|
||||
# ORA-03135: connection lost contact
|
||||
# ORA-01033: ORACLE initialization or shutdown in progress
|
||||
# ORA-02396: exceeded maximum idle time, please connect again
|
||||
# TODO: Others ?
|
||||
return error.code in (28, 3114, 3113, 3135, 1033, 2396)
|
||||
else:
|
||||
return False
|
||||
|
||||
def create_xid(self):
|
||||
"""create a two-phase transaction ID.
|
||||
|
||||
this id will be passed to do_begin_twophase(), do_rollback_twophase(),
|
||||
do_commit_twophase(). its format is unspecified."""
|
||||
|
||||
id = random.randint(0, 2 ** 128)
|
||||
return (0x1234, "%032x" % id, "%032x" % 9)
|
||||
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
if isinstance(parameters, tuple):
|
||||
parameters = list(parameters)
|
||||
cursor.executemany(statement, parameters)
|
||||
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
connection.connection.begin(*xid)
|
||||
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
result = connection.connection.prepare()
|
||||
connection.info['cx_oracle_prepared'] = result
|
||||
|
||||
def do_rollback_twophase(self, connection, xid, is_prepared=True,
|
||||
recover=False):
|
||||
self.do_rollback(connection.connection)
|
||||
|
||||
def do_commit_twophase(self, connection, xid, is_prepared=True,
|
||||
recover=False):
|
||||
if not is_prepared:
|
||||
self.do_commit(connection.connection)
|
||||
else:
|
||||
oci_prepared = connection.info['cx_oracle_prepared']
|
||||
if oci_prepared:
|
||||
self.do_commit(connection.connection)
|
||||
|
||||
def do_recover_twophase(self, connection):
|
||||
connection.info.pop('cx_oracle_prepared', None)
|
||||
|
||||
dialect = OracleDialect_cx_oracle
|
|
@ -1,218 +0,0 @@
|
|||
# oracle/zxjdbc.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: oracle+zxjdbc
|
||||
:name: zxJDBC for Jython
|
||||
:dbapi: zxjdbc
|
||||
:connectstring: oracle+zxjdbc://user:pass@host/dbname
|
||||
:driverurl: http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html.
|
||||
|
||||
"""
|
||||
import decimal
|
||||
import re
|
||||
|
||||
from sqlalchemy import sql, types as sqltypes, util
|
||||
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
|
||||
from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, OracleExecutionContext
|
||||
from sqlalchemy.engine import result as _result
|
||||
from sqlalchemy.sql import expression
|
||||
import collections
|
||||
|
||||
SQLException = zxJDBC = None
|
||||
|
||||
|
||||
class _ZxJDBCDate(sqltypes.Date):
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return value.date()
|
||||
return process
|
||||
|
||||
|
||||
class _ZxJDBCNumeric(sqltypes.Numeric):
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
#XXX: does the dialect return Decimal or not???
|
||||
# if it does (in all cases), we could use a None processor as well as
|
||||
# the to_float generic processor
|
||||
if self.asdecimal:
|
||||
def process(value):
|
||||
if isinstance(value, decimal.Decimal):
|
||||
return value
|
||||
else:
|
||||
return decimal.Decimal(str(value))
|
||||
else:
|
||||
def process(value):
|
||||
if isinstance(value, decimal.Decimal):
|
||||
return float(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class OracleCompiler_zxjdbc(OracleCompiler):
|
||||
|
||||
def returning_clause(self, stmt, returning_cols):
|
||||
self.returning_cols = list(expression._select_iterables(returning_cols))
|
||||
|
||||
# within_columns_clause=False so that labels (foo AS bar) don't render
|
||||
columns = [self.process(c, within_columns_clause=False, result_map=self.result_map)
|
||||
for c in self.returning_cols]
|
||||
|
||||
if not hasattr(self, 'returning_parameters'):
|
||||
self.returning_parameters = []
|
||||
|
||||
binds = []
|
||||
for i, col in enumerate(self.returning_cols):
|
||||
dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
|
||||
self.returning_parameters.append((i + 1, dbtype))
|
||||
|
||||
bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype))
|
||||
self.binds[bindparam.key] = bindparam
|
||||
binds.append(self.bindparam_string(self._truncate_bindparam(bindparam)))
|
||||
|
||||
return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
|
||||
|
||||
|
||||
class OracleExecutionContext_zxjdbc(OracleExecutionContext):
|
||||
|
||||
def pre_exec(self):
|
||||
if hasattr(self.compiled, 'returning_parameters'):
|
||||
# prepare a zxJDBC statement so we can grab its underlying
|
||||
# OraclePreparedStatement's getReturnResultSet later
|
||||
self.statement = self.cursor.prepare(self.statement)
|
||||
|
||||
def get_result_proxy(self):
|
||||
if hasattr(self.compiled, 'returning_parameters'):
|
||||
rrs = None
|
||||
try:
|
||||
try:
|
||||
rrs = self.statement.__statement__.getReturnResultSet()
|
||||
next(rrs)
|
||||
except SQLException as sqle:
|
||||
msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode())
|
||||
if sqle.getSQLState() is not None:
|
||||
msg += ' [SQLState: %s]' % sqle.getSQLState()
|
||||
raise zxJDBC.Error(msg)
|
||||
else:
|
||||
row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype)
|
||||
for index, dbtype in self.compiled.returning_parameters)
|
||||
return ReturningResultProxy(self, row)
|
||||
finally:
|
||||
if rrs is not None:
|
||||
try:
|
||||
rrs.close()
|
||||
except SQLException:
|
||||
pass
|
||||
self.statement.close()
|
||||
|
||||
return _result.ResultProxy(self)
|
||||
|
||||
def create_cursor(self):
|
||||
cursor = self._dbapi_connection.cursor()
|
||||
cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
|
||||
return cursor
|
||||
|
||||
|
||||
class ReturningResultProxy(_result.FullyBufferedResultProxy):
|
||||
|
||||
"""ResultProxy backed by the RETURNING ResultSet results."""
|
||||
|
||||
def __init__(self, context, returning_row):
|
||||
self._returning_row = returning_row
|
||||
super(ReturningResultProxy, self).__init__(context)
|
||||
|
||||
def _cursor_description(self):
|
||||
ret = []
|
||||
for c in self.context.compiled.returning_cols:
|
||||
if hasattr(c, 'name'):
|
||||
ret.append((c.name, c.type))
|
||||
else:
|
||||
ret.append((c.anon_label, c.type))
|
||||
return ret
|
||||
|
||||
def _buffer_rows(self):
|
||||
return collections.deque([self._returning_row])
|
||||
|
||||
|
||||
class ReturningParam(object):
|
||||
|
||||
"""A bindparam value representing a RETURNING parameter.
|
||||
|
||||
Specially handled by OracleReturningDataHandler.
|
||||
"""
|
||||
|
||||
def __init__(self, type):
|
||||
self.type = type
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, ReturningParam):
|
||||
return self.type == other.type
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
if isinstance(other, ReturningParam):
|
||||
return self.type != other.type
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
kls = self.__class__
|
||||
return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self),
|
||||
self.type)
|
||||
|
||||
|
||||
class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
|
||||
jdbc_db_name = 'oracle'
|
||||
jdbc_driver_name = 'oracle.jdbc.OracleDriver'
|
||||
|
||||
statement_compiler = OracleCompiler_zxjdbc
|
||||
execution_ctx_cls = OracleExecutionContext_zxjdbc
|
||||
|
||||
colspecs = util.update_copy(
|
||||
OracleDialect.colspecs,
|
||||
{
|
||||
sqltypes.Date: _ZxJDBCDate,
|
||||
sqltypes.Numeric: _ZxJDBCNumeric
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(OracleDialect_zxjdbc, self).__init__(*args, **kwargs)
|
||||
global SQLException, zxJDBC
|
||||
from java.sql import SQLException
|
||||
from com.ziclix.python.sql import zxJDBC
|
||||
from com.ziclix.python.sql.handler import OracleDataHandler
|
||||
|
||||
class OracleReturningDataHandler(OracleDataHandler):
|
||||
"""zxJDBC DataHandler that specially handles ReturningParam."""
|
||||
|
||||
def setJDBCObject(self, statement, index, object, dbtype=None):
|
||||
if type(object) is ReturningParam:
|
||||
statement.registerReturnParameter(index, object.type)
|
||||
elif dbtype is None:
|
||||
OracleDataHandler.setJDBCObject(
|
||||
self, statement, index, object)
|
||||
else:
|
||||
OracleDataHandler.setJDBCObject(
|
||||
self, statement, index, object, dbtype)
|
||||
self.DataHandler = OracleReturningDataHandler
|
||||
|
||||
def initialize(self, connection):
|
||||
super(OracleDialect_zxjdbc, self).initialize(connection)
|
||||
self.implicit_returning = connection.connection.driverversion >= '10.2'
|
||||
|
||||
def _create_jdbc_url(self, url):
|
||||
return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
|
||||
return tuple(int(x) for x in version.split('.'))
|
||||
|
||||
dialect = OracleDialect_zxjdbc
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue