diff --git a/CHANGES.md b/CHANGES.md index 8bc988bb..0d3029b1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,4 +1,11 @@ -### 3.32.8 (2024-10-07 00:30:00 UTC) +### 3.33.0 (2024-0x-xx xx:xx:00 UTC) + +* Update filelock 3.14.0 (8556141) to 3.15.4 (9a979df) +* Update package resource API 68.2.2 (8ad627d) to 70.1.1 (222ebf9) +* Update urllib3 2.2.1 (54d6edf) to 2.2.2 (27e2a5c) + + +### 3.32.8 (2024-10-07 00:30:00 UTC) * Change min required Python version to 3.9 * Change add support for Python 3.9.20, 3.10.15, 3.11.10, 3.12.7 diff --git a/lib/filelock/__init__.py b/lib/filelock/__init__.py index 006299d2..c9d8c5b8 100644 --- a/lib/filelock/__init__.py +++ b/lib/filelock/__init__.py @@ -17,6 +17,13 @@ from ._error import Timeout from ._soft import SoftFileLock from ._unix import UnixFileLock, has_fcntl from ._windows import WindowsFileLock +from .asyncio import ( + AsyncAcquireReturnProxy, + AsyncSoftFileLock, + AsyncUnixFileLock, + AsyncWindowsFileLock, + BaseAsyncFileLock, +) from .version import version #: version of the project as a string @@ -25,23 +32,34 @@ __version__: str = version if sys.platform == "win32": # pragma: win32 cover _FileLock: type[BaseFileLock] = WindowsFileLock + _AsyncFileLock: type[BaseAsyncFileLock] = AsyncWindowsFileLock else: # pragma: win32 no cover # noqa: PLR5501 if has_fcntl: _FileLock: type[BaseFileLock] = UnixFileLock + _AsyncFileLock: type[BaseAsyncFileLock] = AsyncUnixFileLock else: _FileLock = SoftFileLock + _AsyncFileLock = AsyncSoftFileLock if warnings is not None: warnings.warn("only soft file lock is available", stacklevel=2) if TYPE_CHECKING: FileLock = SoftFileLock + AsyncFileLock = AsyncSoftFileLock else: #: Alias for the lock, which should be used for the current platform. FileLock = _FileLock + AsyncFileLock = _AsyncFileLock __all__ = [ "AcquireReturnProxy", + "AsyncAcquireReturnProxy", + "AsyncFileLock", + "AsyncSoftFileLock", + "AsyncUnixFileLock", + "AsyncWindowsFileLock", + "BaseAsyncFileLock", "BaseFileLock", "FileLock", "SoftFileLock", diff --git a/lib/filelock/_api.py b/lib/filelock/_api.py index fd87972c..771a559a 100644 --- a/lib/filelock/_api.py +++ b/lib/filelock/_api.py @@ -1,14 +1,15 @@ from __future__ import annotations import contextlib +import inspect import logging import os import time import warnings -from abc import ABC, abstractmethod +from abc import ABCMeta, abstractmethod from dataclasses import dataclass from threading import local -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from weakref import WeakValueDictionary from ._error import Timeout @@ -77,35 +78,71 @@ class ThreadLocalFileContext(FileLockContext, local): """A thread local version of the ``FileLockContext`` class.""" -class BaseFileLock(ABC, contextlib.ContextDecorator): - """Abstract base class for a file lock object.""" - - _instances: WeakValueDictionary[str, BaseFileLock] - - def __new__( # noqa: PLR0913 +class FileLockMeta(ABCMeta): + def __call__( # noqa: PLR0913 cls, lock_file: str | os.PathLike[str], timeout: float = -1, mode: int = 0o644, - thread_local: bool = True, # noqa: ARG003, FBT001, FBT002 + thread_local: bool = True, # noqa: FBT001, FBT002 *, - blocking: bool = True, # noqa: ARG003 + blocking: bool = True, is_singleton: bool = False, - **kwargs: dict[str, Any], # capture remaining kwargs for subclasses # noqa: ARG003 - ) -> Self: - """Create a new lock object or if specified return the singleton instance for the lock file.""" - if not is_singleton: - return super().__new__(cls) + **kwargs: Any, # capture remaining kwargs for subclasses # noqa: ANN401 + ) -> BaseFileLock: + if is_singleton: + instance = cls._instances.get(str(lock_file)) # type: ignore[attr-defined] + if instance: + params_to_check = { + "thread_local": (thread_local, instance.is_thread_local()), + "timeout": (timeout, instance.timeout), + "mode": (mode, instance.mode), + "blocking": (blocking, instance.blocking), + } - instance = cls._instances.get(str(lock_file)) - if not instance: - instance = super().__new__(cls) - cls._instances[str(lock_file)] = instance - elif timeout != instance.timeout or mode != instance.mode: - msg = "Singleton lock instances cannot be initialized with differing arguments" - raise ValueError(msg) + non_matching_params = { + name: (passed_param, set_param) + for name, (passed_param, set_param) in params_to_check.items() + if passed_param != set_param + } + if not non_matching_params: + return cast(BaseFileLock, instance) - return instance # type: ignore[return-value] # https://github.com/python/mypy/issues/15322 + # parameters do not match; raise error + msg = "Singleton lock instances cannot be initialized with differing arguments" + msg += "\nNon-matching arguments: " + for param_name, (passed_param, set_param) in non_matching_params.items(): + msg += f"\n\t{param_name} (existing lock has {set_param} but {passed_param} was passed)" + raise ValueError(msg) + + # Workaround to make `__init__`'s params optional in subclasses + # E.g. virtualenv changes the signature of the `__init__` method in the `BaseFileLock` class descendant + # (https://github.com/tox-dev/filelock/pull/340) + + all_params = { + "timeout": timeout, + "mode": mode, + "thread_local": thread_local, + "blocking": blocking, + "is_singleton": is_singleton, + **kwargs, + } + + present_params = inspect.signature(cls.__init__).parameters # type: ignore[misc] + init_params = {key: value for key, value in all_params.items() if key in present_params} + + instance = super().__call__(lock_file, **init_params) + + if is_singleton: + cls._instances[str(lock_file)] = instance # type: ignore[attr-defined] + + return cast(BaseFileLock, instance) + + +class BaseFileLock(contextlib.ContextDecorator, metaclass=FileLockMeta): + """Abstract base class for a file lock object.""" + + _instances: WeakValueDictionary[str, BaseFileLock] def __init_subclass__(cls, **kwargs: dict[str, Any]) -> None: """Setup unique state for lock subclasses.""" diff --git a/lib/filelock/asyncio.py b/lib/filelock/asyncio.py new file mode 100644 index 00000000..f5848c89 --- /dev/null +++ b/lib/filelock/asyncio.py @@ -0,0 +1,342 @@ +"""An asyncio-based implementation of the file lock.""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +import os +import time +from dataclasses import dataclass +from threading import local +from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast + +from ._api import BaseFileLock, FileLockContext, FileLockMeta +from ._error import Timeout +from ._soft import SoftFileLock +from ._unix import UnixFileLock +from ._windows import WindowsFileLock + +if TYPE_CHECKING: + import sys + from concurrent import futures + from types import TracebackType + + if sys.version_info >= (3, 11): # pragma: no cover (py311+) + from typing import Self + else: # pragma: no cover ( None: # noqa: D107 + self.lock = lock + + async def __aenter__(self) -> BaseAsyncFileLock: # noqa: D105 + return self.lock + + async def __aexit__( # noqa: D105 + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.lock.release() + + +class AsyncFileLockMeta(FileLockMeta): + def __call__( # type: ignore[override] # noqa: PLR0913 + cls, # noqa: N805 + lock_file: str | os.PathLike[str], + timeout: float = -1, + mode: int = 0o644, + thread_local: bool = False, # noqa: FBT001, FBT002 + *, + blocking: bool = True, + is_singleton: bool = False, + loop: asyncio.AbstractEventLoop | None = None, + run_in_executor: bool = True, + executor: futures.Executor | None = None, + ) -> BaseAsyncFileLock: + if thread_local and run_in_executor: + msg = "run_in_executor is not supported when thread_local is True" + raise ValueError(msg) + instance = super().__call__( + lock_file=lock_file, + timeout=timeout, + mode=mode, + thread_local=thread_local, + blocking=blocking, + is_singleton=is_singleton, + loop=loop, + run_in_executor=run_in_executor, + executor=executor, + ) + return cast(BaseAsyncFileLock, instance) + + +class BaseAsyncFileLock(BaseFileLock, metaclass=AsyncFileLockMeta): + """Base class for asynchronous file locks.""" + + def __init__( # noqa: PLR0913 + self, + lock_file: str | os.PathLike[str], + timeout: float = -1, + mode: int = 0o644, + thread_local: bool = False, # noqa: FBT001, FBT002 + *, + blocking: bool = True, + is_singleton: bool = False, + loop: asyncio.AbstractEventLoop | None = None, + run_in_executor: bool = True, + executor: futures.Executor | None = None, + ) -> None: + """ + Create a new lock object. + + :param lock_file: path to the file + :param timeout: default timeout when acquiring the lock, in seconds. It will be used as fallback value in \ + the acquire method, if no timeout value (``None``) is given. If you want to disable the timeout, set it \ + to a negative value. A timeout of 0 means that there is exactly one attempt to acquire the file lock. + :param mode: file permissions for the lockfile + :param thread_local: Whether this object's internal context should be thread local or not. If this is set to \ + ``False`` then the lock will be reentrant across threads. + :param blocking: whether the lock should be blocking or not + :param is_singleton: If this is set to ``True`` then only one instance of this class will be created \ + per lock file. This is useful if you want to use the lock object for reentrant locking without needing \ + to pass the same object around. + :param loop: The event loop to use. If not specified, the running event loop will be used. + :param run_in_executor: If this is set to ``True`` then the lock will be acquired in an executor. + :param executor: The executor to use. If not specified, the default executor will be used. + + """ + self._is_thread_local = thread_local + self._is_singleton = is_singleton + + # Create the context. Note that external code should not work with the context directly and should instead use + # properties of this class. + kwargs: dict[str, Any] = { + "lock_file": os.fspath(lock_file), + "timeout": timeout, + "mode": mode, + "blocking": blocking, + "loop": loop, + "run_in_executor": run_in_executor, + "executor": executor, + } + self._context: AsyncFileLockContext = (AsyncThreadLocalFileContext if thread_local else AsyncFileLockContext)( + **kwargs + ) + + @property + def run_in_executor(self) -> bool: + """::return: whether run in executor.""" + return self._context.run_in_executor + + @property + def executor(self) -> futures.Executor | None: + """::return: the executor.""" + return self._context.executor + + @executor.setter + def executor(self, value: futures.Executor | None) -> None: # pragma: no cover + """ + Change the executor. + + :param value: the new executor or ``None`` + :type value: futures.Executor | None + + """ + self._context.executor = value + + @property + def loop(self) -> asyncio.AbstractEventLoop | None: + """::return: the event loop.""" + return self._context.loop + + async def acquire( # type: ignore[override] + self, + timeout: float | None = None, + poll_interval: float = 0.05, + *, + blocking: bool | None = None, + ) -> AsyncAcquireReturnProxy: + """ + Try to acquire the file lock. + + :param timeout: maximum wait time for acquiring the lock, ``None`` means use the default + :attr:`~BaseFileLock.timeout` is and if ``timeout < 0``, there is no timeout and + this method will block until the lock could be acquired + :param poll_interval: interval of trying to acquire the lock file + :param blocking: defaults to True. If False, function will return immediately if it cannot obtain a lock on the + first attempt. Otherwise, this method will block until the timeout expires or the lock is acquired. + :raises Timeout: if fails to acquire lock within the timeout period + :return: a context object that will unlock the file when the context is exited + + .. code-block:: python + + # You can use this method in the context manager (recommended) + with lock.acquire(): + pass + + # Or use an equivalent try-finally construct: + lock.acquire() + try: + pass + finally: + lock.release() + + """ + # Use the default timeout, if no timeout is provided. + if timeout is None: + timeout = self._context.timeout + + if blocking is None: + blocking = self._context.blocking + + # Increment the number right at the beginning. We can still undo it, if something fails. + self._context.lock_counter += 1 + + lock_id = id(self) + lock_filename = self.lock_file + start_time = time.perf_counter() + try: + while True: + if not self.is_locked: + _LOGGER.debug("Attempting to acquire lock %s on %s", lock_id, lock_filename) + await self._run_internal_method(self._acquire) + if self.is_locked: + _LOGGER.debug("Lock %s acquired on %s", lock_id, lock_filename) + break + if blocking is False: + _LOGGER.debug("Failed to immediately acquire lock %s on %s", lock_id, lock_filename) + raise Timeout(lock_filename) # noqa: TRY301 + if 0 <= timeout < time.perf_counter() - start_time: + _LOGGER.debug("Timeout on acquiring lock %s on %s", lock_id, lock_filename) + raise Timeout(lock_filename) # noqa: TRY301 + msg = "Lock %s not acquired on %s, waiting %s seconds ..." + _LOGGER.debug(msg, lock_id, lock_filename, poll_interval) + await asyncio.sleep(poll_interval) + except BaseException: # Something did go wrong, so decrement the counter. + self._context.lock_counter = max(0, self._context.lock_counter - 1) + raise + return AsyncAcquireReturnProxy(lock=self) + + async def release(self, force: bool = False) -> None: # type: ignore[override] # noqa: FBT001, FBT002 + """ + Releases the file lock. Please note, that the lock is only completely released, if the lock counter is 0. + Also note, that the lock file itself is not automatically deleted. + + :param force: If true, the lock counter is ignored and the lock is released in every case/ + + """ + if self.is_locked: + self._context.lock_counter -= 1 + + if self._context.lock_counter == 0 or force: + lock_id, lock_filename = id(self), self.lock_file + + _LOGGER.debug("Attempting to release lock %s on %s", lock_id, lock_filename) + await self._run_internal_method(self._release) + self._context.lock_counter = 0 + _LOGGER.debug("Lock %s released on %s", lock_id, lock_filename) + + async def _run_internal_method(self, method: Callable[[], Any]) -> None: + if asyncio.iscoroutinefunction(method): + await method() + elif self.run_in_executor: + loop = self.loop or asyncio.get_running_loop() + await loop.run_in_executor(self.executor, method) + else: + method() + + def __enter__(self) -> NoReturn: + """ + Replace old __enter__ method to avoid using it. + + NOTE: DO NOT USE `with` FOR ASYNCIO LOCKS, USE `async with` INSTEAD. + + :return: none + :rtype: NoReturn + """ + msg = "Do not use `with` for asyncio locks, use `async with` instead." + raise NotImplementedError(msg) + + async def __aenter__(self) -> Self: + """ + Acquire the lock. + + :return: the lock object + + """ + await self.acquire() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """ + Release the lock. + + :param exc_type: the exception type if raised + :param exc_value: the exception value if raised + :param traceback: the exception traceback if raised + + """ + await self.release() + + def __del__(self) -> None: + """Called when the lock object is deleted.""" + with contextlib.suppress(RuntimeError): + loop = self.loop or asyncio.get_running_loop() + if not loop.is_running(): # pragma: no cover + loop.run_until_complete(self.release(force=True)) + else: + loop.create_task(self.release(force=True)) + + +class AsyncSoftFileLock(SoftFileLock, BaseAsyncFileLock): + """Simply watches the existence of the lock file.""" + + +class AsyncUnixFileLock(UnixFileLock, BaseAsyncFileLock): + """Uses the :func:`fcntl.flock` to hard lock the lock file on unix systems.""" + + +class AsyncWindowsFileLock(WindowsFileLock, BaseAsyncFileLock): + """Uses the :func:`msvcrt.locking` to hard lock the lock file on windows systems.""" + + +__all__ = [ + "AsyncAcquireReturnProxy", + "AsyncSoftFileLock", + "AsyncUnixFileLock", + "AsyncWindowsFileLock", + "BaseAsyncFileLock", +] diff --git a/lib/filelock/version.py b/lib/filelock/version.py index 0afa0cef..3b5e79de 100644 --- a/lib/filelock/version.py +++ b/lib/filelock/version.py @@ -1,4 +1,4 @@ # file generated by setuptools_scm # don't change, don't track in version control -__version__ = version = '3.14.0' -__version_tuple__ = version_tuple = (3, 14, 0) +__version__ = version = '3.15.4' +__version_tuple__ = version_tuple = (3, 15, 4) diff --git a/lib/pkg_resources/__init__.py b/lib/pkg_resources/__init__.py index c10a88e5..c4ace5aa 100644 --- a/lib/pkg_resources/__init__.py +++ b/lib/pkg_resources/__init__.py @@ -20,9 +20,11 @@ This module is deprecated. Users are directed to :mod:`importlib.resources`, :mod:`importlib.metadata` and :pypi:`packaging` instead. """ +from __future__ import annotations + import sys -if sys.version_info < (3, 8): +if sys.version_info < (3, 8): # noqa: UP036 # Check for unsupported versions raise RuntimeError("Python 3.8 or later is required") import os @@ -32,23 +34,21 @@ import re import types from typing import ( Any, + Literal, + Dict, + Iterator, Mapping, MutableSequence, NamedTuple, NoReturn, - Sequence, - Set, Tuple, - Type, Union, TYPE_CHECKING, - List, Protocol, Callable, - Dict, Iterable, - Optional, TypeVar, + overload, ) import zipfile import zipimport @@ -99,7 +99,8 @@ from pkg_resources.extern.packaging import version as _packaging_version from pkg_resources.extern.platformdirs import user_cache_dir as _user_cache_dir if TYPE_CHECKING: - from _typeshed import StrPath + from _typeshed import BytesPath, StrPath, StrOrBytesPath + from typing_extensions import Self warnings.warn( "pkg_resources is deprecated as an API. " @@ -110,17 +111,22 @@ warnings.warn( _T = TypeVar("_T") +_DistributionT = TypeVar("_DistributionT", bound="Distribution") # Type aliases _NestedStr = Union[str, Iterable[Union[str, Iterable["_NestedStr"]]]] -_InstallerType = Callable[["Requirement"], Optional["Distribution"]] +_InstallerTypeT = Callable[["Requirement"], "_DistributionT"] +_InstallerType = Callable[["Requirement"], Union["Distribution", None]] _PkgReqType = Union[str, "Requirement"] _EPDistType = Union["Distribution", _PkgReqType] -_MetadataType = Optional["IResourceProvider"] +_MetadataType = Union["IResourceProvider", None] +_ResolvedEntryPoint = Any # Can be any attribute in the module +_ResourceStream = Any # TODO / Incomplete: A readable file-like object # Any object works, but let's indicate we expect something like a module (optionally has __loader__ or __file__) _ModuleLike = Union[object, types.ModuleType] -_ProviderFactoryType = Callable[[_ModuleLike], "IResourceProvider"] +# Any: Should be _ModuleLike but we end up with issues where _ModuleLike doesn't have _ZipLoaderModule's __loader__ +_ProviderFactoryType = Callable[[Any], "IResourceProvider"] _DistFinderType = Callable[[_T, str, bool], Iterable["Distribution"]] -_NSHandlerType = Callable[[_T, str, str, types.ModuleType], Optional[str]] +_NSHandlerType = Callable[[_T, str, str, types.ModuleType], Union[str, None]] _AdapterT = TypeVar( "_AdapterT", _DistFinderType[Any], _ProviderFactoryType, _NSHandlerType[Any] ) @@ -131,6 +137,10 @@ class _LoaderProtocol(Protocol): def load_module(self, fullname: str, /) -> types.ModuleType: ... +class _ZipLoaderModule(Protocol): + __loader__: zipimport.zipimporter + + _PEP440_FALLBACK = re.compile(r"^v?(?P(?:[0-9]+!)?[0-9]+(?:\.[0-9]+)*)", re.I) @@ -144,7 +154,7 @@ class PEP440Warning(RuntimeWarning): parse_version = _packaging_version.Version -_state_vars: Dict[str, str] = {} +_state_vars: dict[str, str] = {} def _declare_state(vartype: str, varname: str, initial_value: _T) -> _T: @@ -152,7 +162,7 @@ def _declare_state(vartype: str, varname: str, initial_value: _T) -> _T: return initial_value -def __getstate__(): +def __getstate__() -> dict[str, Any]: state = {} g = globals() for k, v in _state_vars.items(): @@ -160,7 +170,7 @@ def __getstate__(): return state -def __setstate__(state): +def __setstate__(state: dict[str, Any]) -> dict[str, Any]: g = globals() for k, v in state.items(): g['_sset_' + _state_vars[k]](k, g[k], v) @@ -315,17 +325,17 @@ class VersionConflict(ResolutionError): _template = "{self.dist} is installed but {self.req} is required" @property - def dist(self): + def dist(self) -> Distribution: return self.args[0] @property - def req(self): + def req(self) -> Requirement: return self.args[1] def report(self): return self._template.format(**locals()) - def with_context(self, required_by: Set[Union["Distribution", str]]): + def with_context(self, required_by: set[Distribution | str]): """ If required_by is non-empty, return a version of self that is a ContextualVersionConflict. @@ -345,7 +355,7 @@ class ContextualVersionConflict(VersionConflict): _template = VersionConflict._template + ' by {self.required_by}' @property - def required_by(self): + def required_by(self) -> set[str]: return self.args[2] @@ -358,11 +368,11 @@ class DistributionNotFound(ResolutionError): ) @property - def req(self): + def req(self) -> Requirement: return self.args[0] @property - def requirers(self): + def requirers(self) -> set[str] | None: return self.args[1] @property @@ -382,7 +392,7 @@ class UnknownExtra(ResolutionError): """Distribution doesn't have an "extra feature" of the given name""" -_provider_factories: Dict[Type[_ModuleLike], _ProviderFactoryType] = {} +_provider_factories: dict[type[_ModuleLike], _ProviderFactoryType] = {} PY_MAJOR = '{}.{}'.format(*sys.version_info) EGG_DIST = 3 @@ -393,7 +403,7 @@ DEVELOP_DIST = -1 def register_loader_type( - loader_type: Type[_ModuleLike], provider_factory: _ProviderFactoryType + loader_type: type[_ModuleLike], provider_factory: _ProviderFactoryType ): """Register `provider_factory` to make providers for `loader_type` @@ -404,7 +414,11 @@ def register_loader_type( _provider_factories[loader_type] = provider_factory -def get_provider(moduleOrReq: Union[str, "Requirement"]): +@overload +def get_provider(moduleOrReq: str) -> IResourceProvider: ... +@overload +def get_provider(moduleOrReq: Requirement) -> Distribution: ... +def get_provider(moduleOrReq: str | Requirement) -> IResourceProvider | Distribution: """Return an IResourceProvider for the named module or requirement""" if isinstance(moduleOrReq, Requirement): return working_set.find(moduleOrReq) or require(str(moduleOrReq))[0] @@ -466,7 +480,7 @@ darwinVersionString = re.compile(r"darwin-(\d+)\.(\d+)\.(\d+)-(.*)") get_platform = get_build_platform -def compatible_platforms(provided: Optional[str], required: Optional[str]): +def compatible_platforms(provided: str | None, required: str | None): """Can code for the `provided` platform run on the `required` platform? Returns true if either platform is ``None``, or the platforms are equal. @@ -515,23 +529,34 @@ def compatible_platforms(provided: Optional[str], required: Optional[str]): return False -def get_distribution(dist: _EPDistType): +@overload +def get_distribution(dist: _DistributionT) -> _DistributionT: ... +@overload +def get_distribution(dist: _PkgReqType) -> Distribution: ... +def get_distribution(dist: Distribution | _PkgReqType) -> Distribution: """Return a current distribution object for a Requirement or string""" if isinstance(dist, str): dist = Requirement.parse(dist) if isinstance(dist, Requirement): - dist = get_provider(dist) + # Bad type narrowing, dist has to be a Requirement here, so get_provider has to return Distribution + dist = get_provider(dist) # type: ignore[assignment] if not isinstance(dist, Distribution): - raise TypeError("Expected string, Requirement, or Distribution", dist) + raise TypeError("Expected str, Requirement, or Distribution", dist) return dist -def load_entry_point(dist: _EPDistType, group: str, name: str): +def load_entry_point(dist: _EPDistType, group: str, name: str) -> _ResolvedEntryPoint: """Return `name` entry point of `group` for `dist` or raise ImportError""" return get_distribution(dist).load_entry_point(group, name) -def get_entry_map(dist: _EPDistType, group: Optional[str] = None): +@overload +def get_entry_map( + dist: _EPDistType, group: None = None +) -> dict[str, dict[str, EntryPoint]]: ... +@overload +def get_entry_map(dist: _EPDistType, group: str) -> dict[str, EntryPoint]: ... +def get_entry_map(dist: _EPDistType, group: str | None = None): """Return the entry point map for `group`, or the full entry map""" return get_distribution(dist).get_entry_map(group) @@ -545,10 +570,10 @@ class IMetadataProvider(Protocol): def has_metadata(self, name: str) -> bool: """Does the package's distribution contain the named metadata?""" - def get_metadata(self, name: str): + def get_metadata(self, name: str) -> str: """The named metadata resource as a string""" - def get_metadata_lines(self, name: str): + def get_metadata_lines(self, name: str) -> Iterator[str]: """Yield named metadata resource as list of non-blank non-comment lines Leading and trailing whitespace is stripped from each line, and lines @@ -557,49 +582,53 @@ class IMetadataProvider(Protocol): def metadata_isdir(self, name: str) -> bool: """Is the named metadata a directory? (like ``os.path.isdir()``)""" - def metadata_listdir(self, name: str): + def metadata_listdir(self, name: str) -> list[str]: """List of metadata names in the directory (like ``os.listdir()``)""" - def run_script(self, script_name: str, namespace: Dict[str, Any]): + def run_script(self, script_name: str, namespace: dict[str, Any]) -> None: """Execute the named script in the supplied namespace dictionary""" class IResourceProvider(IMetadataProvider, Protocol): """An object that provides access to package resources""" - def get_resource_filename(self, manager: "ResourceManager", resource_name: str): + def get_resource_filename( + self, manager: ResourceManager, resource_name: str + ) -> str: """Return a true filesystem path for `resource_name` `manager` must be a ``ResourceManager``""" - def get_resource_stream(self, manager: "ResourceManager", resource_name: str): + def get_resource_stream( + self, manager: ResourceManager, resource_name: str + ) -> _ResourceStream: """Return a readable file-like object for `resource_name` `manager` must be a ``ResourceManager``""" def get_resource_string( - self, manager: "ResourceManager", resource_name: str + self, manager: ResourceManager, resource_name: str ) -> bytes: """Return the contents of `resource_name` as :obj:`bytes` `manager` must be a ``ResourceManager``""" - def has_resource(self, resource_name: str): + def has_resource(self, resource_name: str) -> bool: """Does the package contain the named resource?""" - def resource_isdir(self, resource_name: str): + def resource_isdir(self, resource_name: str) -> bool: """Is the named resource a directory? (like ``os.path.isdir()``)""" - def resource_listdir(self, resource_name: str): + def resource_listdir(self, resource_name: str) -> list[str]: """List of resource names in the directory (like ``os.listdir()``)""" class WorkingSet: """A collection of active distributions on sys.path (or a similar list)""" - def __init__(self, entries: Optional[Iterable[str]] = None): + def __init__(self, entries: Iterable[str] | None = None): """Create working set from list of path entries (default=sys.path)""" - self.entries: List[str] = [] + self.entries: list[str] = [] self.entry_keys = {} self.by_key = {} self.normalized_to_canonical_keys = {} @@ -668,11 +697,11 @@ class WorkingSet: for dist in find_distributions(entry, True): self.add(dist, entry, False) - def __contains__(self, dist: "Distribution"): + def __contains__(self, dist: Distribution) -> bool: """True if `dist` is the active distribution for its project""" return self.by_key.get(dist.key) == dist - def find(self, req: "Requirement"): + def find(self, req: Requirement) -> Distribution | None: """Find a distribution matching requirement `req` If there is an active distribution for the requested project, this @@ -696,7 +725,7 @@ class WorkingSet: raise VersionConflict(dist, req) return dist - def iter_entry_points(self, group: str, name: Optional[str] = None): + def iter_entry_points(self, group: str, name: str | None = None): """Yield entry point objects from `group` matching `name` If `name` is None, yields all entry points in `group` from all @@ -718,13 +747,13 @@ class WorkingSet: ns['__name__'] = name self.require(requires)[0].run_script(script_name, ns) - def __iter__(self): + def __iter__(self) -> Iterator[Distribution]: """Yield distributions for non-duplicate projects in the working set The yield order is the order in which the items' path entries were added to the working set. """ - seen = {} + seen = set() for item in self.entries: if item not in self.entry_keys: # workaround a cache issue @@ -732,13 +761,13 @@ class WorkingSet: for key in self.entry_keys[item]: if key not in seen: - seen[key] = 1 + seen.add(key) yield self.by_key[key] def add( self, - dist: "Distribution", - entry: Optional[str] = None, + dist: Distribution, + entry: str | None = None, insert: bool = True, replace: bool = False, ): @@ -773,14 +802,42 @@ class WorkingSet: keys2.append(dist.key) self._added_new(dist) + @overload def resolve( self, - requirements: Iterable["Requirement"], - env: Optional["Environment"] = None, - installer: Optional[_InstallerType] = None, + requirements: Iterable[Requirement], + env: Environment | None, + installer: _InstallerTypeT[_DistributionT], replace_conflicting: bool = False, - extras: Optional[Tuple[str, ...]] = None, - ): + extras: tuple[str, ...] | None = None, + ) -> list[_DistributionT]: ... + @overload + def resolve( + self, + requirements: Iterable[Requirement], + env: Environment | None = None, + *, + installer: _InstallerTypeT[_DistributionT], + replace_conflicting: bool = False, + extras: tuple[str, ...] | None = None, + ) -> list[_DistributionT]: ... + @overload + def resolve( + self, + requirements: Iterable[Requirement], + env: Environment | None = None, + installer: _InstallerType | None = None, + replace_conflicting: bool = False, + extras: tuple[str, ...] | None = None, + ) -> list[Distribution]: ... + def resolve( + self, + requirements: Iterable[Requirement], + env: Environment | None = None, + installer: _InstallerType | None | _InstallerTypeT[_DistributionT] = None, + replace_conflicting: bool = False, + extras: tuple[str, ...] | None = None, + ) -> list[Distribution] | list[_DistributionT]: """List all distributions needed to (recursively) meet `requirements` `requirements` must be a sequence of ``Requirement`` objects. `env`, @@ -808,7 +865,7 @@ class WorkingSet: # set up the stack requirements = list(requirements)[::-1] # set of processed requirements - processed = {} + processed = set() # key -> dist best = {} to_activate = [] @@ -842,14 +899,14 @@ class WorkingSet: required_by[new_requirement].add(req.project_name) req_extras[new_requirement] = req.extras - processed[req] = True + processed.add(req) # return list of distros to activate return to_activate def _resolve_dist( self, req, best, replace_conflicting, env, installer, required_by, to_activate - ) -> "Distribution": + ) -> Distribution: dist = best.get(req.key) if dist is None: # Find the best distribution and add it to the map @@ -878,13 +935,41 @@ class WorkingSet: raise VersionConflict(dist, req).with_context(dependent_req) return dist + @overload def find_plugins( self, - plugin_env: "Environment", - full_env: Optional["Environment"] = None, - installer: Optional[_InstallerType] = None, + plugin_env: Environment, + full_env: Environment | None, + installer: _InstallerTypeT[_DistributionT], fallback: bool = True, - ): + ) -> tuple[list[_DistributionT], dict[Distribution, Exception]]: ... + @overload + def find_plugins( + self, + plugin_env: Environment, + full_env: Environment | None = None, + *, + installer: _InstallerTypeT[_DistributionT], + fallback: bool = True, + ) -> tuple[list[_DistributionT], dict[Distribution, Exception]]: ... + @overload + def find_plugins( + self, + plugin_env: Environment, + full_env: Environment | None = None, + installer: _InstallerType | None = None, + fallback: bool = True, + ) -> tuple[list[Distribution], dict[Distribution, Exception]]: ... + def find_plugins( + self, + plugin_env: Environment, + full_env: Environment | None = None, + installer: _InstallerType | None | _InstallerTypeT[_DistributionT] = None, + fallback: bool = True, + ) -> tuple[ + list[Distribution] | list[_DistributionT], + dict[Distribution, Exception], + ]: """Find all activatable distributions in `plugin_env` Example usage:: @@ -923,8 +1008,8 @@ class WorkingSet: # scan project names in alphabetic order plugin_projects.sort() - error_info = {} - distributions = {} + error_info: dict[Distribution, Exception] = {} + distributions: dict[Distribution, Exception | None] = {} if full_env is None: env = Environment(self.entries) @@ -982,7 +1067,7 @@ class WorkingSet: return needed def subscribe( - self, callback: Callable[["Distribution"], object], existing: bool = True + self, callback: Callable[[Distribution], object], existing: bool = True ): """Invoke `callback` for all distributions @@ -1024,9 +1109,7 @@ class _ReqExtras(Dict["Requirement", Tuple[str, ...]]): Map each requirement to the extras that demanded it. """ - def markers_pass( - self, req: "Requirement", extras: Optional[Tuple[str, ...]] = None - ): + def markers_pass(self, req: Requirement, extras: tuple[str, ...] | None = None): """ Evaluate markers for req against each extra that demanded it. @@ -1046,9 +1129,9 @@ class Environment: def __init__( self, - search_path: Optional[Sequence[str]] = None, - platform: Optional[str] = get_supported_platform(), - python: Optional[str] = PY_MAJOR, + search_path: Iterable[str] | None = None, + platform: str | None = get_supported_platform(), + python: str | None = PY_MAJOR, ): """Snapshot distributions available on a search path @@ -1071,7 +1154,7 @@ class Environment: self.python = python self.scan(search_path) - def can_add(self, dist: "Distribution"): + def can_add(self, dist: Distribution): """Is distribution `dist` acceptable for this environment? The distribution must match the platform and python version @@ -1085,11 +1168,11 @@ class Environment: ) return py_compat and compatible_platforms(dist.platform, self.platform) - def remove(self, dist: "Distribution"): + def remove(self, dist: Distribution): """Remove `dist` from the environment""" self._distmap[dist.key].remove(dist) - def scan(self, search_path: Optional[Sequence[str]] = None): + def scan(self, search_path: Iterable[str] | None = None): """Scan `search_path` for distributions usable in this environment Any distributions found are added to the environment. @@ -1104,7 +1187,7 @@ class Environment: for dist in find_distributions(item): self.add(dist) - def __getitem__(self, project_name: str): + def __getitem__(self, project_name: str) -> list[Distribution]: """Return a newest-to-oldest list of distributions for `project_name` Uses case-insensitive `project_name` comparison, assuming all the @@ -1115,7 +1198,7 @@ class Environment: distribution_key = project_name.lower() return self._distmap.get(distribution_key, []) - def add(self, dist: "Distribution"): + def add(self, dist: Distribution): """Add `dist` if we ``can_add()`` it and it has not already been added""" if self.can_add(dist) and dist.has_version(): dists = self._distmap.setdefault(dist.key, []) @@ -1123,13 +1206,29 @@ class Environment: dists.append(dist) dists.sort(key=operator.attrgetter('hashcmp'), reverse=True) + @overload def best_match( self, - req: "Requirement", + req: Requirement, working_set: WorkingSet, - installer: Optional[Callable[["Requirement"], Any]] = None, + installer: _InstallerTypeT[_DistributionT], replace_conflicting: bool = False, - ): + ) -> _DistributionT: ... + @overload + def best_match( + self, + req: Requirement, + working_set: WorkingSet, + installer: _InstallerType | None = None, + replace_conflicting: bool = False, + ) -> Distribution | None: ... + def best_match( + self, + req: Requirement, + working_set: WorkingSet, + installer: _InstallerType | None | _InstallerTypeT[_DistributionT] = None, + replace_conflicting: bool = False, + ) -> Distribution | None: """Find distribution best matching `req` and usable on `working_set` This calls the ``find(req)`` method of the `working_set` to see if a @@ -1156,11 +1255,32 @@ class Environment: # try to download/install return self.obtain(req, installer) + @overload def obtain( self, - requirement: "Requirement", - installer: Optional[Callable[["Requirement"], Any]] = None, - ): + requirement: Requirement, + installer: _InstallerTypeT[_DistributionT], + ) -> _DistributionT: ... + @overload + def obtain( + self, + requirement: Requirement, + installer: Callable[[Requirement], None] | None = None, + ) -> None: ... + @overload + def obtain( + self, + requirement: Requirement, + installer: _InstallerType | None = None, + ) -> Distribution | None: ... + def obtain( + self, + requirement: Requirement, + installer: Callable[[Requirement], None] + | _InstallerType + | None + | _InstallerTypeT[_DistributionT] = None, + ) -> Distribution | None: """Obtain a distribution matching `requirement` (e.g. via download) Obtain a distro that matches requirement (e.g. via download). In the @@ -1171,13 +1291,13 @@ class Environment: to the `installer` argument.""" return installer(requirement) if installer else None - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Yield the unique project names of the available distributions""" for key in self._distmap.keys(): if self[key]: yield key - def __iadd__(self, other: Union["Distribution", "Environment"]): + def __iadd__(self, other: Distribution | Environment): """In-place addition of a distribution or environment""" if isinstance(other, Distribution): self.add(other) @@ -1189,7 +1309,7 @@ class Environment: raise TypeError("Can't add %r to environment" % (other,)) return self - def __add__(self, other: Union["Distribution", "Environment"]): + def __add__(self, other: Distribution | Environment): """Add an environment or distribution to an environment""" new = self.__class__([], platform=None, python=None) for env in self, other: @@ -1216,15 +1336,15 @@ class ExtractionError(RuntimeError): The exception instance that caused extraction to fail """ - manager: "ResourceManager" + manager: ResourceManager cache_path: str - original_error: Optional[BaseException] + original_error: BaseException | None class ResourceManager: """Manage resource extraction and packages""" - extraction_path: Optional[str] = None + extraction_path: str | None = None def __init__(self): self.cached_files = {} @@ -1293,7 +1413,7 @@ class ResourceManager: err.original_error = old_exc raise err - def get_cache_path(self, archive_name: str, names: Iterable[str] = ()): + def get_cache_path(self, archive_name: str, names: Iterable[StrPath] = ()): """Return absolute location in cache for `archive_name` and `names` The parent directory of the resulting path will be created if it does @@ -1315,7 +1435,7 @@ class ResourceManager: self._warn_unsafe_extraction_path(extract_path) - self.cached_files[target_path] = 1 + self.cached_files[target_path] = True return target_path @staticmethod @@ -1345,7 +1465,7 @@ class ResourceManager: ).format(**locals()) warnings.warn(msg, UserWarning) - def postprocess(self, tempname: str, filename: str): + def postprocess(self, tempname: StrOrBytesPath, filename: StrOrBytesPath): """Perform any platform-specific postprocessing of `tempname` This is where Mac header rewrites should be done; other platforms don't @@ -1389,7 +1509,7 @@ class ResourceManager: self.extraction_path = path - def cleanup_resources(self, force: bool = False) -> List[str]: + def cleanup_resources(self, force: bool = False) -> list[str]: """ Delete all extracted resource files and directories, returning a list of the file and directory names that could not be successfully removed. @@ -1404,7 +1524,7 @@ class ResourceManager: return [] -def get_default_cache(): +def get_default_cache() -> str: """ Return the ``PYTHON_EGG_CACHE`` environment variable or a platform-relevant user cache dir for an app @@ -1496,7 +1616,7 @@ def invalid_marker(text: str): return False -def evaluate_marker(text: str, extra: Optional[str] = None): +def evaluate_marker(text: str, extra: str | None = None) -> bool: """ Evaluate a PEP 508 environment marker. Return a boolean indicating the marker result in this environment. @@ -1514,10 +1634,9 @@ def evaluate_marker(text: str, extra: Optional[str] = None): class NullProvider: """Try to implement resources and metadata for arbitrary PEP 302 loaders""" - egg_name: Optional[str] = None - egg_info: Optional[str] = None - loader: Optional[_LoaderProtocol] = None - module_path: Optional[str] # Some subclasses can have a None module_path + egg_name: str | None = None + egg_info: str | None = None + loader: _LoaderProtocol | None = None def __init__(self, module: _ModuleLike): self.loader = getattr(module, '__loader__', None) @@ -1560,7 +1679,7 @@ class NullProvider: exc.reason += ' in {} file at path: {}'.format(name, path) raise - def get_metadata_lines(self, name: str): + def get_metadata_lines(self, name: str) -> Iterator[str]: return yield_lines(self.get_metadata(name)) def resource_isdir(self, resource_name: str): @@ -1572,12 +1691,12 @@ class NullProvider: def resource_listdir(self, resource_name: str): return self._listdir(self._fn(self.module_path, resource_name)) - def metadata_listdir(self, name: str): + def metadata_listdir(self, name: str) -> list[str]: if self.egg_info: return self._listdir(self._fn(self.egg_info, name)) return [] - def run_script(self, script_name: str, namespace: Dict[str, Any]): + def run_script(self, script_name: str, namespace: dict[str, Any]): script = 'scripts/' + script_name if not self.has_metadata(script): raise ResolutionError( @@ -1585,6 +1704,7 @@ class NullProvider: **locals() ), ) + script_text = self.get_metadata(script).replace('\r\n', '\n') script_text = script_text.replace('\r', '\n') script_filename = self._fn(self.egg_info, script) @@ -1615,12 +1735,16 @@ class NullProvider: "Can't perform this operation for unregistered loader type" ) - def _listdir(self, path): + def _listdir(self, path) -> list[str]: raise NotImplementedError( "Can't perform this operation for unregistered loader type" ) - def _fn(self, base, resource_name: str): + def _fn(self, base: str | None, resource_name: str): + if base is None: + raise TypeError( + "`base` parameter in `_fn` is `None`. Either override this method or check the parameter first." + ) self._validate_resource_path(resource_name) if resource_name: return os.path.join(base, *resource_name.split('/')) @@ -1780,7 +1904,8 @@ DefaultProvider._register() class EmptyProvider(NullProvider): """Provider that returns nothing for all requests""" - module_path = None + # A special case, we don't want all Providers inheriting from NullProvider to have a potentially None module_path + module_path: str | None = None # type: ignore[assignment] _isdir = _has = lambda self, path: False @@ -1802,7 +1927,7 @@ class ZipManifests(Dict[str, "MemoizedZipManifests.manifest_mod"]): zip manifest builder """ - # `path` could be `Union["StrPath", IO[bytes]]` but that violates the LSP for `MemoizedZipManifests.load` + # `path` could be `StrPath | IO[bytes]` but that violates the LSP for `MemoizedZipManifests.load` @classmethod def build(cls, path: str): """ @@ -1831,10 +1956,10 @@ class MemoizedZipManifests(ZipManifests): """ class manifest_mod(NamedTuple): - manifest: Dict[str, zipfile.ZipInfo] + manifest: dict[str, zipfile.ZipInfo] mtime: float - def load(self, path: str): # type: ignore[override] # ZipManifests.load is a classmethod + def load(self, path: str) -> dict[str, zipfile.ZipInfo]: # type: ignore[override] # ZipManifests.load is a classmethod """ Load a manifest at path or return a suitable manifest already loaded. """ @@ -1851,12 +1976,12 @@ class MemoizedZipManifests(ZipManifests): class ZipProvider(EggProvider): """Resource support for zips and eggs""" - eagers: Optional[List[str]] = None + eagers: list[str] | None = None _zip_manifests = MemoizedZipManifests() # ZipProvider's loader should always be a zipimporter or equivalent loader: zipimport.zipimporter - def __init__(self, module: _ModuleLike): + def __init__(self, module: _ZipLoaderModule): super().__init__(module) self.zip_pre = self.loader.archive + os.sep @@ -1905,7 +2030,7 @@ class ZipProvider(EggProvider): return timestamp, size # FIXME: 'ZipProvider._extract_resource' is too complex (12) - def _extract_resource(self, manager: ResourceManager, zip_path): # noqa: C901 + def _extract_resource(self, manager: ResourceManager, zip_path) -> str: # noqa: C901 if zip_path in self._index(): for name in self._index()[zip_path]: last = self._extract_resource(manager, os.path.join(zip_path, name)) @@ -2033,7 +2158,7 @@ class FileMetadata(EmptyProvider): the provided location. """ - def __init__(self, path: "StrPath"): + def __init__(self, path: StrPath): self.path = path def _get_metadata_path(self, name): @@ -2042,7 +2167,7 @@ class FileMetadata(EmptyProvider): def has_metadata(self, name: str) -> bool: return name == 'PKG-INFO' and os.path.isfile(self.path) - def get_metadata(self, name): + def get_metadata(self, name: str): if name != 'PKG-INFO': raise KeyError("No metadata except PKG-INFO is available") @@ -2058,7 +2183,7 @@ class FileMetadata(EmptyProvider): msg = tmpl.format(**locals()) warnings.warn(msg) - def get_metadata_lines(self, name): + def get_metadata_lines(self, name: str) -> Iterator[str]: return yield_lines(self.get_metadata(name)) @@ -2102,12 +2227,12 @@ class EggMetadata(ZipProvider): self._setup_prefix() -_distribution_finders: Dict[type, _DistFinderType[Any]] = _declare_state( +_distribution_finders: dict[type, _DistFinderType[Any]] = _declare_state( 'dict', '_distribution_finders', {} ) -def register_finder(importer_type: Type[_T], distribution_finder: _DistFinderType[_T]): +def register_finder(importer_type: type[_T], distribution_finder: _DistFinderType[_T]): """Register `distribution_finder` to find distributions in sys.path items `importer_type` is the type or class of a PEP 302 "Importer" (sys.path item @@ -2126,7 +2251,7 @@ def find_distributions(path_item: str, only: bool = False): def find_eggs_in_zip( importer: zipimport.zipimporter, path_item: str, only: bool = False -): +) -> Iterator[Distribution]: """ Find eggs in zip files; possibly multiple nested eggs. """ @@ -2156,7 +2281,7 @@ register_finder(zipimport.zipimporter, find_eggs_in_zip) def find_nothing( - importer: Optional[object], path_item: Optional[str], only: Optional[bool] = False + importer: object | None, path_item: str | None, only: bool | None = False ): return () @@ -2164,7 +2289,7 @@ def find_nothing( register_finder(object, find_nothing) -def find_on_path(importer: Optional[object], path_item, only=False): +def find_on_path(importer: object | None, path_item, only=False): """Yield distributions accessible on a sys.path directory""" path_item = _normalize_cached(path_item) @@ -2219,7 +2344,7 @@ class NoDists: return iter(()) -def safe_listdir(path): +def safe_listdir(path: StrOrBytesPath): """ Attempt to list contents of path, but suppress some exceptions. """ @@ -2235,13 +2360,13 @@ def safe_listdir(path): return () -def distributions_from_metadata(path): +def distributions_from_metadata(path: str): root = os.path.dirname(path) if os.path.isdir(path): if len(os.listdir(path)) == 0: # empty metadata dir; skip return - metadata = PathMetadata(root, path) + metadata: _MetadataType = PathMetadata(root, path) else: metadata = FileMetadata(path) entry = os.path.basename(path) @@ -2281,16 +2406,16 @@ if hasattr(pkgutil, 'ImpImporter'): register_finder(importlib.machinery.FileFinder, find_on_path) -_namespace_handlers: Dict[type, _NSHandlerType[Any]] = _declare_state( +_namespace_handlers: dict[type, _NSHandlerType[Any]] = _declare_state( 'dict', '_namespace_handlers', {} ) -_namespace_packages: Dict[Optional[str], List[str]] = _declare_state( +_namespace_packages: dict[str | None, list[str]] = _declare_state( 'dict', '_namespace_packages', {} ) def register_namespace_handler( - importer_type: Type[_T], namespace_handler: _NSHandlerType[_T] + importer_type: type[_T], namespace_handler: _NSHandlerType[_T] ): """Register `namespace_handler` to declare namespace packages @@ -2423,7 +2548,7 @@ def declare_namespace(packageName: str): _imp.release_lock() -def fixup_namespace_packages(path_item: str, parent: Optional[str] = None): +def fixup_namespace_packages(path_item: str, parent: str | None = None): """Ensure that previously-declared namespace packages include path_item""" _imp.acquire_lock() try: @@ -2437,7 +2562,7 @@ def fixup_namespace_packages(path_item: str, parent: Optional[str] = None): def file_ns_handler( importer: object, - path_item: "StrPath", + path_item: StrPath, packageName: str, module: types.ModuleType, ): @@ -2462,9 +2587,9 @@ register_namespace_handler(importlib.machinery.FileFinder, file_ns_handler) def null_ns_handler( importer: object, - path_item: Optional[str], - packageName: Optional[str], - module: Optional[_ModuleLike], + path_item: str | None, + packageName: str | None, + module: _ModuleLike | None, ): return None @@ -2472,12 +2597,16 @@ def null_ns_handler( register_namespace_handler(object, null_ns_handler) -def normalize_path(filename: "StrPath"): +@overload +def normalize_path(filename: StrPath) -> str: ... +@overload +def normalize_path(filename: BytesPath) -> bytes: ... +def normalize_path(filename: StrOrBytesPath): """Normalize a file/dir name for comparison purposes""" return os.path.normcase(os.path.realpath(os.path.normpath(_cygwin_patch(filename)))) -def _cygwin_patch(filename: "StrPath"): # pragma: nocover +def _cygwin_patch(filename: StrOrBytesPath): # pragma: nocover """ Contrary to POSIX 2008, on Cygwin, getcwd (3) contains symlink components. Using @@ -2488,9 +2617,19 @@ def _cygwin_patch(filename: "StrPath"): # pragma: nocover return os.path.abspath(filename) if sys.platform == 'cygwin' else filename -@functools.lru_cache(maxsize=None) -def _normalize_cached(filename): - return normalize_path(filename) +if TYPE_CHECKING: + # https://github.com/python/mypy/issues/16261 + # https://github.com/python/typeshed/issues/6347 + @overload + def _normalize_cached(filename: StrPath) -> str: ... + @overload + def _normalize_cached(filename: BytesPath) -> bytes: ... + def _normalize_cached(filename: StrOrBytesPath) -> str | bytes: ... +else: + + @functools.lru_cache(maxsize=None) + def _normalize_cached(filename): + return normalize_path(filename) def _is_egg_path(path): @@ -2549,7 +2688,7 @@ class EntryPoint: module_name: str, attrs: Iterable[str] = (), extras: Iterable[str] = (), - dist: Optional["Distribution"] = None, + dist: Distribution | None = None, ): if not MODULE(module_name): raise ValueError("Invalid module name", module_name) @@ -2570,12 +2709,26 @@ class EntryPoint: def __repr__(self): return "EntryPoint.parse(%r)" % str(self) + @overload + def load( + self, + require: Literal[True] = True, + env: Environment | None = None, + installer: _InstallerType | None = None, + ) -> _ResolvedEntryPoint: ... + @overload + def load( + self, + require: Literal[False], + *args: Any, + **kwargs: Any, + ) -> _ResolvedEntryPoint: ... def load( self, require: bool = True, - *args: Optional[Union[Environment, _InstallerType]], - **kwargs: Optional[Union[Environment, _InstallerType]], - ): + *args: Environment | _InstallerType | None, + **kwargs: Environment | _InstallerType | None, + ) -> _ResolvedEntryPoint: """ Require packages for this EntryPoint, then resolve it. """ @@ -2592,7 +2745,7 @@ class EntryPoint: self.require(*args, **kwargs) # type: ignore return self.resolve() - def resolve(self): + def resolve(self) -> _ResolvedEntryPoint: """ Resolve the entry point from its module and attrs. """ @@ -2604,8 +2757,8 @@ class EntryPoint: def require( self, - env: Optional[Environment] = None, - installer: Optional[_InstallerType] = None, + env: Environment | None = None, + installer: _InstallerType | None = None, ): if not self.dist: error_cls = UnknownExtra if self.extras else AttributeError @@ -2630,7 +2783,7 @@ class EntryPoint: ) @classmethod - def parse(cls, src: str, dist: Optional["Distribution"] = None): + def parse(cls, src: str, dist: Distribution | None = None): """Parse a single entry point from string `src` Entry point syntax follows the form:: @@ -2663,12 +2816,12 @@ class EntryPoint: cls, group: str, lines: _NestedStr, - dist: Optional["Distribution"] = None, + dist: Distribution | None = None, ): """Parse an entry point group""" if not MODULE(group): raise ValueError("Invalid group name", group) - this = {} + this: dict[str, Self] = {} for line in yield_lines(lines): ep = cls.parse(line, dist) if ep.name in this: @@ -2679,15 +2832,16 @@ class EntryPoint: @classmethod def parse_map( cls, - data: Union[str, Iterable[str], Dict[str, Union[str, Iterable[str]]]], - dist: Optional["Distribution"] = None, + data: str | Iterable[str] | dict[str, str | Iterable[str]], + dist: Distribution | None = None, ): """Parse a map of entry point groups""" + _data: Iterable[tuple[str | None, str | Iterable[str]]] if isinstance(data, dict): _data = data.items() else: _data = split_sections(data) - maps: Dict[str, Dict[str, EntryPoint]] = {} + maps: dict[str, dict[str, Self]] = {} for group, lines in _data: if group is None: if not lines: @@ -2722,12 +2876,12 @@ class Distribution: def __init__( self, - location: Optional[str] = None, + location: str | None = None, metadata: _MetadataType = None, - project_name: Optional[str] = None, - version: Optional[str] = None, - py_version: Optional[str] = PY_MAJOR, - platform: Optional[str] = None, + project_name: str | None = None, + version: str | None = None, + py_version: str | None = PY_MAJOR, + platform: str | None = None, precedence: int = EGG_DIST, ): self.project_name = safe_name(project_name or 'Unknown') @@ -2743,10 +2897,10 @@ class Distribution: def from_location( cls, location: str, - basename: str, + basename: StrPath, metadata: _MetadataType = None, **kw: int, # We could set `precedence` explicitly, but keeping this as `**kw` for full backwards and subclassing compatibility - ): + ) -> Distribution: project_name, version, py_version, platform = [None] * 4 basename, ext = os.path.splitext(basename) if ext.lower() in _distributionImpl: @@ -2784,16 +2938,16 @@ class Distribution: def __hash__(self): return hash(self.hashcmp) - def __lt__(self, other: "Distribution"): + def __lt__(self, other: Distribution): return self.hashcmp < other.hashcmp - def __le__(self, other: "Distribution"): + def __le__(self, other: Distribution): return self.hashcmp <= other.hashcmp - def __gt__(self, other: "Distribution"): + def __gt__(self, other: Distribution): return self.hashcmp > other.hashcmp - def __ge__(self, other: "Distribution"): + def __ge__(self, other: Distribution): return self.hashcmp >= other.hashcmp def __eq__(self, other: object): @@ -2885,14 +3039,14 @@ class Distribution: return self.__dep_map @staticmethod - def _filter_extras(dm): + def _filter_extras(dm: dict[str | None, list[Requirement]]): """ Given a mapping of extras to dependencies, strip off environment markers and filter out any dependencies not matching the markers. """ for extra in list(filter(None, dm)): - new_extra = extra + new_extra: str | None = extra reqs = dm.pop(extra) new_extra, _, marker = extra.partition(':') fails_marker = marker and ( @@ -2915,7 +3069,7 @@ class Distribution: def requires(self, extras: Iterable[str] = ()): """List of Requirements needed for this distro if `extras` are used""" dm = self._dep_map - deps = [] + deps: list[Requirement] = [] deps.extend(dm.get(None, ())) for ext in extras: try: @@ -2951,7 +3105,7 @@ class Distribution: lines = self._get_metadata(self.PKG_INFO) return _version_from_file(lines) - def activate(self, path: Optional[List[str]] = None, replace: bool = False): + def activate(self, path: list[str] | None = None, replace: bool = False): """Ensure distribution is importable on `path` (default=sys.path)""" if path is None: path = sys.path @@ -3003,7 +3157,7 @@ class Distribution: @classmethod def from_filename( cls, - filename: str, + filename: StrPath, metadata: _MetadataType = None, **kw: int, # We could set `precedence` explicitly, but keeping this as `**kw` for full backwards and subclassing compatibility ): @@ -3020,14 +3174,18 @@ class Distribution: return Requirement.parse(spec) - def load_entry_point(self, group: str, name: str): + def load_entry_point(self, group: str, name: str) -> _ResolvedEntryPoint: """Return the `name` entry point of `group` or raise ImportError""" ep = self.get_entry_info(group, name) if ep is None: raise ImportError("Entry point %r not found" % ((group, name),)) return ep.load() - def get_entry_map(self, group: Optional[str] = None): + @overload + def get_entry_map(self, group: None = None) -> dict[str, dict[str, EntryPoint]]: ... + @overload + def get_entry_map(self, group: str) -> dict[str, EntryPoint]: ... + def get_entry_map(self, group: str | None = None): """Return the entry point map for `group`, or the full entry map""" if not hasattr(self, "_ep_map"): self._ep_map = EntryPoint.parse_map( @@ -3044,7 +3202,7 @@ class Distribution: # FIXME: 'Distribution.insert_on' is too complex (13) def insert_on( # noqa: C901 self, - path: List[str], + path: list[str], loc=None, replace: bool = False, ): @@ -3152,7 +3310,7 @@ class Distribution: return False return True - def clone(self, **kw: Optional[Union[str, int, IResourceProvider]]): + def clone(self, **kw: str | int | IResourceProvider | None): """Copy this distribution, substituting in any changed keyword args""" names = 'project_name version py_version platform location precedence' for attr in names.split(): @@ -3212,11 +3370,11 @@ class DistInfoDistribution(Distribution): self.__dep_map = self._compute_dependencies() return self.__dep_map - def _compute_dependencies(self): + def _compute_dependencies(self) -> dict[str | None, list[Requirement]]: """Recompute this distribution's dependencies.""" - dm = self.__dep_map = {None: []} + self.__dep_map: dict[str | None, list[Requirement]] = {None: []} - reqs = [] + reqs: list[Requirement] = [] # Including any condition expressions for req in self._parsed_pkg_info.get_all('Requires-Dist') or []: reqs.extend(parse_requirements(req)) @@ -3227,13 +3385,15 @@ class DistInfoDistribution(Distribution): yield req common = types.MappingProxyType(dict.fromkeys(reqs_for_extra(None))) - dm[None].extend(common) + self.__dep_map[None].extend(common) for extra in self._parsed_pkg_info.get_all('Provides-Extra') or []: s_extra = safe_extra(extra.strip()) - dm[s_extra] = [r for r in reqs_for_extra(extra) if r not in common] + self.__dep_map[s_extra] = [ + r for r in reqs_for_extra(extra) if r not in common + ] - return dm + return self.__dep_map _distributionImpl = { @@ -3278,7 +3438,7 @@ class Requirement(_packaging_requirements.Requirement): self.project_name, self.key = project_name, project_name.lower() self.specs = [(spec.operator, spec.version) for spec in self.specifier] # packaging.requirements.Requirement uses a set for its extras. We use a variable-length tuple - self.extras: Tuple[str] = tuple(map(safe_extra, self.extras)) + self.extras: tuple[str] = tuple(map(safe_extra, self.extras)) self.hashCmp = ( self.key, self.url, @@ -3294,7 +3454,7 @@ class Requirement(_packaging_requirements.Requirement): def __ne__(self, other): return not self == other - def __contains__(self, item: Union[Distribution, str, Tuple[str, ...]]): + def __contains__(self, item: Distribution | str | tuple[str, ...]) -> bool: if isinstance(item, Distribution): if item.key != self.key: return False @@ -3313,7 +3473,7 @@ class Requirement(_packaging_requirements.Requirement): return "Requirement.parse(%r)" % str(self) @staticmethod - def parse(s: Union[str, Iterable[str]]): + def parse(s: str | Iterable[str]): (req,) = parse_requirements(s) return req @@ -3339,7 +3499,7 @@ def _find_adapter(registry: Mapping[type, _AdapterT], ob: object) -> _AdapterT: raise TypeError(f"Could not find adapter for {registry} and {ob}") -def ensure_directory(path: str): +def ensure_directory(path: StrOrBytesPath): """Ensure that the parent directory of `path` exists""" dirname = os.path.dirname(path) os.makedirs(dirname, exist_ok=True) @@ -3358,7 +3518,7 @@ def _bypass_ensure_directory(path): pass -def split_sections(s: _NestedStr): +def split_sections(s: _NestedStr) -> Iterator[tuple[str | None, list[str]]]: """Split a string or iterable thereof into (section, content) pairs Each ``section`` is a stripped version of the section header ("[section]") @@ -3411,6 +3571,38 @@ class PkgResourcesDeprecationWarning(Warning): """ +# Ported from ``setuptools`` to avoid introducing an import inter-dependency: +_LOCALE_ENCODING = "locale" if sys.version_info >= (3, 10) else None + + +def _read_utf8_with_fallback(file: str, fallback_encoding=_LOCALE_ENCODING) -> str: + """See setuptools.unicode_utils._read_utf8_with_fallback""" + try: + with open(file, "r", encoding="utf-8") as f: + return f.read() + except UnicodeDecodeError: # pragma: no cover + msg = f"""\ + ******************************************************************************** + `encoding="utf-8"` fails with {file!r}, trying `encoding={fallback_encoding!r}`. + + This fallback behaviour is considered **deprecated** and future versions of + `setuptools/pkg_resources` may not implement it. + + Please encode {file!r} with "utf-8" to ensure future builds will succeed. + + If this file was produced by `setuptools` itself, cleaning up the cached files + and re-building/re-installing the package with a newer version of `setuptools` + (e.g. by updating `build-system.requires` in its `pyproject.toml`) + might solve the problem. + ******************************************************************************** + """ + # TODO: Add a deadline? + # See comment in setuptools.unicode_utils._Utf8EncodingNeeded + warnings.warn(msg, PkgResourcesDeprecationWarning, stacklevel=2) + with open(file, "r", encoding=fallback_encoding) as f: + return f.read() + + # from jaraco.functools 1.3 def _call_aside(f, *args, **kwargs): f(*args, **kwargs) @@ -3483,35 +3675,3 @@ if TYPE_CHECKING: add_activation_listener = working_set.subscribe run_script = working_set.run_script run_main = run_script - - -# ---- Ported from ``setuptools`` to avoid introducing an import inter-dependency ---- -LOCALE_ENCODING = "locale" if sys.version_info >= (3, 10) else None - - -def _read_utf8_with_fallback(file: str, fallback_encoding=LOCALE_ENCODING) -> str: - """See setuptools.unicode_utils._read_utf8_with_fallback""" - try: - with open(file, "r", encoding="utf-8") as f: - return f.read() - except UnicodeDecodeError: # pragma: no cover - msg = f"""\ - ******************************************************************************** - `encoding="utf-8"` fails with {file!r}, trying `encoding={fallback_encoding!r}`. - - This fallback behaviour is considered **deprecated** and future versions of - `setuptools/pkg_resources` may not implement it. - - Please encode {file!r} with "utf-8" to ensure future builds will succeed. - - If this file was produced by `setuptools` itself, cleaning up the cached files - and re-building/re-installing the package with a newer version of `setuptools` - (e.g. by updating `build-system.requires` in its `pyproject.toml`) - might solve the problem. - ******************************************************************************** - """ - # TODO: Add a deadline? - # See comment in setuptools.unicode_utils._Utf8EncodingNeeded - warnings.warn(msg, PkgResourcesDeprecationWarning, stacklevel=2) - with open(file, "r", encoding=fallback_encoding) as f: - return f.read() diff --git a/lib/pkg_resources/extern/__init__.py b/lib/pkg_resources/extern/__init__.py index a1b7490d..9b9ac10a 100644 --- a/lib/pkg_resources/extern/__init__.py +++ b/lib/pkg_resources/extern/__init__.py @@ -1,8 +1,9 @@ +from __future__ import annotations from importlib.machinery import ModuleSpec import importlib.util import sys from types import ModuleType -from typing import Iterable, Optional, Sequence +from typing import Iterable, Sequence class VendorImporter: @@ -15,7 +16,7 @@ class VendorImporter: self, root_name: str, vendored_names: Iterable[str] = (), - vendor_pkg: Optional[str] = None, + vendor_pkg: str | None = None, ): self.root_name = root_name self.vendored_names = set(vendored_names) @@ -65,8 +66,8 @@ class VendorImporter: def find_spec( self, fullname: str, - path: Optional[Sequence[str]] = None, - target: Optional[ModuleType] = None, + path: Sequence[str] | None = None, + target: ModuleType | None = None, ): """Return a module spec for vendored names.""" return ( diff --git a/lib/urllib3/_base_connection.py b/lib/urllib3/_base_connection.py index bb349c74..29ca3348 100644 --- a/lib/urllib3/_base_connection.py +++ b/lib/urllib3/_base_connection.py @@ -12,7 +12,7 @@ _TYPE_BODY = typing.Union[bytes, typing.IO[typing.Any], typing.Iterable[bytes], class ProxyConfig(typing.NamedTuple): ssl_context: ssl.SSLContext | None use_forwarding_for_https: bool - assert_hostname: None | str | Literal[False] + assert_hostname: None | str | typing.Literal[False] assert_fingerprint: str | None @@ -28,7 +28,7 @@ class _ResponseOptions(typing.NamedTuple): if typing.TYPE_CHECKING: import ssl - from typing import Literal, Protocol + from typing import Protocol from .response import BaseHTTPResponse @@ -124,7 +124,7 @@ if typing.TYPE_CHECKING: # Certificate verification methods cert_reqs: int | str | None - assert_hostname: None | str | Literal[False] + assert_hostname: None | str | typing.Literal[False] assert_fingerprint: str | None ssl_context: ssl.SSLContext | None @@ -155,7 +155,7 @@ if typing.TYPE_CHECKING: proxy: Url | None = None, proxy_config: ProxyConfig | None = None, cert_reqs: int | str | None = None, - assert_hostname: None | str | Literal[False] = None, + assert_hostname: None | str | typing.Literal[False] = None, assert_fingerprint: str | None = None, server_hostname: str | None = None, ssl_context: ssl.SSLContext | None = None, diff --git a/lib/urllib3/_collections.py b/lib/urllib3/_collections.py index 55b03247..8a4409a1 100644 --- a/lib/urllib3/_collections.py +++ b/lib/urllib3/_collections.py @@ -427,7 +427,7 @@ class HTTPHeaderDict(typing.MutableMapping[str, str]): val = other.getlist(key) self._container[key.lower()] = [key, *val] - def copy(self) -> HTTPHeaderDict: + def copy(self) -> Self: clone = type(self)() clone._copy_from(self) return clone @@ -462,7 +462,7 @@ class HTTPHeaderDict(typing.MutableMapping[str, str]): self.extend(maybe_constructable) return self - def __or__(self, other: object) -> HTTPHeaderDict: + def __or__(self, other: object) -> Self: # Supports merging header dicts using operator | # combining items with add instead of __setitem__ maybe_constructable = ensure_can_construct_http_header_dict(other) @@ -472,7 +472,7 @@ class HTTPHeaderDict(typing.MutableMapping[str, str]): result.extend(maybe_constructable) return result - def __ror__(self, other: object) -> HTTPHeaderDict: + def __ror__(self, other: object) -> Self: # Supports merging header dicts using operator | when other is on left side # combining items with add instead of __setitem__ maybe_constructable = ensure_can_construct_http_header_dict(other) diff --git a/lib/urllib3/_version.py b/lib/urllib3/_version.py index 095cf3c1..7442f2b8 100644 --- a/lib/urllib3/_version.py +++ b/lib/urllib3/_version.py @@ -1,4 +1,4 @@ # This file is protected via CODEOWNERS from __future__ import annotations -__version__ = "2.2.1" +__version__ = "2.2.2" diff --git a/lib/urllib3/connection.py b/lib/urllib3/connection.py index aa5c547c..1b162792 100644 --- a/lib/urllib3/connection.py +++ b/lib/urllib3/connection.py @@ -14,8 +14,6 @@ from http.client import ResponseNotReady from socket import timeout as SocketTimeout if typing.TYPE_CHECKING: - from typing import Literal - from .response import HTTPResponse from .util.ssl_ import _TYPE_PEER_CERT_RET_DICT from .util.ssltransport import SSLTransport @@ -482,6 +480,7 @@ class HTTPConnection(_HTTPConnection): headers=headers, status=httplib_response.status, version=httplib_response.version, + version_string=getattr(self, "_http_vsn_str", "HTTP/?"), reason=httplib_response.reason, preload_content=resp_options.preload_content, decode_content=resp_options.decode_content, @@ -523,7 +522,7 @@ class HTTPSConnection(HTTPConnection): proxy: Url | None = None, proxy_config: ProxyConfig | None = None, cert_reqs: int | str | None = None, - assert_hostname: None | str | Literal[False] = None, + assert_hostname: None | str | typing.Literal[False] = None, assert_fingerprint: str | None = None, server_hostname: str | None = None, ssl_context: ssl.SSLContext | None = None, @@ -577,7 +576,7 @@ class HTTPSConnection(HTTPConnection): cert_reqs: int | str | None = None, key_password: str | None = None, ca_certs: str | None = None, - assert_hostname: None | str | Literal[False] = None, + assert_hostname: None | str | typing.Literal[False] = None, assert_fingerprint: str | None = None, ca_cert_dir: str | None = None, ca_cert_data: None | str | bytes = None, @@ -742,7 +741,7 @@ def _ssl_wrap_socket_and_match_hostname( ca_certs: str | None, ca_cert_dir: str | None, ca_cert_data: None | str | bytes, - assert_hostname: None | str | Literal[False], + assert_hostname: None | str | typing.Literal[False], assert_fingerprint: str | None, server_hostname: str | None, ssl_context: ssl.SSLContext | None, diff --git a/lib/urllib3/connectionpool.py b/lib/urllib3/connectionpool.py index bd58ff14..a2c3cf60 100644 --- a/lib/urllib3/connectionpool.py +++ b/lib/urllib3/connectionpool.py @@ -53,7 +53,8 @@ from .util.util import to_str if typing.TYPE_CHECKING: import ssl - from typing import Literal + + from typing_extensions import Self from ._base_connection import BaseHTTPConnection, BaseHTTPSConnection @@ -61,8 +62,6 @@ log = logging.getLogger(__name__) _TYPE_TIMEOUT = typing.Union[Timeout, float, _TYPE_DEFAULT, None] -_SelfT = typing.TypeVar("_SelfT") - # Pool objects class ConnectionPool: @@ -95,7 +94,7 @@ class ConnectionPool: def __str__(self) -> str: return f"{type(self).__name__}(host={self.host!r}, port={self.port!r})" - def __enter__(self: _SelfT) -> _SelfT: + def __enter__(self) -> Self: return self def __exit__( @@ -103,7 +102,7 @@ class ConnectionPool: exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> Literal[False]: + ) -> typing.Literal[False]: self.close() # Return False to re-raise any potential exceptions return False @@ -544,17 +543,14 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): response._connection = response_conn # type: ignore[attr-defined] response._pool = self # type: ignore[attr-defined] - # emscripten connection doesn't have _http_vsn_str - http_version = getattr(conn, "_http_vsn_str", "HTTP/?") log.debug( - '%s://%s:%s "%s %s %s" %s %s', + '%s://%s:%s "%s %s HTTP/%s" %s %s', self.scheme, self.host, self.port, method, url, - # HTTP version - http_version, + response.version, response.status, response.length_remaining, ) @@ -1002,7 +998,7 @@ class HTTPSConnectionPool(HTTPConnectionPool): ssl_version: int | str | None = None, ssl_minimum_version: ssl.TLSVersion | None = None, ssl_maximum_version: ssl.TLSVersion | None = None, - assert_hostname: str | Literal[False] | None = None, + assert_hostname: str | typing.Literal[False] | None = None, assert_fingerprint: str | None = None, ca_cert_dir: str | None = None, **conn_kw: typing.Any, diff --git a/lib/urllib3/contrib/socks.py b/lib/urllib3/contrib/socks.py index 5a803916..c62b5e03 100644 --- a/lib/urllib3/contrib/socks.py +++ b/lib/urllib3/contrib/socks.py @@ -71,10 +71,8 @@ try: except ImportError: ssl = None # type: ignore[assignment] -from typing import TypedDict - -class _TYPE_SOCKS_OPTIONS(TypedDict): +class _TYPE_SOCKS_OPTIONS(typing.TypedDict): socks_version: int proxy_host: str | None proxy_port: str | None diff --git a/lib/urllib3/http2.py b/lib/urllib3/http2.py index 15fa9d91..ceb40602 100644 --- a/lib/urllib3/http2.py +++ b/lib/urllib3/http2.py @@ -195,6 +195,7 @@ class HTTP2Response(BaseHTTPResponse): headers=headers, # Following CPython, we map HTTP versions to major * 10 + minor integers version=20, + version_string="HTTP/2", # No reason phrase in HTTP/2 reason=None, decode_content=decode_content, diff --git a/lib/urllib3/poolmanager.py b/lib/urllib3/poolmanager.py index 32da0a00..085d1dba 100644 --- a/lib/urllib3/poolmanager.py +++ b/lib/urllib3/poolmanager.py @@ -26,7 +26,8 @@ from .util.url import Url, parse_url if typing.TYPE_CHECKING: import ssl - from typing import Literal + + from typing_extensions import Self __all__ = ["PoolManager", "ProxyManager", "proxy_from_url"] @@ -51,8 +52,6 @@ SSL_KEYWORDS = ( # http.client.HTTPConnection & http.client.HTTPSConnection in Python 3.7 _DEFAULT_BLOCKSIZE = 16384 -_SelfT = typing.TypeVar("_SelfT") - class PoolKey(typing.NamedTuple): """ @@ -214,7 +213,7 @@ class PoolManager(RequestMethods): self.pool_classes_by_scheme = pool_classes_by_scheme self.key_fn_by_scheme = key_fn_by_scheme.copy() - def __enter__(self: _SelfT) -> _SelfT: + def __enter__(self) -> Self: return self def __exit__( @@ -222,7 +221,7 @@ class PoolManager(RequestMethods): exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> Literal[False]: + ) -> typing.Literal[False]: self.clear() # Return False to re-raise any potential exceptions return False @@ -553,7 +552,7 @@ class ProxyManager(PoolManager): proxy_headers: typing.Mapping[str, str] | None = None, proxy_ssl_context: ssl.SSLContext | None = None, use_forwarding_for_https: bool = False, - proxy_assert_hostname: None | str | Literal[False] = None, + proxy_assert_hostname: None | str | typing.Literal[False] = None, proxy_assert_fingerprint: str | None = None, **connection_pool_kw: typing.Any, ) -> None: diff --git a/lib/urllib3/response.py b/lib/urllib3/response.py index d31fac9b..a0273d65 100644 --- a/lib/urllib3/response.py +++ b/lib/urllib3/response.py @@ -26,20 +26,21 @@ except ImportError: brotli = None try: - import zstandard as zstd # type: ignore[import-not-found] - + import zstandard as zstd +except (AttributeError, ImportError, ValueError): # Defensive: + HAS_ZSTD = False +else: # The package 'zstandard' added the 'eof' property starting # in v0.18.0 which we require to ensure a complete and # valid zstd stream was fed into the ZstdDecoder. # See: https://github.com/urllib3/urllib3/pull/2624 - _zstd_version = _zstd_version = tuple( + _zstd_version = tuple( map(int, re.search(r"^([0-9]+)\.([0-9]+)", zstd.__version__).groups()) # type: ignore[union-attr] ) if _zstd_version < (0, 18): # Defensive: - zstd = None - -except (AttributeError, ImportError, ValueError): # Defensive: - zstd = None + HAS_ZSTD = False + else: + HAS_ZSTD = True from . import util from ._base_connection import _TYPE_BODY @@ -61,8 +62,6 @@ from .util.response import is_fp_closed, is_response_to_head from .util.retry import Retry if typing.TYPE_CHECKING: - from typing import Literal - from .connectionpool import HTTPConnectionPool log = logging.getLogger(__name__) @@ -163,7 +162,7 @@ if brotli is not None: return b"" -if zstd is not None: +if HAS_ZSTD: class ZstdDecoder(ContentDecoder): def __init__(self) -> None: @@ -183,7 +182,7 @@ if zstd is not None: ret = self._obj.flush() # note: this is a no-op if not self._obj.eof: raise DecodeError("Zstandard data is incomplete") - return ret # type: ignore[no-any-return] + return ret class MultiDecoder(ContentDecoder): @@ -219,7 +218,7 @@ def _get_decoder(mode: str) -> ContentDecoder: if brotli is not None and mode == "br": return BrotliDecoder() - if zstd is not None and mode == "zstd": + if HAS_ZSTD and mode == "zstd": return ZstdDecoder() return DeflateDecoder() @@ -302,7 +301,7 @@ class BaseHTTPResponse(io.IOBase): CONTENT_DECODERS = ["gzip", "x-gzip", "deflate"] if brotli is not None: CONTENT_DECODERS += ["br"] - if zstd is not None: + if HAS_ZSTD: CONTENT_DECODERS += ["zstd"] REDIRECT_STATUSES = [301, 302, 303, 307, 308] @@ -310,7 +309,7 @@ class BaseHTTPResponse(io.IOBase): if brotli is not None: DECODER_ERROR_CLASSES += (brotli.error,) - if zstd is not None: + if HAS_ZSTD: DECODER_ERROR_CLASSES += (zstd.ZstdError,) def __init__( @@ -319,6 +318,7 @@ class BaseHTTPResponse(io.IOBase): headers: typing.Mapping[str, str] | typing.Mapping[bytes, bytes] | None = None, status: int, version: int, + version_string: str, reason: str | None, decode_content: bool, request_url: str | None, @@ -330,6 +330,7 @@ class BaseHTTPResponse(io.IOBase): self.headers = HTTPHeaderDict(headers) # type: ignore[arg-type] self.status = status self.version = version + self.version_string = version_string self.reason = reason self.decode_content = decode_content self._has_decoded_content = False @@ -346,7 +347,7 @@ class BaseHTTPResponse(io.IOBase): self._decoder: ContentDecoder | None = None self.length_remaining: int | None - def get_redirect_location(self) -> str | None | Literal[False]: + def get_redirect_location(self) -> str | None | typing.Literal[False]: """ Should we redirect and where to? @@ -364,13 +365,21 @@ class BaseHTTPResponse(io.IOBase): def json(self) -> typing.Any: """ - Parses the body of the HTTP response as JSON. + Deserializes the body of the HTTP response as a Python object. - To use a custom JSON decoder pass the result of :attr:`HTTPResponse.data` to the decoder. + The body of the HTTP response must be encoded using UTF-8, as per + `RFC 8529 Section 8.1 `_. - This method can raise either `UnicodeDecodeError` or `json.JSONDecodeError`. + To use a custom JSON decoder pass the result of :attr:`HTTPResponse.data` to + your custom decoder instead. - Read more :ref:`here `. + If the body of the HTTP response is not decodable to UTF-8, a + `UnicodeDecodeError` will be raised. If the body of the HTTP response is not a + valid JSON document, a `json.JSONDecodeError` will be raised. + + Read more :ref:`here `. + + :returns: The body of the HTTP response as a Python object. """ data = self.data.decode("utf-8") return _json.loads(data) @@ -567,6 +576,7 @@ class HTTPResponse(BaseHTTPResponse): headers: typing.Mapping[str, str] | typing.Mapping[bytes, bytes] | None = None, status: int = 0, version: int = 0, + version_string: str = "HTTP/?", reason: str | None = None, preload_content: bool = True, decode_content: bool = True, @@ -584,6 +594,7 @@ class HTTPResponse(BaseHTTPResponse): headers=headers, status=status, version=version, + version_string=version_string, reason=reason, decode_content=decode_content, request_url=request_url, @@ -926,7 +937,10 @@ class HTTPResponse(BaseHTTPResponse): if decode_content is None: decode_content = self.decode_content - if amt is not None: + if amt and amt < 0: + # Negative numbers and `None` should be treated the same. + amt = None + elif amt is not None: cache_content = False if len(self._decoded_buffer) >= amt: @@ -986,6 +1000,9 @@ class HTTPResponse(BaseHTTPResponse): """ if decode_content is None: decode_content = self.decode_content + if amt and amt < 0: + # Negative numbers and `None` should be treated the same. + amt = None # try and respond without going to the network if self._has_decoded_content: if not decode_content: @@ -1180,6 +1197,11 @@ class HTTPResponse(BaseHTTPResponse): if self._fp.fp is None: # type: ignore[union-attr] return None + if amt and amt < 0: + # Negative numbers and `None` should be treated the same, + # but httplib handles only `None` correctly. + amt = None + while True: self._update_chunk_length() if self.chunk_left == 0: diff --git a/lib/urllib3/util/request.py b/lib/urllib3/util/request.py index fe0e3485..859597e2 100644 --- a/lib/urllib3/util/request.py +++ b/lib/urllib3/util/request.py @@ -29,7 +29,7 @@ except ImportError: else: ACCEPT_ENCODING += ",br" try: - import zstandard as _unused_module_zstd # type: ignore[import-not-found] # noqa: F401 + import zstandard as _unused_module_zstd # noqa: F401 except ImportError: pass else: diff --git a/lib/urllib3/util/retry.py b/lib/urllib3/util/retry.py index 7572bfd2..0456cceb 100644 --- a/lib/urllib3/util/retry.py +++ b/lib/urllib3/util/retry.py @@ -21,6 +21,8 @@ from ..exceptions import ( from .util import reraise if typing.TYPE_CHECKING: + from typing_extensions import Self + from ..connectionpool import ConnectionPool from ..response import BaseHTTPResponse @@ -187,7 +189,9 @@ class Retry: RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503]) #: Default headers to be used for ``remove_headers_on_redirect`` - DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset(["Cookie", "Authorization"]) + DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset( + ["Cookie", "Authorization", "Proxy-Authorization"] + ) #: Default maximum backoff time. DEFAULT_BACKOFF_MAX = 120 @@ -240,7 +244,7 @@ class Retry: ) self.backoff_jitter = backoff_jitter - def new(self, **kw: typing.Any) -> Retry: + def new(self, **kw: typing.Any) -> Self: params = dict( total=self.total, connect=self.connect, @@ -429,7 +433,7 @@ class Retry: error: Exception | None = None, _pool: ConnectionPool | None = None, _stacktrace: TracebackType | None = None, - ) -> Retry: + ) -> Self: """Return a new Retry object with incremented retry counters. :param response: A response object, or None, if the server did not diff --git a/lib/urllib3/util/ssl_.py b/lib/urllib3/util/ssl_.py index b14cf27b..c46dd83e 100644 --- a/lib/urllib3/util/ssl_.py +++ b/lib/urllib3/util/ssl_.py @@ -78,7 +78,7 @@ def _is_has_never_check_common_name_reliable( if typing.TYPE_CHECKING: from ssl import VerifyMode - from typing import Literal, TypedDict + from typing import TypedDict from .ssltransport import SSLTransport as SSLTransportType @@ -365,7 +365,7 @@ def ssl_wrap_socket( ca_cert_dir: str | None = ..., key_password: str | None = ..., ca_cert_data: None | str | bytes = ..., - tls_in_tls: Literal[False] = ..., + tls_in_tls: typing.Literal[False] = ..., ) -> ssl.SSLSocket: ... diff --git a/lib/urllib3/util/ssltransport.py b/lib/urllib3/util/ssltransport.py index fa9f2b37..b52c477c 100644 --- a/lib/urllib3/util/ssltransport.py +++ b/lib/urllib3/util/ssltransport.py @@ -8,12 +8,11 @@ import typing from ..exceptions import ProxySchemeUnsupported if typing.TYPE_CHECKING: - from typing import Literal + from typing_extensions import Self from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT -_SelfT = typing.TypeVar("_SelfT", bound="SSLTransport") _WriteBuffer = typing.Union[bytearray, memoryview] _ReturnValue = typing.TypeVar("_ReturnValue") @@ -70,7 +69,7 @@ class SSLTransport: # Perform initial handshake. self._ssl_io_loop(self.sslobj.do_handshake) - def __enter__(self: _SelfT) -> _SelfT: + def __enter__(self) -> Self: return self def __exit__(self, *_: typing.Any) -> None: @@ -174,12 +173,12 @@ class SSLTransport: @typing.overload def getpeercert( - self, binary_form: Literal[False] = ... + self, binary_form: typing.Literal[False] = ... ) -> _TYPE_PEER_CERT_RET_DICT | None: ... @typing.overload - def getpeercert(self, binary_form: Literal[True]) -> bytes | None: + def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: ... def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: diff --git a/sickgear/__init__.py b/sickgear/__init__.py index e079fa16..ad405d9d 100644 --- a/sickgear/__init__.py +++ b/sickgear/__init__.py @@ -2477,14 +2477,14 @@ def _save_config(force=False, **kwargs): try: if check_valid_config(CONFIG_FILE): for _t in range(0, 3): - copy_file(CONFIG_FILE, backup_config) - if not check_valid_config(backup_config): - if 2 > _t: - logger.debug('backup config file seems to be invalid, retrying...') - else: - logger.warning('backup config file seems to be invalid, not backing up.') - backup_config = None - remove_file_perm(backup_config) + copy_file(CONFIG_FILE, backup_config) + if not check_valid_config(backup_config): + if 2 > _t: + logger.debug('backup config file seems to be invalid, retrying...') + else: + logger.warning('backup config file seems to be invalid, not backing up.') + backup_config = None + remove_file_perm(backup_config) 2 > _t and time.sleep(3) else: break