Update package resource API 68.1.2 (1ef36f2) → 68.2.2 (8ad627d).

This commit is contained in:
JackDandy 2024-06-07 17:18:00 +01:00
parent 354e8d640a
commit 4572ed367c
27 changed files with 5014 additions and 2976 deletions

View file

@ -10,6 +10,7 @@
* Update filelock 3.12.4 (c1163ae) to 3.14.0 (8556141) * Update filelock 3.12.4 (c1163ae) to 3.14.0 (8556141)
* Update idna library 3.4 (cab054c) to 3.7 (1d365e1) * Update idna library 3.4 (cab054c) to 3.7 (1d365e1)
* Update imdbpie 5.6.4 (f695e87) to 5.6.5 (f8ed7a0) * Update imdbpie 5.6.4 (f695e87) to 5.6.5 (f8ed7a0)
* Update package resource API 68.1.2 (1ef36f2) to 68.2.2 (8ad627d)
* Update profilehooks module 1.12.1 (c3fc078) to 1.13.0.dev0 (99f8a31) * Update profilehooks module 1.12.1 (c3fc078) to 1.13.0.dev0 (99f8a31)
* Update pytz 2023.3/2023c (488d3eb) to 2024.1/2024a (3680953) * Update pytz 2023.3/2023c (488d3eb) to 2024.1/2024a (3680953)
* Update Rarfile 4.1a1 (8a72967) to 4.2 (db1df33) * Update Rarfile 4.1a1 (8a72967) to 4.2 (db1df33)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,15 +1,26 @@
import os from __future__ import annotations
import subprocess
import contextlib import contextlib
import functools import functools
import tempfile
import shutil
import operator import operator
import os
import shutil
import subprocess
import sys
import tempfile
import urllib.request
import warnings import warnings
from typing import Iterator
if sys.version_info < (3, 12):
from pkg_resources.extern.backports import tarfile
else:
import tarfile
@contextlib.contextmanager @contextlib.contextmanager
def pushd(dir): def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]:
""" """
>>> tmp_path = getfixture('tmp_path') >>> tmp_path = getfixture('tmp_path')
>>> with pushd(tmp_path): >>> with pushd(tmp_path):
@ -26,33 +37,88 @@ def pushd(dir):
@contextlib.contextmanager @contextlib.contextmanager
def tarball_context(url, target_dir=None, runner=None, pushd=pushd): def tarball(
url, target_dir: str | os.PathLike | None = None
) -> Iterator[str | os.PathLike]:
""" """
Get a tarball, extract it, change to that directory, yield, then Get a tarball, extract it, yield, then clean up.
clean up.
`runner` is the function to invoke commands. >>> import urllib.request
`pushd` is a context manager for changing the directory. >>> url = getfixture('tarfile_served')
>>> target = getfixture('tmp_path') / 'out'
>>> tb = tarball(url, target_dir=target)
>>> import pathlib
>>> with tb as extracted:
... contents = pathlib.Path(extracted, 'contents.txt').read_text(encoding='utf-8')
>>> assert not os.path.exists(extracted)
""" """
if target_dir is None: if target_dir is None:
target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '') target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
if runner is None:
runner = functools.partial(subprocess.check_call, shell=True)
else:
warnings.warn("runner parameter is deprecated", DeprecationWarning)
# In the tar command, use --strip-components=1 to strip the first path and # In the tar command, use --strip-components=1 to strip the first path and
# then # then
# use -C to cause the files to be extracted to {target_dir}. This ensures # use -C to cause the files to be extracted to {target_dir}. This ensures
# that we always know where the files were extracted. # that we always know where the files were extracted.
runner('mkdir {target_dir}'.format(**vars())) os.mkdir(target_dir)
try: try:
getter = 'wget {url} -O -' req = urllib.request.urlopen(url)
extract = 'tar x{compression} --strip-components=1 -C {target_dir}' with tarfile.open(fileobj=req, mode='r|*') as tf:
cmd = ' | '.join((getter, extract)) tf.extractall(path=target_dir, filter=strip_first_component)
runner(cmd.format(compression=infer_compression(url), **vars())) yield target_dir
with pushd(target_dir):
yield target_dir
finally: finally:
runner('rm -Rf {target_dir}'.format(**vars())) shutil.rmtree(target_dir)
def strip_first_component(
member: tarfile.TarInfo,
path,
) -> tarfile.TarInfo:
_, member.name = member.name.split('/', 1)
return member
def _compose(*cmgrs):
"""
Compose any number of dependent context managers into a single one.
The last, innermost context manager may take arbitrary arguments, but
each successive context manager should accept the result from the
previous as a single parameter.
Like :func:`jaraco.functools.compose`, behavior works from right to
left, so the context manager should be indicated from outermost to
innermost.
Example, to create a context manager to change to a temporary
directory:
>>> temp_dir_as_cwd = _compose(pushd, temp_dir)
>>> with temp_dir_as_cwd() as dir:
... assert os.path.samefile(os.getcwd(), dir)
"""
def compose_two(inner, outer):
def composed(*args, **kwargs):
with inner(*args, **kwargs) as saved, outer(saved) as res:
yield res
return contextlib.contextmanager(composed)
return functools.reduce(compose_two, reversed(cmgrs))
tarball_cwd = _compose(pushd, tarball)
@contextlib.contextmanager
def tarball_context(*args, **kwargs):
warnings.warn(
"tarball_context is deprecated. Use tarball or tarball_cwd instead.",
DeprecationWarning,
stacklevel=2,
)
pushd_ctx = kwargs.pop('pushd', pushd)
with tarball(*args, **kwargs) as tball, pushd_ctx(tball) as dir:
yield dir
def infer_compression(url): def infer_compression(url):
@ -68,6 +134,11 @@ def infer_compression(url):
>>> infer_compression('file.xz') >>> infer_compression('file.xz')
'J' 'J'
""" """
warnings.warn(
"infer_compression is deprecated with no replacement",
DeprecationWarning,
stacklevel=2,
)
# cheat and just assume it's the last two characters # cheat and just assume it's the last two characters
compression_indicator = url[-2:] compression_indicator = url[-2:]
mapping = dict(gz='z', bz='j', xz='J') mapping = dict(gz='z', bz='j', xz='J')
@ -84,7 +155,7 @@ def temp_dir(remover=shutil.rmtree):
>>> import pathlib >>> import pathlib
>>> with temp_dir() as the_dir: >>> with temp_dir() as the_dir:
... assert os.path.isdir(the_dir) ... assert os.path.isdir(the_dir)
... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents') ... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents', encoding='utf-8')
>>> assert not os.path.exists(the_dir) >>> assert not os.path.exists(the_dir)
""" """
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
@ -113,15 +184,23 @@ def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
yield repo_dir yield repo_dir
@contextlib.contextmanager
def null(): def null():
""" """
A null context suitable to stand in for a meaningful context. A null context suitable to stand in for a meaningful context.
>>> with null() as value: >>> with null() as value:
... assert value is None ... assert value is None
This context is most useful when dealing with two or more code
branches but only some need a context. Wrap the others in a null
context to provide symmetry across all options.
""" """
yield warnings.warn(
"null is deprecated. Use contextlib.nullcontext",
DeprecationWarning,
stacklevel=2,
)
return contextlib.nullcontext()
class ExceptionTrap: class ExceptionTrap:
@ -267,13 +346,7 @@ class on_interrupt(contextlib.ContextDecorator):
... on_interrupt('ignore')(do_interrupt)() ... on_interrupt('ignore')(do_interrupt)()
""" """
def __init__( def __init__(self, action='error', /, code=1):
self,
action='error',
# py3.7 compat
# /,
code=1,
):
self.action = action self.action = action
self.code = code self.code = code

View file

@ -1,18 +1,14 @@
import collections.abc
import functools import functools
import time
import inspect import inspect
import collections
import types
import itertools import itertools
import operator
import time
import types
import warnings import warnings
import pkg_resources.extern.more_itertools import pkg_resources.extern.more_itertools
from typing import Callable, TypeVar
CallableT = TypeVar("CallableT", bound=Callable[..., object])
def compose(*funcs): def compose(*funcs):
""" """
@ -38,24 +34,6 @@ def compose(*funcs):
return functools.reduce(compose_two, funcs) return functools.reduce(compose_two, funcs)
def method_caller(method_name, *args, **kwargs):
"""
Return a function that will call a named method on the
target object with optional positional and keyword
arguments.
>>> lower = method_caller('lower')
>>> lower('MyString')
'mystring'
"""
def call_method(target):
func = getattr(target, method_name)
return func(*args, **kwargs)
return call_method
def once(func): def once(func):
""" """
Decorate func so it's only ever called the first time. Decorate func so it's only ever called the first time.
@ -98,12 +76,7 @@ def once(func):
return wrapper return wrapper
def method_cache( def method_cache(method, cache_wrapper=functools.lru_cache()):
method: CallableT,
cache_wrapper: Callable[
[CallableT], CallableT
] = functools.lru_cache(), # type: ignore[assignment]
) -> CallableT:
""" """
Wrap lru_cache to support storing the cache data in the object instances. Wrap lru_cache to support storing the cache data in the object instances.
@ -171,21 +144,17 @@ def method_cache(
for another implementation and additional justification. for another implementation and additional justification.
""" """
def wrapper(self: object, *args: object, **kwargs: object) -> object: def wrapper(self, *args, **kwargs):
# it's the first call, replace the method with a cached, bound method # it's the first call, replace the method with a cached, bound method
bound_method: CallableT = types.MethodType( # type: ignore[assignment] bound_method = types.MethodType(method, self)
method, self
)
cached_method = cache_wrapper(bound_method) cached_method = cache_wrapper(bound_method)
setattr(self, method.__name__, cached_method) setattr(self, method.__name__, cached_method)
return cached_method(*args, **kwargs) return cached_method(*args, **kwargs)
# Support cache clear even before cache has been created. # Support cache clear even before cache has been created.
wrapper.cache_clear = lambda: None # type: ignore[attr-defined] wrapper.cache_clear = lambda: None
return ( # type: ignore[return-value] return _special_method_cache(method, cache_wrapper) or wrapper
_special_method_cache(method, cache_wrapper) or wrapper
)
def _special_method_cache(method, cache_wrapper): def _special_method_cache(method, cache_wrapper):
@ -201,12 +170,13 @@ def _special_method_cache(method, cache_wrapper):
""" """
name = method.__name__ name = method.__name__
special_names = '__getattr__', '__getitem__' special_names = '__getattr__', '__getitem__'
if name not in special_names: if name not in special_names:
return return None
wrapper_name = '__cached' + name wrapper_name = '__cached' + name
def proxy(self, *args, **kwargs): def proxy(self, /, *args, **kwargs):
if wrapper_name not in vars(self): if wrapper_name not in vars(self):
bound = types.MethodType(method, self) bound = types.MethodType(method, self)
cache = cache_wrapper(bound) cache = cache_wrapper(bound)
@ -243,7 +213,7 @@ def result_invoke(action):
r""" r"""
Decorate a function with an action function that is Decorate a function with an action function that is
invoked on the results returned from the decorated invoked on the results returned from the decorated
function (for its side-effect), then return the original function (for its side effect), then return the original
result. result.
>>> @result_invoke(print) >>> @result_invoke(print)
@ -267,7 +237,7 @@ def result_invoke(action):
return wrap return wrap
def invoke(f, *args, **kwargs): def invoke(f, /, *args, **kwargs):
""" """
Call a function for its side effect after initialization. Call a function for its side effect after initialization.
@ -302,25 +272,15 @@ def invoke(f, *args, **kwargs):
Use functools.partial to pass parameters to the initial call Use functools.partial to pass parameters to the initial call
>>> @functools.partial(invoke, name='bingo') >>> @functools.partial(invoke, name='bingo')
... def func(name): print("called with", name) ... def func(name): print('called with', name)
called with bingo called with bingo
""" """
f(*args, **kwargs) f(*args, **kwargs)
return f return f
def call_aside(*args, **kwargs):
"""
Deprecated name for invoke.
"""
warnings.warn("call_aside is deprecated, use invoke", DeprecationWarning)
return invoke(*args, **kwargs)
class Throttler: class Throttler:
""" """Rate-limit a function (or other callable)."""
Rate-limit a function (or other callable)
"""
def __init__(self, func, max_rate=float('Inf')): def __init__(self, func, max_rate=float('Inf')):
if isinstance(func, Throttler): if isinstance(func, Throttler):
@ -337,20 +297,20 @@ class Throttler:
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
def _wait(self): def _wait(self):
"ensure at least 1/max_rate seconds from last call" """Ensure at least 1/max_rate seconds from last call."""
elapsed = time.time() - self.last_called elapsed = time.time() - self.last_called
must_wait = 1 / self.max_rate - elapsed must_wait = 1 / self.max_rate - elapsed
time.sleep(max(0, must_wait)) time.sleep(max(0, must_wait))
self.last_called = time.time() self.last_called = time.time()
def __get__(self, obj, type=None): def __get__(self, obj, owner=None):
return first_invoke(self._wait, functools.partial(self.func, obj)) return first_invoke(self._wait, functools.partial(self.func, obj))
def first_invoke(func1, func2): def first_invoke(func1, func2):
""" """
Return a function that when invoked will invoke func1 without Return a function that when invoked will invoke func1 without
any parameters (for its side-effect) and then invoke func2 any parameters (for its side effect) and then invoke func2
with whatever parameters were passed, returning its result. with whatever parameters were passed, returning its result.
""" """
@ -361,6 +321,17 @@ def first_invoke(func1, func2):
return wrapper return wrapper
method_caller = first_invoke(
lambda: warnings.warn(
'`jaraco.functools.method_caller` is deprecated, '
'use `operator.methodcaller` instead',
DeprecationWarning,
stacklevel=3,
),
operator.methodcaller,
)
def retry_call(func, cleanup=lambda: None, retries=0, trap=()): def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
""" """
Given a callable func, trap the indicated exceptions Given a callable func, trap the indicated exceptions
@ -369,7 +340,7 @@ def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
to propagate. to propagate.
""" """
attempts = itertools.count() if retries == float('inf') else range(retries) attempts = itertools.count() if retries == float('inf') else range(retries)
for attempt in attempts: for _ in attempts:
try: try:
return func() return func()
except trap: except trap:
@ -406,7 +377,7 @@ def retry(*r_args, **r_kwargs):
def print_yielded(func): def print_yielded(func):
""" """
Convert a generator into a function that prints all yielded elements Convert a generator into a function that prints all yielded elements.
>>> @print_yielded >>> @print_yielded
... def x(): ... def x():
@ -422,7 +393,7 @@ def print_yielded(func):
def pass_none(func): def pass_none(func):
""" """
Wrap func so it's not called if its first param is None Wrap func so it's not called if its first param is None.
>>> print_text = pass_none(print) >>> print_text = pass_none(print)
>>> print_text('text') >>> print_text('text')
@ -431,9 +402,10 @@ def pass_none(func):
""" """
@functools.wraps(func) @functools.wraps(func)
def wrapper(param, *args, **kwargs): def wrapper(param, /, *args, **kwargs):
if param is not None: if param is not None:
return func(param, *args, **kwargs) return func(param, *args, **kwargs)
return None
return wrapper return wrapper
@ -507,7 +479,7 @@ def save_method_args(method):
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs') args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
@functools.wraps(method) @functools.wraps(method)
def wrapper(self, *args, **kwargs): def wrapper(self, /, *args, **kwargs):
attr_name = '_saved_' + method.__name__ attr_name = '_saved_' + method.__name__
attr = args_and_kwargs(args, kwargs) attr = args_and_kwargs(args, kwargs)
setattr(self, attr_name, attr) setattr(self, attr_name, attr)
@ -554,3 +526,108 @@ def except_(*exceptions, replace=None, use=None):
return wrapper return wrapper
return decorate return decorate
def identity(x):
"""
Return the argument.
>>> o = object()
>>> identity(o) is o
True
"""
return x
def bypass_when(check, *, _op=identity):
"""
Decorate a function to return its parameter when ``check``.
>>> bypassed = [] # False
>>> @bypass_when(bypassed)
... def double(x):
... return x * 2
>>> double(2)
4
>>> bypassed[:] = [object()] # True
>>> double(2)
2
"""
def decorate(func):
@functools.wraps(func)
def wrapper(param, /):
return param if _op(check) else func(param)
return wrapper
return decorate
def bypass_unless(check):
"""
Decorate a function to return its parameter unless ``check``.
>>> enabled = [object()] # True
>>> @bypass_unless(enabled)
... def double(x):
... return x * 2
>>> double(2)
4
>>> del enabled[:] # False
>>> double(2)
2
"""
return bypass_when(check, _op=operator.not_)
@functools.singledispatch
def _splat_inner(args, func):
"""Splat args to func."""
return func(*args)
@_splat_inner.register
def _(args: collections.abc.Mapping, func):
"""Splat kargs to func as kwargs."""
return func(**args)
def splat(func):
"""
Wrap func to expect its parameters to be passed positionally in a tuple.
Has a similar effect to that of ``itertools.starmap`` over
simple ``map``.
>>> pairs = [(-1, 1), (0, 2)]
>>> pkg_resources.extern.more_itertools.consume(itertools.starmap(print, pairs))
-1 1
0 2
>>> pkg_resources.extern.more_itertools.consume(map(splat(print), pairs))
-1 1
0 2
The approach generalizes to other iterators that don't have a "star"
equivalent, such as a "starfilter".
>>> list(filter(splat(operator.add), pairs))
[(0, 2)]
Splat also accepts a mapping argument.
>>> def is_nice(msg, code):
... return "smile" in msg or code == 0
>>> msgs = [
... dict(msg='smile!', code=20),
... dict(msg='error :(', code=1),
... dict(msg='unknown', code=0),
... ]
>>> for msg in filter(splat(is_nice), msgs):
... print(msg)
{'msg': 'smile!', 'code': 20}
{'msg': 'unknown', 'code': 0}
"""
return functools.wraps(func)(functools.partial(_splat_inner, func=func))

View file

@ -0,0 +1,128 @@
from collections.abc import Callable, Hashable, Iterator
from functools import partial
from operator import methodcaller
import sys
from typing import (
Any,
Generic,
Protocol,
TypeVar,
overload,
)
if sys.version_info >= (3, 10):
from typing import Concatenate, ParamSpec
else:
from typing_extensions import Concatenate, ParamSpec
_P = ParamSpec('_P')
_R = TypeVar('_R')
_T = TypeVar('_T')
_R1 = TypeVar('_R1')
_R2 = TypeVar('_R2')
_V = TypeVar('_V')
_S = TypeVar('_S')
_R_co = TypeVar('_R_co', covariant=True)
class _OnceCallable(Protocol[_P, _R]):
saved_result: _R
reset: Callable[[], None]
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
class _ProxyMethodCacheWrapper(Protocol[_R_co]):
cache_clear: Callable[[], None]
def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ...
class _MethodCacheWrapper(Protocol[_R_co]):
def cache_clear(self) -> None: ...
def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ...
# `compose()` overloads below will cover most use cases.
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[_P, _R],
/,
) -> Callable[_P, _T]: ...
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[[_R1], _R],
__func3: Callable[_P, _R1],
/,
) -> Callable[_P, _T]: ...
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[[_R2], _R],
__func3: Callable[[_R1], _R2],
__func4: Callable[_P, _R1],
/,
) -> Callable[_P, _T]: ...
def once(func: Callable[_P, _R]) -> _OnceCallable[_P, _R]: ...
def method_cache(
method: Callable[..., _R],
cache_wrapper: Callable[[Callable[..., _R]], _MethodCacheWrapper[_R]] = ...,
) -> _MethodCacheWrapper[_R] | _ProxyMethodCacheWrapper[_R]: ...
def apply(
transform: Callable[[_R], _T]
) -> Callable[[Callable[_P, _R]], Callable[_P, _T]]: ...
def result_invoke(
action: Callable[[_R], Any]
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ...
def invoke(
f: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
def call_aside(
f: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
class Throttler(Generic[_R]):
last_called: float
func: Callable[..., _R]
max_rate: float
def __init__(
self, func: Callable[..., _R] | Throttler[_R], max_rate: float = ...
) -> None: ...
def reset(self) -> None: ...
def __call__(self, *args: Any, **kwargs: Any) -> _R: ...
def __get__(self, obj: Any, owner: type[Any] | None = ...) -> Callable[..., _R]: ...
def first_invoke(
func1: Callable[..., Any], func2: Callable[_P, _R]
) -> Callable[_P, _R]: ...
method_caller: Callable[..., methodcaller]
def retry_call(
func: Callable[..., _R],
cleanup: Callable[..., None] = ...,
retries: int | float = ...,
trap: type[BaseException] | tuple[type[BaseException], ...] = ...,
) -> _R: ...
def retry(
cleanup: Callable[..., None] = ...,
retries: int | float = ...,
trap: type[BaseException] | tuple[type[BaseException], ...] = ...,
) -> Callable[[Callable[..., _R]], Callable[..., _R]]: ...
def print_yielded(func: Callable[_P, Iterator[Any]]) -> Callable[_P, None]: ...
def pass_none(
func: Callable[Concatenate[_T, _P], _R]
) -> Callable[Concatenate[_T, _P], _R]: ...
def assign_params(
func: Callable[..., _R], namespace: dict[str, Any]
) -> partial[_R]: ...
def save_method_args(
method: Callable[Concatenate[_S, _P], _R]
) -> Callable[Concatenate[_S, _P], _R]: ...
def except_(
*exceptions: type[BaseException], replace: Any = ..., use: Any = ...
) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]: ...
def identity(x: _T) -> _T: ...
def bypass_when(
check: _V, *, _op: Callable[[_V], Any] = ...
) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ...
def bypass_unless(
check: Any,
) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ...

View file

@ -3,4 +3,4 @@
from .more import * # noqa from .more import * # noqa
from .recipes import * # noqa from .recipes import * # noqa
__version__ = '9.1.0' __version__ = '10.2.0'

View file

@ -2,7 +2,7 @@ import warnings
from collections import Counter, defaultdict, deque, abc from collections import Counter, defaultdict, deque, abc
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial, reduce, wraps from functools import cached_property, partial, reduce, wraps
from heapq import heapify, heapreplace, heappop from heapq import heapify, heapreplace, heappop
from itertools import ( from itertools import (
chain, chain,
@ -17,8 +17,9 @@ from itertools import (
takewhile, takewhile,
tee, tee,
zip_longest, zip_longest,
product,
) )
from math import exp, factorial, floor, log from math import exp, factorial, floor, log, perm, comb
from queue import Empty, Queue from queue import Empty, Queue
from random import random, randrange, uniform from random import random, randrange, uniform
from operator import itemgetter, mul, sub, gt, lt, ge, le from operator import itemgetter, mul, sub, gt, lt, ge, le
@ -36,6 +37,7 @@ from .recipes import (
take, take,
unique_everseen, unique_everseen,
all_equal, all_equal,
batched,
) )
__all__ = [ __all__ = [
@ -53,6 +55,7 @@ __all__ = [
'circular_shifts', 'circular_shifts',
'collapse', 'collapse',
'combination_index', 'combination_index',
'combination_with_replacement_index',
'consecutive_groups', 'consecutive_groups',
'constrained_batches', 'constrained_batches',
'consumer', 'consumer',
@ -65,8 +68,10 @@ __all__ = [
'divide', 'divide',
'duplicates_everseen', 'duplicates_everseen',
'duplicates_justseen', 'duplicates_justseen',
'classify_unique',
'exactly_n', 'exactly_n',
'filter_except', 'filter_except',
'filter_map',
'first', 'first',
'gray_product', 'gray_product',
'groupby_transform', 'groupby_transform',
@ -80,6 +85,7 @@ __all__ = [
'is_sorted', 'is_sorted',
'islice_extended', 'islice_extended',
'iterate', 'iterate',
'iter_suppress',
'last', 'last',
'locate', 'locate',
'longest_common_prefix', 'longest_common_prefix',
@ -93,10 +99,13 @@ __all__ = [
'nth_or_last', 'nth_or_last',
'nth_permutation', 'nth_permutation',
'nth_product', 'nth_product',
'nth_combination_with_replacement',
'numeric_range', 'numeric_range',
'one', 'one',
'only', 'only',
'outer_product',
'padded', 'padded',
'partial_product',
'partitions', 'partitions',
'peekable', 'peekable',
'permutation_index', 'permutation_index',
@ -125,6 +134,7 @@ __all__ = [
'strictly_n', 'strictly_n',
'substrings', 'substrings',
'substrings_indexes', 'substrings_indexes',
'takewhile_inclusive',
'time_limited', 'time_limited',
'unique_in_window', 'unique_in_window',
'unique_to_each', 'unique_to_each',
@ -191,15 +201,14 @@ def first(iterable, default=_marker):
``next(iter(iterable), default)``. ``next(iter(iterable), default)``.
""" """
try: for item in iterable:
return next(iter(iterable)) return item
except StopIteration as e: if default is _marker:
if default is _marker: raise ValueError(
raise ValueError( 'first() was called on an empty iterable, and no '
'first() was called on an empty iterable, and no ' 'default value was provided.'
'default value was provided.' )
) from e return default
return default
def last(iterable, default=_marker): def last(iterable, default=_marker):
@ -472,7 +481,10 @@ def iterate(func, start):
""" """
while True: while True:
yield start yield start
start = func(start) try:
start = func(start)
except StopIteration:
break
def with_iter(context_manager): def with_iter(context_manager):
@ -572,6 +584,9 @@ def strictly_n(iterable, n, too_short=None, too_long=None):
>>> list(strictly_n(iterable, n)) >>> list(strictly_n(iterable, n))
['a', 'b', 'c', 'd'] ['a', 'b', 'c', 'd']
Note that the returned iterable must be consumed in order for the check to
be made.
By default, *too_short* and *too_long* are functions that raise By default, *too_short* and *too_long* are functions that raise
``ValueError``. ``ValueError``.
@ -909,7 +924,7 @@ def substrings_indexes(seq, reverse=False):
class bucket: class bucket:
"""Wrap *iterable* and return an object that buckets it iterable into """Wrap *iterable* and return an object that buckets the iterable into
child iterables based on a *key* function. child iterables based on a *key* function.
>>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
@ -2069,7 +2084,6 @@ class numeric_range(abc.Sequence, abc.Hashable):
if self._step == self._zero: if self._step == self._zero:
raise ValueError('numeric_range() arg 3 must not be zero') raise ValueError('numeric_range() arg 3 must not be zero')
self._growing = self._step > self._zero self._growing = self._step > self._zero
self._init_len()
def __bool__(self): def __bool__(self):
if self._growing: if self._growing:
@ -2145,7 +2159,8 @@ class numeric_range(abc.Sequence, abc.Hashable):
def __len__(self): def __len__(self):
return self._len return self._len
def _init_len(self): @cached_property
def _len(self):
if self._growing: if self._growing:
start = self._start start = self._start
stop = self._stop stop = self._stop
@ -2156,10 +2171,10 @@ class numeric_range(abc.Sequence, abc.Hashable):
step = -self._step step = -self._step
distance = stop - start distance = stop - start
if distance <= self._zero: if distance <= self._zero:
self._len = 0 return 0
else: # distance > 0 and step > 0: regular euclidean division else: # distance > 0 and step > 0: regular euclidean division
q, r = divmod(distance, step) q, r = divmod(distance, step)
self._len = int(q) + int(r != self._zero) return int(q) + int(r != self._zero)
def __reduce__(self): def __reduce__(self):
return numeric_range, (self._start, self._stop, self._step) return numeric_range, (self._start, self._stop, self._step)
@ -2699,6 +2714,9 @@ class seekable:
>>> it.seek(10) >>> it.seek(10)
>>> next(it) >>> next(it)
'10' '10'
>>> it.relative_seek(-2) # Seeking relative to the current position
>>> next(it)
'9'
>>> it.seek(20) # Seeking past the end of the source isn't a problem >>> it.seek(20) # Seeking past the end of the source isn't a problem
>>> list(it) >>> list(it)
[] []
@ -2812,6 +2830,10 @@ class seekable:
if remainder > 0: if remainder > 0:
consume(self, remainder) consume(self, remainder)
def relative_seek(self, count):
index = len(self._cache)
self.seek(max(index + count, 0))
class run_length: class run_length:
""" """
@ -3205,6 +3227,8 @@ class time_limited:
stops if the time elapsed is greater than *limit_seconds*. If your time stops if the time elapsed is greater than *limit_seconds*. If your time
limit is 1 second, but it takes 2 seconds to generate the first item from limit is 1 second, but it takes 2 seconds to generate the first item from
the iterable, the function will run for 2 seconds and not yield anything. the iterable, the function will run for 2 seconds and not yield anything.
As a special case, when *limit_seconds* is zero, the iterator never
returns anything.
""" """
@ -3220,6 +3244,9 @@ class time_limited:
return self return self
def __next__(self): def __next__(self):
if self.limit_seconds == 0:
self.timed_out = True
raise StopIteration
item = next(self._iterable) item = next(self._iterable)
if monotonic() - self._start_time > self.limit_seconds: if monotonic() - self._start_time > self.limit_seconds:
self.timed_out = True self.timed_out = True
@ -3339,7 +3366,7 @@ def iequals(*iterables):
>>> iequals("abc", "acb") >>> iequals("abc", "acb")
False False
Not to be confused with :func:`all_equals`, which checks whether all Not to be confused with :func:`all_equal`, which checks whether all
elements of iterable are equal to each other. elements of iterable are equal to each other.
""" """
@ -3835,7 +3862,7 @@ def nth_permutation(iterable, r, index):
elif not 0 <= r < n: elif not 0 <= r < n:
raise ValueError raise ValueError
else: else:
c = factorial(n) // factorial(n - r) c = perm(n, r)
if index < 0: if index < 0:
index += c index += c
@ -3858,6 +3885,52 @@ def nth_permutation(iterable, r, index):
return tuple(map(pool.pop, result)) return tuple(map(pool.pop, result))
def nth_combination_with_replacement(iterable, r, index):
"""Equivalent to
``list(combinations_with_replacement(iterable, r))[index]``.
The subsequences with repetition of *iterable* that are of length *r* can
be ordered lexicographically. :func:`nth_combination_with_replacement`
computes the subsequence at sort position *index* directly, without
computing the previous subsequences with replacement.
>>> nth_combination_with_replacement(range(5), 3, 5)
(0, 1, 1)
``ValueError`` will be raised If *r* is negative or greater than the length
of *iterable*.
``IndexError`` will be raised if the given *index* is invalid.
"""
pool = tuple(iterable)
n = len(pool)
if (r < 0) or (r > n):
raise ValueError
c = comb(n + r - 1, r)
if index < 0:
index += c
if (index < 0) or (index >= c):
raise IndexError
result = []
i = 0
while r:
r -= 1
while n >= 0:
num_combs = comb(n + r - 1, r)
if index < num_combs:
break
n -= 1
i += 1
index -= num_combs
result.append(pool[i])
return tuple(result)
def value_chain(*args): def value_chain(*args):
"""Yield all arguments passed to the function in the same order in which """Yield all arguments passed to the function in the same order in which
they were passed. If an argument itself is iterable then iterate over its they were passed. If an argument itself is iterable then iterate over its
@ -3949,9 +4022,66 @@ def combination_index(element, iterable):
for i, j in enumerate(reversed(indexes), start=1): for i, j in enumerate(reversed(indexes), start=1):
j = n - j j = n - j
if i <= j: if i <= j:
index += factorial(j) // (factorial(i) * factorial(j - i)) index += comb(j, i)
return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index return comb(n + 1, k + 1) - index
def combination_with_replacement_index(element, iterable):
"""Equivalent to
``list(combinations_with_replacement(iterable, r)).index(element)``
The subsequences with repetition of *iterable* that are of length *r* can
be ordered lexicographically. :func:`combination_with_replacement_index`
computes the index of the first *element*, without computing the previous
combinations with replacement.
>>> combination_with_replacement_index('adf', 'abcdefg')
20
``ValueError`` will be raised if the given *element* isn't one of the
combinations with replacement of *iterable*.
"""
element = tuple(element)
l = len(element)
element = enumerate(element)
k, y = next(element, (None, None))
if k is None:
return 0
indexes = []
pool = tuple(iterable)
for n, x in enumerate(pool):
while x == y:
indexes.append(n)
tmp, y = next(element, (None, None))
if tmp is None:
break
else:
k = tmp
if y is None:
break
else:
raise ValueError(
'element is not a combination with replacement of iterable'
)
n = len(pool)
occupations = [0] * n
for p in indexes:
occupations[p] += 1
index = 0
cumulative_sum = 0
for k in range(1, n):
cumulative_sum += occupations[k - 1]
j = l + n - 1 - k - cumulative_sum
i = n - k
if i <= j:
index += comb(j, i)
return index
def permutation_index(element, iterable): def permutation_index(element, iterable):
@ -4056,26 +4186,20 @@ def _chunked_even_finite(iterable, N, n):
num_full = N - partial_size * num_lists num_full = N - partial_size * num_lists
num_partial = num_lists - num_full num_partial = num_lists - num_full
buffer = []
iterator = iter(iterable)
# Yield num_full lists of full_size # Yield num_full lists of full_size
for x in iterator: partial_start_idx = num_full * full_size
buffer.append(x) if full_size > 0:
if len(buffer) == full_size: for i in range(0, partial_start_idx, full_size):
yield buffer yield list(islice(iterable, i, i + full_size))
buffer = []
num_full -= 1
if num_full <= 0:
break
# Yield num_partial lists of partial_size # Yield num_partial lists of partial_size
for x in iterator: if partial_size > 0:
buffer.append(x) for i in range(
if len(buffer) == partial_size: partial_start_idx,
yield buffer partial_start_idx + (num_partial * partial_size),
buffer = [] partial_size,
num_partial -= 1 ):
yield list(islice(iterable, i, i + partial_size))
def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False): def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
@ -4114,30 +4238,23 @@ def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
if not size: if not size:
return return
new_item = [None] * size
iterables, iterable_positions = [], [] iterables, iterable_positions = [], []
scalars, scalar_positions = [], []
for i, obj in enumerate(objects): for i, obj in enumerate(objects):
if is_scalar(obj): if is_scalar(obj):
scalars.append(obj) new_item[i] = obj
scalar_positions.append(i)
else: else:
iterables.append(iter(obj)) iterables.append(iter(obj))
iterable_positions.append(i) iterable_positions.append(i)
if len(scalars) == size: if not iterables:
yield tuple(objects) yield tuple(objects)
return return
zipper = _zip_equal if strict else zip zipper = _zip_equal if strict else zip
for item in zipper(*iterables): for item in zipper(*iterables):
new_item = [None] * size for i, new_item[i] in zip(iterable_positions, item):
pass
for i, elem in zip(iterable_positions, item):
new_item[i] = elem
for i, elem in zip(scalar_positions, scalars):
new_item[i] = elem
yield tuple(new_item) yield tuple(new_item)
@ -4162,22 +4279,23 @@ def unique_in_window(iterable, n, key=None):
raise ValueError('n must be greater than 0') raise ValueError('n must be greater than 0')
window = deque(maxlen=n) window = deque(maxlen=n)
uniques = set() counts = defaultdict(int)
use_key = key is not None use_key = key is not None
for item in iterable: for item in iterable:
if len(window) == n:
to_discard = window[0]
if counts[to_discard] == 1:
del counts[to_discard]
else:
counts[to_discard] -= 1
k = key(item) if use_key else item k = key(item) if use_key else item
if k in uniques: if k not in counts:
continue yield item
counts[k] += 1
if len(uniques) == n:
uniques.discard(window[0])
uniques.add(k)
window.append(k) window.append(k)
yield item
def duplicates_everseen(iterable, key=None): def duplicates_everseen(iterable, key=None):
"""Yield duplicate elements after their first appearance. """Yield duplicate elements after their first appearance.
@ -4187,7 +4305,7 @@ def duplicates_everseen(iterable, key=None):
>>> list(duplicates_everseen('AaaBbbCccAaa', str.lower)) >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a'] ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
This function is analagous to :func:`unique_everseen` and is subject to This function is analogous to :func:`unique_everseen` and is subject to
the same performance considerations. the same performance considerations.
""" """
@ -4217,15 +4335,52 @@ def duplicates_justseen(iterable, key=None):
>>> list(duplicates_justseen('AaaBbbCccAaa', str.lower)) >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a'] ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
This function is analagous to :func:`unique_justseen`. This function is analogous to :func:`unique_justseen`.
""" """
return flatten( return flatten(g for _, g in groupby(iterable, key) for _ in g)
map(
lambda group_tuple: islice_extended(group_tuple[1])[1:],
groupby(iterable, key), def classify_unique(iterable, key=None):
) """Classify each element in terms of its uniqueness.
)
For each element in the input iterable, return a 3-tuple consisting of:
1. The element itself
2. ``False`` if the element is equal to the one preceding it in the input,
``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
3. ``False`` if this element has been seen anywhere in the input before,
``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
>>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
[('o', True, True),
('t', True, True),
('t', False, False),
('o', True, False)]
This function is analogous to :func:`unique_everseen` and is subject to
the same performance considerations.
"""
seen_set = set()
seen_list = []
use_key = key is not None
previous = None
for i, element in enumerate(iterable):
k = key(element) if use_key else element
is_unique_justseen = not i or previous != k
previous = k
is_unique_everseen = False
try:
if k not in seen_set:
seen_set.add(k)
is_unique_everseen = True
except TypeError:
if k not in seen_list:
seen_list.append(k)
is_unique_everseen = True
yield element, is_unique_justseen, is_unique_everseen
def minmax(iterable_or_value, *others, key=None, default=_marker): def minmax(iterable_or_value, *others, key=None, default=_marker):
@ -4389,3 +4544,112 @@ def gray_product(*iterables):
o[j] = -o[j] o[j] = -o[j]
f[j] = f[j + 1] f[j] = f[j + 1]
f[j + 1] = j + 1 f[j + 1] = j + 1
def partial_product(*iterables):
"""Yields tuples containing one item from each iterator, with subsequent
tuples changing a single item at a time by advancing each iterator until it
is exhausted. This sequence guarantees every value in each iterable is
output at least once without generating all possible combinations.
This may be useful, for example, when testing an expensive function.
>>> list(partial_product('AB', 'C', 'DEF'))
[('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
"""
iterators = list(map(iter, iterables))
try:
prod = [next(it) for it in iterators]
except StopIteration:
return
yield tuple(prod)
for i, it in enumerate(iterators):
for prod[i] in it:
yield tuple(prod)
def takewhile_inclusive(predicate, iterable):
"""A variant of :func:`takewhile` that yields one additional element.
>>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
[1, 4, 6]
:func:`takewhile` would return ``[1, 4]``.
"""
for x in iterable:
yield x
if not predicate(x):
break
def outer_product(func, xs, ys, *args, **kwargs):
"""A generalized outer product that applies a binary function to all
pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
columns.
Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
Multiplication table:
>>> list(outer_product(mul, range(1, 4), range(1, 6)))
[(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
Cross tabulation:
>>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
>>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
>>> rows = list(zip(xs, ys))
>>> count_rows = lambda x, y: rows.count((x, y))
>>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
[(2, 3, 0), (1, 0, 4)]
Usage with ``*args`` and ``**kwargs``:
>>> animals = ['cat', 'wolf', 'mouse']
>>> list(outer_product(min, animals, animals, key=len))
[('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
"""
ys = tuple(ys)
return batched(
starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
n=len(ys),
)
def iter_suppress(iterable, *exceptions):
"""Yield each of the items from *iterable*. If the iteration raises one of
the specified *exceptions*, that exception will be suppressed and iteration
will stop.
>>> from itertools import chain
>>> def breaks_at_five(x):
... while True:
... if x >= 5:
... raise RuntimeError
... yield x
... x += 1
>>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
>>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
>>> list(chain(it_1, it_2))
[1, 2, 3, 4, 2, 3, 4]
"""
try:
yield from iterable
except exceptions:
return
def filter_map(func, iterable):
"""Apply *func* to every element of *iterable*, yielding only those which
are not ``None``.
>>> elems = ['1', 'a', '2', 'b', '3']
>>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
[1, 2, 3]
"""
for x in iterable:
y = func(x)
if y is not None:
yield y

View file

@ -29,7 +29,7 @@ _U = TypeVar('_U')
_V = TypeVar('_V') _V = TypeVar('_V')
_W = TypeVar('_W') _W = TypeVar('_W')
_T_co = TypeVar('_T_co', covariant=True) _T_co = TypeVar('_T_co', covariant=True)
_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]]) _GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[Any]])
_Raisable = BaseException | Type[BaseException] _Raisable = BaseException | Type[BaseException]
@type_check_only @type_check_only
@ -74,7 +74,7 @@ class peekable(Generic[_T], Iterator[_T]):
def __getitem__(self, index: slice) -> list[_T]: ... def __getitem__(self, index: slice) -> list[_T]: ...
def consumer(func: _GenFn) -> _GenFn: ... def consumer(func: _GenFn) -> _GenFn: ...
def ilen(iterable: Iterable[object]) -> int: ... def ilen(iterable: Iterable[_T]) -> int: ...
def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ... def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
def with_iter( def with_iter(
context_manager: ContextManager[Iterable[_T]], context_manager: ContextManager[Iterable[_T]],
@ -116,7 +116,7 @@ class bucket(Generic[_T, _U], Container[_U]):
self, self,
iterable: Iterable[_T], iterable: Iterable[_T],
key: Callable[[_T], _U], key: Callable[[_T], _U],
validator: Callable[[object], object] | None = ..., validator: Callable[[_U], object] | None = ...,
) -> None: ... ) -> None: ...
def __contains__(self, value: object) -> bool: ... def __contains__(self, value: object) -> bool: ...
def __iter__(self) -> Iterator[_U]: ... def __iter__(self) -> Iterator[_U]: ...
@ -383,7 +383,7 @@ def mark_ends(
iterable: Iterable[_T], iterable: Iterable[_T],
) -> Iterable[tuple[bool, bool, _T]]: ... ) -> Iterable[tuple[bool, bool, _T]]: ...
def locate( def locate(
iterable: Iterable[object], iterable: Iterable[_T],
pred: Callable[..., Any] = ..., pred: Callable[..., Any] = ...,
window_size: int | None = ..., window_size: int | None = ...,
) -> Iterator[int]: ... ) -> Iterator[int]: ...
@ -440,6 +440,7 @@ class seekable(Generic[_T], Iterator[_T]):
def peek(self, default: _U) -> _T | _U: ... def peek(self, default: _U) -> _T | _U: ...
def elements(self) -> SequenceView[_T]: ... def elements(self) -> SequenceView[_T]: ...
def seek(self, index: int) -> None: ... def seek(self, index: int) -> None: ...
def relative_seek(self, count: int) -> None: ...
class run_length: class run_length:
@staticmethod @staticmethod
@ -578,6 +579,9 @@ def all_unique(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> bool: ... ) -> bool: ...
def nth_product(index: int, *args: Iterable[_T]) -> tuple[_T, ...]: ... def nth_product(index: int, *args: Iterable[_T]) -> tuple[_T, ...]: ...
def nth_combination_with_replacement(
iterable: Iterable[_T], r: int, index: int
) -> tuple[_T, ...]: ...
def nth_permutation( def nth_permutation(
iterable: Iterable[_T], r: int, index: int iterable: Iterable[_T], r: int, index: int
) -> tuple[_T, ...]: ... ) -> tuple[_T, ...]: ...
@ -586,6 +590,9 @@ def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ...
def combination_index( def combination_index(
element: Iterable[_T], iterable: Iterable[_T] element: Iterable[_T], iterable: Iterable[_T]
) -> int: ... ) -> int: ...
def combination_with_replacement_index(
element: Iterable[_T], iterable: Iterable[_T]
) -> int: ...
def permutation_index( def permutation_index(
element: Iterable[_T], iterable: Iterable[_T] element: Iterable[_T], iterable: Iterable[_T]
) -> int: ... ) -> int: ...
@ -611,6 +618,9 @@ def duplicates_everseen(
def duplicates_justseen( def duplicates_justseen(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def classify_unique(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> Iterator[tuple[_T, bool, bool]]: ...
class _SupportsLessThan(Protocol): class _SupportsLessThan(Protocol):
def __lt__(self, __other: Any) -> bool: ... def __lt__(self, __other: Any) -> bool: ...
@ -655,12 +665,31 @@ def minmax(
def longest_common_prefix( def longest_common_prefix(
iterables: Iterable[Iterable[_T]], iterables: Iterable[Iterable[_T]],
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def iequals(*iterables: Iterable[object]) -> bool: ... def iequals(*iterables: Iterable[Any]) -> bool: ...
def constrained_batches( def constrained_batches(
iterable: Iterable[object], iterable: Iterable[_T],
max_size: int, max_size: int,
max_count: int | None = ..., max_count: int | None = ...,
get_len: Callable[[_T], object] = ..., get_len: Callable[[_T], object] = ...,
strict: bool = ..., strict: bool = ...,
) -> Iterator[tuple[_T]]: ... ) -> Iterator[tuple[_T]]: ...
def gray_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ... def gray_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
def partial_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
def takewhile_inclusive(
predicate: Callable[[_T], bool], iterable: Iterable[_T]
) -> Iterator[_T]: ...
def outer_product(
func: Callable[[_T, _U], _V],
xs: Iterable[_T],
ys: Iterable[_U],
*args: Any,
**kwargs: Any,
) -> Iterator[tuple[_V, ...]]: ...
def iter_suppress(
iterable: Iterable[_T],
*exceptions: Type[BaseException],
) -> Iterator[_T]: ...
def filter_map(
func: Callable[[_T], _V | None],
iterable: Iterable[_T],
) -> Iterator[_V]: ...

View file

@ -9,11 +9,10 @@ Some backward-compatible usability improvements have been made.
""" """
import math import math
import operator import operator
import warnings
from collections import deque from collections import deque
from collections.abc import Sized from collections.abc import Sized
from functools import reduce from functools import partial, reduce
from itertools import ( from itertools import (
chain, chain,
combinations, combinations,
@ -52,10 +51,13 @@ __all__ = [
'pad_none', 'pad_none',
'pairwise', 'pairwise',
'partition', 'partition',
'polynomial_eval',
'polynomial_from_roots', 'polynomial_from_roots',
'polynomial_derivative',
'powerset', 'powerset',
'prepend', 'prepend',
'quantify', 'quantify',
'reshape',
'random_combination_with_replacement', 'random_combination_with_replacement',
'random_combination', 'random_combination',
'random_permutation', 'random_permutation',
@ -65,9 +67,11 @@ __all__ = [
'sieve', 'sieve',
'sliding_window', 'sliding_window',
'subslices', 'subslices',
'sum_of_squares',
'tabulate', 'tabulate',
'tail', 'tail',
'take', 'take',
'totient',
'transpose', 'transpose',
'triplewise', 'triplewise',
'unique_everseen', 'unique_everseen',
@ -77,6 +81,18 @@ __all__ = [
_marker = object() _marker = object()
# zip with strict is available for Python 3.10+
try:
zip(strict=True)
except TypeError:
_zip_strict = zip
else:
_zip_strict = partial(zip, strict=True)
# math.sumprod is available for Python 3.12+
_sumprod = getattr(math, 'sumprod', lambda x, y: dotproduct(x, y))
def take(n, iterable): def take(n, iterable):
"""Return first *n* items of the iterable as a list. """Return first *n* items of the iterable as a list.
@ -293,7 +309,7 @@ def _pairwise(iterable):
""" """
a, b = tee(iterable) a, b = tee(iterable)
next(b, None) next(b, None)
yield from zip(a, b) return zip(a, b)
try: try:
@ -303,7 +319,7 @@ except ImportError:
else: else:
def pairwise(iterable): def pairwise(iterable):
yield from itertools_pairwise(iterable) return itertools_pairwise(iterable)
pairwise.__doc__ = _pairwise.__doc__ pairwise.__doc__ = _pairwise.__doc__
@ -334,13 +350,9 @@ def _zip_equal(*iterables):
for i, it in enumerate(iterables[1:], 1): for i, it in enumerate(iterables[1:], 1):
size = len(it) size = len(it)
if size != first_size: if size != first_size:
break raise UnequalIterablesError(details=(first_size, i, size))
else: # All sizes are equal, we can use the built-in zip.
# If we didn't break out, we can use the built-in zip. return zip(*iterables)
return zip(*iterables)
# If we did break out, there was a mismatch.
raise UnequalIterablesError(details=(first_size, i, size))
# If any one of the iterables didn't have a length, start reading # If any one of the iterables didn't have a length, start reading
# them until one runs out. # them until one runs out.
except TypeError: except TypeError:
@ -433,12 +445,9 @@ def partition(pred, iterable):
if pred is None: if pred is None:
pred = bool pred = bool
evaluations = ((pred(x), x) for x in iterable) t1, t2, p = tee(iterable, 3)
t1, t2 = tee(evaluations) p1, p2 = tee(map(pred, p))
return ( return (compress(t1, map(operator.not_, p1)), compress(t2, p2))
(x for (cond, x) in t1 if not cond),
(x for (cond, x) in t2 if cond),
)
def powerset(iterable): def powerset(iterable):
@ -486,7 +495,7 @@ def unique_everseen(iterable, key=None):
>>> list(unique_everseen(iterable, key=tuple)) # Faster >>> list(unique_everseen(iterable, key=tuple)) # Faster
[[1, 2], [2, 3]] [[1, 2], [2, 3]]
Similary, you may want to convert unhashable ``set`` objects with Similarly, you may want to convert unhashable ``set`` objects with
``key=frozenset``. For ``dict`` objects, ``key=frozenset``. For ``dict`` objects,
``key=lambda x: frozenset(x.items())`` can be used. ``key=lambda x: frozenset(x.items())`` can be used.
@ -518,6 +527,9 @@ def unique_justseen(iterable, key=None):
['A', 'B', 'C', 'A', 'D'] ['A', 'B', 'C', 'A', 'D']
""" """
if key is None:
return map(operator.itemgetter(0), groupby(iterable))
return map(next, map(operator.itemgetter(1), groupby(iterable, key))) return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
@ -712,12 +724,14 @@ def convolve(signal, kernel):
is immediately consumed and stored. is immediately consumed and stored.
""" """
# This implementation intentionally doesn't match the one in the itertools
# documentation.
kernel = tuple(kernel)[::-1] kernel = tuple(kernel)[::-1]
n = len(kernel) n = len(kernel)
window = deque([0], maxlen=n) * n window = deque([0], maxlen=n) * n
for x in chain(signal, repeat(0, n - 1)): for x in chain(signal, repeat(0, n - 1)):
window.append(x) window.append(x)
yield sum(map(operator.mul, kernel, window)) yield _sumprod(kernel, window)
def before_and_after(predicate, it): def before_and_after(predicate, it):
@ -778,9 +792,7 @@ def sliding_window(iterable, n):
For a variant with more features, see :func:`windowed`. For a variant with more features, see :func:`windowed`.
""" """
it = iter(iterable) it = iter(iterable)
window = deque(islice(it, n), maxlen=n) window = deque(islice(it, n - 1), maxlen=n)
if len(window) == n:
yield tuple(window)
for x in it: for x in it:
window.append(x) window.append(x)
yield tuple(window) yield tuple(window)
@ -807,39 +819,38 @@ def polynomial_from_roots(roots):
>>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60 >>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60
[1, -4, -17, 60] [1, -4, -17, 60]
""" """
# Use math.prod for Python 3.8+, factors = zip(repeat(1), map(operator.neg, roots))
prod = getattr(math, 'prod', lambda x: reduce(operator.mul, x, 1)) return list(reduce(convolve, factors, [1]))
roots = list(map(operator.neg, roots))
return [
sum(map(prod, combinations(roots, k))) for k in range(len(roots) + 1)
]
def iter_index(iterable, value, start=0): def iter_index(iterable, value, start=0, stop=None):
"""Yield the index of each place in *iterable* that *value* occurs, """Yield the index of each place in *iterable* that *value* occurs,
beginning with index *start*. beginning with index *start* and ending before index *stop*.
See :func:`locate` for a more general means of finding the indexes See :func:`locate` for a more general means of finding the indexes
associated with particular values. associated with particular values.
>>> list(iter_index('AABCADEAF', 'A')) >>> list(iter_index('AABCADEAF', 'A'))
[0, 1, 4, 7] [0, 1, 4, 7]
>>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
[1, 4, 7]
>>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
[1, 4]
""" """
try: seq_index = getattr(iterable, 'index', None)
seq_index = iterable.index if seq_index is None:
except AttributeError:
# Slow path for general iterables # Slow path for general iterables
it = islice(iterable, start, None) it = islice(iterable, start, stop)
for i, element in enumerate(it, start): for i, element in enumerate(it, start):
if element is value or element == value: if element is value or element == value:
yield i yield i
else: else:
# Fast path for sequences # Fast path for sequences
stop = len(iterable) if stop is None else stop
i = start - 1 i = start - 1
try: try:
while True: while True:
i = seq_index(value, i + 1) yield (i := seq_index(value, i + 1, stop))
yield i
except ValueError: except ValueError:
pass pass
@ -850,81 +861,152 @@ def sieve(n):
>>> list(sieve(30)) >>> list(sieve(30))
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29] [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
""" """
isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x))) if n > 2:
yield 2
start = 3
data = bytearray((0, 1)) * (n // 2) data = bytearray((0, 1)) * (n // 2)
data[:3] = 0, 0, 0 limit = math.isqrt(n) + 1
limit = isqrt(n) + 1 for p in iter_index(data, 1, start, limit):
for p in compress(range(limit), data): yield from iter_index(data, 1, start, p * p)
data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p))) data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
data[2] = 1 start = p * p
return iter_index(data, 1) if n > 2 else iter([]) yield from iter_index(data, 1, start)
def batched(iterable, n): def _batched(iterable, n, *, strict=False):
"""Batch data into lists of length *n*. The last batch may be shorter. """Batch data into tuples of length *n*. If the number of items in
*iterable* is not divisible by *n*:
* The last batch will be shorter if *strict* is ``False``.
* :exc:`ValueError` will be raised if *strict* is ``True``.
>>> list(batched('ABCDEFG', 3)) >>> list(batched('ABCDEFG', 3))
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']] [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
This recipe is from the ``itertools`` docs. This library also provides On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
:func:`chunked`, which has a different implementation.
""" """
if hexversion >= 0x30C00A0: # Python 3.12.0a0 if n < 1:
warnings.warn( raise ValueError('n must be at least one')
(
'batched will be removed in a future version of '
'more-itertools. Use the standard library '
'itertools.batched function instead'
),
DeprecationWarning,
)
it = iter(iterable) it = iter(iterable)
while True: while batch := tuple(islice(it, n)):
batch = list(islice(it, n)) if strict and len(batch) != n:
if not batch: raise ValueError('batched(): incomplete batch')
break
yield batch yield batch
if hexversion >= 0x30D00A2:
from itertools import batched as itertools_batched
def batched(iterable, n, *, strict=False):
return itertools_batched(iterable, n, strict=strict)
else:
batched = _batched
batched.__doc__ = _batched.__doc__
def transpose(it): def transpose(it):
"""Swap the rows and columns of the input. """Swap the rows and columns of the input matrix.
>>> list(transpose([(1, 2, 3), (11, 22, 33)])) >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
[(1, 11), (2, 22), (3, 33)] [(1, 11), (2, 22), (3, 33)]
The caller should ensure that the dimensions of the input are compatible. The caller should ensure that the dimensions of the input are compatible.
If the input is empty, no output will be produced.
""" """
# TODO: when 3.9 goes end-of-life, add stric=True to this. return _zip_strict(*it)
return zip(*it)
def reshape(matrix, cols):
"""Reshape the 2-D input *matrix* to have a column count given by *cols*.
>>> matrix = [(0, 1), (2, 3), (4, 5)]
>>> cols = 3
>>> list(reshape(matrix, cols))
[(0, 1, 2), (3, 4, 5)]
"""
return batched(chain.from_iterable(matrix), cols)
def matmul(m1, m2): def matmul(m1, m2):
"""Multiply two matrices. """Multiply two matrices.
>>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)])) >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
[[49, 80], [41, 60]] [(49, 80), (41, 60)]
The caller should ensure that the dimensions of the input matrices are The caller should ensure that the dimensions of the input matrices are
compatible with each other. compatible with each other.
""" """
n = len(m2[0]) n = len(m2[0])
return batched(starmap(dotproduct, product(m1, transpose(m2))), n) return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
def factor(n): def factor(n):
"""Yield the prime factors of n. """Yield the prime factors of n.
>>> list(factor(360)) >>> list(factor(360))
[2, 2, 2, 3, 3, 5] [2, 2, 2, 3, 3, 5]
""" """
isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x))) for prime in sieve(math.isqrt(n) + 1):
for prime in sieve(isqrt(n) + 1): while not n % prime:
while True:
quotient, remainder = divmod(n, prime)
if remainder:
break
yield prime yield prime
n = quotient n //= prime
if n == 1: if n == 1:
return return
if n >= 2: if n > 1:
yield n yield n
def polynomial_eval(coefficients, x):
"""Evaluate a polynomial at a specific value.
Example: evaluating x^3 - 4 * x^2 - 17 * x + 60 at x = 2.5:
>>> coefficients = [1, -4, -17, 60]
>>> x = 2.5
>>> polynomial_eval(coefficients, x)
8.125
"""
n = len(coefficients)
if n == 0:
return x * 0 # coerce zero to the type of x
powers = map(pow, repeat(x), reversed(range(n)))
return _sumprod(coefficients, powers)
def sum_of_squares(it):
"""Return the sum of the squares of the input values.
>>> sum_of_squares([10, 20, 30])
1400
"""
return _sumprod(*tee(it))
def polynomial_derivative(coefficients):
"""Compute the first derivative of a polynomial.
Example: evaluating the derivative of x^3 - 4 * x^2 - 17 * x + 60
>>> coefficients = [1, -4, -17, 60]
>>> derivative_coefficients = polynomial_derivative(coefficients)
>>> derivative_coefficients
[3, -8, -17]
"""
n = len(coefficients)
powers = reversed(range(1, n))
return list(map(operator.mul, coefficients, powers))
def totient(n):
"""Return the count of natural numbers up to *n* that are coprime with *n*.
>>> totient(9)
6
>>> totient(12)
4
"""
for p in unique_justseen(factor(n)):
n = n // p * (p - 1)
return n

View file

@ -14,6 +14,8 @@ from typing import (
# Type and type variable definitions # Type and type variable definitions
_T = TypeVar('_T') _T = TypeVar('_T')
_T1 = TypeVar('_T1')
_T2 = TypeVar('_T2')
_U = TypeVar('_U') _U = TypeVar('_U')
def take(n: int, iterable: Iterable[_T]) -> list[_T]: ... def take(n: int, iterable: Iterable[_T]) -> list[_T]: ...
@ -21,19 +23,19 @@ def tabulate(
function: Callable[[int], _T], start: int = ... function: Callable[[int], _T], start: int = ...
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ... def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ...
def consume(iterator: Iterable[object], n: int | None = ...) -> None: ... def consume(iterator: Iterable[_T], n: int | None = ...) -> None: ...
@overload @overload
def nth(iterable: Iterable[_T], n: int) -> _T | None: ... def nth(iterable: Iterable[_T], n: int) -> _T | None: ...
@overload @overload
def nth(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ... def nth(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
def all_equal(iterable: Iterable[object]) -> bool: ... def all_equal(iterable: Iterable[_T]) -> bool: ...
def quantify( def quantify(
iterable: Iterable[_T], pred: Callable[[_T], bool] = ... iterable: Iterable[_T], pred: Callable[[_T], bool] = ...
) -> int: ... ) -> int: ...
def pad_none(iterable: Iterable[_T]) -> Iterator[_T | None]: ... def pad_none(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
def padnone(iterable: Iterable[_T]) -> Iterator[_T | None]: ... def padnone(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ... def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ...
def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ... def dotproduct(vec1: Iterable[_T1], vec2: Iterable[_T2]) -> Any: ...
def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ... def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ...
def repeatfunc( def repeatfunc(
func: Callable[..., _U], times: int | None = ..., *args: Any func: Callable[..., _U], times: int | None = ..., *args: Any
@ -101,19 +103,26 @@ def sliding_window(
iterable: Iterable[_T], n: int iterable: Iterable[_T], n: int
) -> Iterator[tuple[_T, ...]]: ... ) -> Iterator[tuple[_T, ...]]: ...
def subslices(iterable: Iterable[_T]) -> Iterator[list[_T]]: ... def subslices(iterable: Iterable[_T]) -> Iterator[list[_T]]: ...
def polynomial_from_roots(roots: Sequence[int]) -> list[int]: ... def polynomial_from_roots(roots: Sequence[_T]) -> list[_T]: ...
def iter_index( def iter_index(
iterable: Iterable[object], iterable: Iterable[_T],
value: Any, value: Any,
start: int | None = ..., start: int | None = ...,
stop: int | None = ...,
) -> Iterator[int]: ... ) -> Iterator[int]: ...
def sieve(n: int) -> Iterator[int]: ... def sieve(n: int) -> Iterator[int]: ...
def batched( def batched(
iterable: Iterable[_T], iterable: Iterable[_T], n: int, *, strict: bool = False
n: int, ) -> Iterator[tuple[_T]]: ...
) -> Iterator[list[_T]]: ...
def transpose( def transpose(
it: Iterable[Iterable[_T]], it: Iterable[Iterable[_T]],
) -> tuple[Iterator[_T], ...]: ... ) -> Iterator[tuple[_T, ...]]: ...
def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[list[_T]]: ... def reshape(
matrix: Iterable[Iterable[_T]], cols: int
) -> Iterator[tuple[_T, ...]]: ...
def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[tuple[_T]]: ...
def factor(n: int) -> Iterator[int]: ... def factor(n: int) -> Iterator[int]: ...
def polynomial_eval(coefficients: Sequence[_T], x: _U) -> _U: ...
def sum_of_squares(it: Iterable[_T]) -> _T: ...
def polynomial_derivative(coefficients: Sequence[_T]) -> list[_T]: ...
def totient(n: int) -> int: ...

View file

@ -6,10 +6,10 @@ __title__ = "packaging"
__summary__ = "Core utilities for Python packages" __summary__ = "Core utilities for Python packages"
__uri__ = "https://github.com/pypa/packaging" __uri__ = "https://github.com/pypa/packaging"
__version__ = "23.1" __version__ = "24.0"
__author__ = "Donald Stufft and individual contributors" __author__ = "Donald Stufft and individual contributors"
__email__ = "donald@stufft.io" __email__ = "donald@stufft.io"
__license__ = "BSD-2-Clause or Apache-2.0" __license__ = "BSD-2-Clause or Apache-2.0"
__copyright__ = "2014-2019 %s" % __author__ __copyright__ = "2014 %s" % __author__

View file

@ -5,7 +5,7 @@ import os
import re import re
import sys import sys
import warnings import warnings
from typing import Dict, Generator, Iterator, NamedTuple, Optional, Tuple from typing import Dict, Generator, Iterator, NamedTuple, Optional, Sequence, Tuple
from ._elffile import EIClass, EIData, ELFFile, EMachine from ._elffile import EIClass, EIData, ELFFile, EMachine
@ -50,12 +50,21 @@ def _is_linux_i686(executable: str) -> bool:
) )
def _have_compatible_abi(executable: str, arch: str) -> bool: def _have_compatible_abi(executable: str, archs: Sequence[str]) -> bool:
if arch == "armv7l": if "armv7l" in archs:
return _is_linux_armhf(executable) return _is_linux_armhf(executable)
if arch == "i686": if "i686" in archs:
return _is_linux_i686(executable) return _is_linux_i686(executable)
return arch in {"x86_64", "aarch64", "ppc64", "ppc64le", "s390x"} allowed_archs = {
"x86_64",
"aarch64",
"ppc64",
"ppc64le",
"s390x",
"loongarch64",
"riscv64",
}
return any(arch in allowed_archs for arch in archs)
# If glibc ever changes its major version, we need to know what the last # If glibc ever changes its major version, we need to know what the last
@ -81,7 +90,7 @@ def _glibc_version_string_confstr() -> Optional[str]:
# https://github.com/python/cpython/blob/fcf1d003bf4f0100c/Lib/platform.py#L175-L183 # https://github.com/python/cpython/blob/fcf1d003bf4f0100c/Lib/platform.py#L175-L183
try: try:
# Should be a string like "glibc 2.17". # Should be a string like "glibc 2.17".
version_string: str = getattr(os, "confstr")("CS_GNU_LIBC_VERSION") version_string: Optional[str] = os.confstr("CS_GNU_LIBC_VERSION")
assert version_string is not None assert version_string is not None
_, version = version_string.rsplit() _, version = version_string.rsplit()
except (AssertionError, AttributeError, OSError, ValueError): except (AssertionError, AttributeError, OSError, ValueError):
@ -167,13 +176,13 @@ def _get_glibc_version() -> Tuple[int, int]:
# From PEP 513, PEP 600 # From PEP 513, PEP 600
def _is_compatible(name: str, arch: str, version: _GLibCVersion) -> bool: def _is_compatible(arch: str, version: _GLibCVersion) -> bool:
sys_glibc = _get_glibc_version() sys_glibc = _get_glibc_version()
if sys_glibc < version: if sys_glibc < version:
return False return False
# Check for presence of _manylinux module. # Check for presence of _manylinux module.
try: try:
import _manylinux # noqa import _manylinux
except ImportError: except ImportError:
return True return True
if hasattr(_manylinux, "manylinux_compatible"): if hasattr(_manylinux, "manylinux_compatible"):
@ -203,12 +212,22 @@ _LEGACY_MANYLINUX_MAP = {
} }
def platform_tags(linux: str, arch: str) -> Iterator[str]: def platform_tags(archs: Sequence[str]) -> Iterator[str]:
if not _have_compatible_abi(sys.executable, arch): """Generate manylinux tags compatible to the current platform.
:param archs: Sequence of compatible architectures.
The first one shall be the closest to the actual architecture and be the part of
platform tag after the ``linux_`` prefix, e.g. ``x86_64``.
The ``linux_`` prefix is assumed as a prerequisite for the current platform to
be manylinux-compatible.
:returns: An iterator of compatible manylinux tags.
"""
if not _have_compatible_abi(sys.executable, archs):
return return
# Oldest glibc to be supported regardless of architecture is (2, 17). # Oldest glibc to be supported regardless of architecture is (2, 17).
too_old_glibc2 = _GLibCVersion(2, 16) too_old_glibc2 = _GLibCVersion(2, 16)
if arch in {"x86_64", "i686"}: if set(archs) & {"x86_64", "i686"}:
# On x86/i686 also oldest glibc to be supported is (2, 5). # On x86/i686 also oldest glibc to be supported is (2, 5).
too_old_glibc2 = _GLibCVersion(2, 4) too_old_glibc2 = _GLibCVersion(2, 4)
current_glibc = _GLibCVersion(*_get_glibc_version()) current_glibc = _GLibCVersion(*_get_glibc_version())
@ -222,19 +241,20 @@ def platform_tags(linux: str, arch: str) -> Iterator[str]:
for glibc_major in range(current_glibc.major - 1, 1, -1): for glibc_major in range(current_glibc.major - 1, 1, -1):
glibc_minor = _LAST_GLIBC_MINOR[glibc_major] glibc_minor = _LAST_GLIBC_MINOR[glibc_major]
glibc_max_list.append(_GLibCVersion(glibc_major, glibc_minor)) glibc_max_list.append(_GLibCVersion(glibc_major, glibc_minor))
for glibc_max in glibc_max_list: for arch in archs:
if glibc_max.major == too_old_glibc2.major: for glibc_max in glibc_max_list:
min_minor = too_old_glibc2.minor if glibc_max.major == too_old_glibc2.major:
else: min_minor = too_old_glibc2.minor
# For other glibc major versions oldest supported is (x, 0). else:
min_minor = -1 # For other glibc major versions oldest supported is (x, 0).
for glibc_minor in range(glibc_max.minor, min_minor, -1): min_minor = -1
glibc_version = _GLibCVersion(glibc_max.major, glibc_minor) for glibc_minor in range(glibc_max.minor, min_minor, -1):
tag = "manylinux_{}_{}".format(*glibc_version) glibc_version = _GLibCVersion(glibc_max.major, glibc_minor)
if _is_compatible(tag, arch, glibc_version): tag = "manylinux_{}_{}".format(*glibc_version)
yield linux.replace("linux", tag) if _is_compatible(arch, glibc_version):
# Handle the legacy manylinux1, manylinux2010, manylinux2014 tags. yield f"{tag}_{arch}"
if glibc_version in _LEGACY_MANYLINUX_MAP: # Handle the legacy manylinux1, manylinux2010, manylinux2014 tags.
legacy_tag = _LEGACY_MANYLINUX_MAP[glibc_version] if glibc_version in _LEGACY_MANYLINUX_MAP:
if _is_compatible(legacy_tag, arch, glibc_version): legacy_tag = _LEGACY_MANYLINUX_MAP[glibc_version]
yield linux.replace("linux", legacy_tag) if _is_compatible(arch, glibc_version):
yield f"{legacy_tag}_{arch}"

View file

@ -8,7 +8,7 @@ import functools
import re import re
import subprocess import subprocess
import sys import sys
from typing import Iterator, NamedTuple, Optional from typing import Iterator, NamedTuple, Optional, Sequence
from ._elffile import ELFFile from ._elffile import ELFFile
@ -47,24 +47,27 @@ def _get_musl_version(executable: str) -> Optional[_MuslVersion]:
return None return None
if ld is None or "musl" not in ld: if ld is None or "musl" not in ld:
return None return None
proc = subprocess.run([ld], stderr=subprocess.PIPE, universal_newlines=True) proc = subprocess.run([ld], stderr=subprocess.PIPE, text=True)
return _parse_musl_version(proc.stderr) return _parse_musl_version(proc.stderr)
def platform_tags(arch: str) -> Iterator[str]: def platform_tags(archs: Sequence[str]) -> Iterator[str]:
"""Generate musllinux tags compatible to the current platform. """Generate musllinux tags compatible to the current platform.
:param arch: Should be the part of platform tag after the ``linux_`` :param archs: Sequence of compatible architectures.
prefix, e.g. ``x86_64``. The ``linux_`` prefix is assumed as a The first one shall be the closest to the actual architecture and be the part of
prerequisite for the current platform to be musllinux-compatible. platform tag after the ``linux_`` prefix, e.g. ``x86_64``.
The ``linux_`` prefix is assumed as a prerequisite for the current platform to
be musllinux-compatible.
:returns: An iterator of compatible musllinux tags. :returns: An iterator of compatible musllinux tags.
""" """
sys_musl = _get_musl_version(sys.executable) sys_musl = _get_musl_version(sys.executable)
if sys_musl is None: # Python not dynamically linked against musl. if sys_musl is None: # Python not dynamically linked against musl.
return return
for minor in range(sys_musl.minor, -1, -1): for arch in archs:
yield f"musllinux_{sys_musl.major}_{minor}_{arch}" for minor in range(sys_musl.minor, -1, -1):
yield f"musllinux_{sys_musl.major}_{minor}_{arch}"
if __name__ == "__main__": # pragma: no cover if __name__ == "__main__": # pragma: no cover

View file

@ -252,7 +252,13 @@ def _parse_version_many(tokenizer: Tokenizer) -> str:
# Recursive descent parser for marker expression # Recursive descent parser for marker expression
# -------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------
def parse_marker(source: str) -> MarkerList: def parse_marker(source: str) -> MarkerList:
return _parse_marker(Tokenizer(source, rules=DEFAULT_RULES)) return _parse_full_marker(Tokenizer(source, rules=DEFAULT_RULES))
def _parse_full_marker(tokenizer: Tokenizer) -> MarkerList:
retval = _parse_marker(tokenizer)
tokenizer.expect("END", expected="end of marker expression")
return retval
def _parse_marker(tokenizer: Tokenizer) -> MarkerList: def _parse_marker(tokenizer: Tokenizer) -> MarkerList:
@ -318,10 +324,7 @@ def _parse_marker_var(tokenizer: Tokenizer) -> MarkerVar:
def process_env_var(env_var: str) -> Variable: def process_env_var(env_var: str) -> Variable:
if ( if env_var in ("platform_python_implementation", "python_implementation"):
env_var == "platform_python_implementation"
or env_var == "python_implementation"
):
return Variable("platform_python_implementation") return Variable("platform_python_implementation")
else: else:
return Variable(env_var) return Variable(env_var)

View file

@ -5,23 +5,77 @@ import email.parser
import email.policy import email.policy
import sys import sys
import typing import typing
from typing import Dict, List, Optional, Tuple, Union, cast from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
if sys.version_info >= (3, 8): # pragma: no cover from . import requirements, specifiers, utils, version as version_module
from typing import TypedDict
T = typing.TypeVar("T")
if sys.version_info[:2] >= (3, 8): # pragma: no cover
from typing import Literal, TypedDict
else: # pragma: no cover else: # pragma: no cover
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing_extensions import TypedDict from typing_extensions import Literal, TypedDict
else: else:
try: try:
from typing_extensions import TypedDict from typing_extensions import Literal, TypedDict
except ImportError: except ImportError:
class Literal:
def __init_subclass__(*_args, **_kwargs):
pass
class TypedDict: class TypedDict:
def __init_subclass__(*_args, **_kwargs): def __init_subclass__(*_args, **_kwargs):
pass pass
try:
ExceptionGroup
except NameError: # pragma: no cover
class ExceptionGroup(Exception): # noqa: N818
"""A minimal implementation of :external:exc:`ExceptionGroup` from Python 3.11.
If :external:exc:`ExceptionGroup` is already defined by Python itself,
that version is used instead.
"""
message: str
exceptions: List[Exception]
def __init__(self, message: str, exceptions: List[Exception]) -> None:
self.message = message
self.exceptions = exceptions
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.message!r}, {self.exceptions!r})"
else: # pragma: no cover
ExceptionGroup = ExceptionGroup
class InvalidMetadata(ValueError):
"""A metadata field contains invalid data."""
field: str
"""The name of the field that contains invalid data."""
def __init__(self, field: str, message: str) -> None:
self.field = field
super().__init__(message)
# The RawMetadata class attempts to make as few assumptions about the underlying # The RawMetadata class attempts to make as few assumptions about the underlying
# serialization formats as possible. The idea is that as long as a serialization # serialization formats as possible. The idea is that as long as a serialization
# formats offer some very basic primitives in *some* way then we can support # formats offer some very basic primitives in *some* way then we can support
@ -33,7 +87,8 @@ class RawMetadata(TypedDict, total=False):
provided). The key is lower-case and underscores are used instead of dashes provided). The key is lower-case and underscores are used instead of dashes
compared to the equivalent core metadata field. Any core metadata field that compared to the equivalent core metadata field. Any core metadata field that
can be specified multiple times or can hold multiple values in a single can be specified multiple times or can hold multiple values in a single
field have a key with a plural name. field have a key with a plural name. See :class:`Metadata` whose attributes
match the keys of this dictionary.
Core metadata fields that can be specified multiple times are stored as a Core metadata fields that can be specified multiple times are stored as a
list or dict depending on which is appropriate for the field. Any fields list or dict depending on which is appropriate for the field. Any fields
@ -77,7 +132,7 @@ class RawMetadata(TypedDict, total=False):
# but got stuck without ever being able to build consensus on # but got stuck without ever being able to build consensus on
# it and ultimately ended up withdrawn. # it and ultimately ended up withdrawn.
# #
# However, a number of tools had started emiting METADATA with # However, a number of tools had started emitting METADATA with
# `2.0` Metadata-Version, so for historical reasons, this version # `2.0` Metadata-Version, so for historical reasons, this version
# was skipped. # was skipped.
@ -110,7 +165,7 @@ _STRING_FIELDS = {
"version", "version",
} }
_LIST_STRING_FIELDS = { _LIST_FIELDS = {
"classifiers", "classifiers",
"dynamic", "dynamic",
"obsoletes", "obsoletes",
@ -125,6 +180,10 @@ _LIST_STRING_FIELDS = {
"supported_platforms", "supported_platforms",
} }
_DICT_FIELDS = {
"project_urls",
}
def _parse_keywords(data: str) -> List[str]: def _parse_keywords(data: str) -> List[str]:
"""Split a string of comma-separate keyboards into a list of keywords.""" """Split a string of comma-separate keyboards into a list of keywords."""
@ -230,10 +289,11 @@ _EMAIL_TO_RAW_MAPPING = {
"supported-platform": "supported_platforms", "supported-platform": "supported_platforms",
"version": "version", "version": "version",
} }
_RAW_TO_EMAIL_MAPPING = {raw: email for email, raw in _EMAIL_TO_RAW_MAPPING.items()}
def parse_email(data: Union[bytes, str]) -> Tuple[RawMetadata, Dict[str, List[str]]]: def parse_email(data: Union[bytes, str]) -> Tuple[RawMetadata, Dict[str, List[str]]]:
"""Parse a distribution's metadata. """Parse a distribution's metadata stored as email headers (e.g. from ``METADATA``).
This function returns a two-item tuple of dicts. The first dict is of This function returns a two-item tuple of dicts. The first dict is of
recognized fields from the core metadata specification. Fields that can be recognized fields from the core metadata specification. Fields that can be
@ -267,7 +327,7 @@ def parse_email(data: Union[bytes, str]) -> Tuple[RawMetadata, Dict[str, List[st
# We use get_all() here, even for fields that aren't multiple use, # We use get_all() here, even for fields that aren't multiple use,
# because otherwise someone could have e.g. two Name fields, and we # because otherwise someone could have e.g. two Name fields, and we
# would just silently ignore it rather than doing something about it. # would just silently ignore it rather than doing something about it.
headers = parsed.get_all(name) headers = parsed.get_all(name) or []
# The way the email module works when parsing bytes is that it # The way the email module works when parsing bytes is that it
# unconditionally decodes the bytes as ascii using the surrogateescape # unconditionally decodes the bytes as ascii using the surrogateescape
@ -349,7 +409,7 @@ def parse_email(data: Union[bytes, str]) -> Tuple[RawMetadata, Dict[str, List[st
# If this is one of our list of string fields, then we can just assign # If this is one of our list of string fields, then we can just assign
# the value, since email *only* has strings, and our get_all() call # the value, since email *only* has strings, and our get_all() call
# above ensures that this is a list. # above ensures that this is a list.
elif raw_name in _LIST_STRING_FIELDS: elif raw_name in _LIST_FIELDS:
raw[raw_name] = value raw[raw_name] = value
# Special Case: Keywords # Special Case: Keywords
# The keywords field is implemented in the metadata spec as a str, # The keywords field is implemented in the metadata spec as a str,
@ -406,3 +466,360 @@ def parse_email(data: Union[bytes, str]) -> Tuple[RawMetadata, Dict[str, List[st
# way this function is implemented, our `TypedDict` can only have valid key # way this function is implemented, our `TypedDict` can only have valid key
# names. # names.
return cast(RawMetadata, raw), unparsed return cast(RawMetadata, raw), unparsed
_NOT_FOUND = object()
# Keep the two values in sync.
_VALID_METADATA_VERSIONS = ["1.0", "1.1", "1.2", "2.1", "2.2", "2.3"]
_MetadataVersion = Literal["1.0", "1.1", "1.2", "2.1", "2.2", "2.3"]
_REQUIRED_ATTRS = frozenset(["metadata_version", "name", "version"])
class _Validator(Generic[T]):
"""Validate a metadata field.
All _process_*() methods correspond to a core metadata field. The method is
called with the field's raw value. If the raw value is valid it is returned
in its "enriched" form (e.g. ``version.Version`` for the ``Version`` field).
If the raw value is invalid, :exc:`InvalidMetadata` is raised (with a cause
as appropriate).
"""
name: str
raw_name: str
added: _MetadataVersion
def __init__(
self,
*,
added: _MetadataVersion = "1.0",
) -> None:
self.added = added
def __set_name__(self, _owner: "Metadata", name: str) -> None:
self.name = name
self.raw_name = _RAW_TO_EMAIL_MAPPING[name]
def __get__(self, instance: "Metadata", _owner: Type["Metadata"]) -> T:
# With Python 3.8, the caching can be replaced with functools.cached_property().
# No need to check the cache as attribute lookup will resolve into the
# instance's __dict__ before __get__ is called.
cache = instance.__dict__
value = instance._raw.get(self.name)
# To make the _process_* methods easier, we'll check if the value is None
# and if this field is NOT a required attribute, and if both of those
# things are true, we'll skip the the converter. This will mean that the
# converters never have to deal with the None union.
if self.name in _REQUIRED_ATTRS or value is not None:
try:
converter: Callable[[Any], T] = getattr(self, f"_process_{self.name}")
except AttributeError:
pass
else:
value = converter(value)
cache[self.name] = value
try:
del instance._raw[self.name] # type: ignore[misc]
except KeyError:
pass
return cast(T, value)
def _invalid_metadata(
self, msg: str, cause: Optional[Exception] = None
) -> InvalidMetadata:
exc = InvalidMetadata(
self.raw_name, msg.format_map({"field": repr(self.raw_name)})
)
exc.__cause__ = cause
return exc
def _process_metadata_version(self, value: str) -> _MetadataVersion:
# Implicitly makes Metadata-Version required.
if value not in _VALID_METADATA_VERSIONS:
raise self._invalid_metadata(f"{value!r} is not a valid metadata version")
return cast(_MetadataVersion, value)
def _process_name(self, value: str) -> str:
if not value:
raise self._invalid_metadata("{field} is a required field")
# Validate the name as a side-effect.
try:
utils.canonicalize_name(value, validate=True)
except utils.InvalidName as exc:
raise self._invalid_metadata(
f"{value!r} is invalid for {{field}}", cause=exc
)
else:
return value
def _process_version(self, value: str) -> version_module.Version:
if not value:
raise self._invalid_metadata("{field} is a required field")
try:
return version_module.parse(value)
except version_module.InvalidVersion as exc:
raise self._invalid_metadata(
f"{value!r} is invalid for {{field}}", cause=exc
)
def _process_summary(self, value: str) -> str:
"""Check the field contains no newlines."""
if "\n" in value:
raise self._invalid_metadata("{field} must be a single line")
return value
def _process_description_content_type(self, value: str) -> str:
content_types = {"text/plain", "text/x-rst", "text/markdown"}
message = email.message.EmailMessage()
message["content-type"] = value
content_type, parameters = (
# Defaults to `text/plain` if parsing failed.
message.get_content_type().lower(),
message["content-type"].params,
)
# Check if content-type is valid or defaulted to `text/plain` and thus was
# not parseable.
if content_type not in content_types or content_type not in value.lower():
raise self._invalid_metadata(
f"{{field}} must be one of {list(content_types)}, not {value!r}"
)
charset = parameters.get("charset", "UTF-8")
if charset != "UTF-8":
raise self._invalid_metadata(
f"{{field}} can only specify the UTF-8 charset, not {list(charset)}"
)
markdown_variants = {"GFM", "CommonMark"}
variant = parameters.get("variant", "GFM") # Use an acceptable default.
if content_type == "text/markdown" and variant not in markdown_variants:
raise self._invalid_metadata(
f"valid Markdown variants for {{field}} are {list(markdown_variants)}, "
f"not {variant!r}",
)
return value
def _process_dynamic(self, value: List[str]) -> List[str]:
for dynamic_field in map(str.lower, value):
if dynamic_field in {"name", "version", "metadata-version"}:
raise self._invalid_metadata(
f"{value!r} is not allowed as a dynamic field"
)
elif dynamic_field not in _EMAIL_TO_RAW_MAPPING:
raise self._invalid_metadata(f"{value!r} is not a valid dynamic field")
return list(map(str.lower, value))
def _process_provides_extra(
self,
value: List[str],
) -> List[utils.NormalizedName]:
normalized_names = []
try:
for name in value:
normalized_names.append(utils.canonicalize_name(name, validate=True))
except utils.InvalidName as exc:
raise self._invalid_metadata(
f"{name!r} is invalid for {{field}}", cause=exc
)
else:
return normalized_names
def _process_requires_python(self, value: str) -> specifiers.SpecifierSet:
try:
return specifiers.SpecifierSet(value)
except specifiers.InvalidSpecifier as exc:
raise self._invalid_metadata(
f"{value!r} is invalid for {{field}}", cause=exc
)
def _process_requires_dist(
self,
value: List[str],
) -> List[requirements.Requirement]:
reqs = []
try:
for req in value:
reqs.append(requirements.Requirement(req))
except requirements.InvalidRequirement as exc:
raise self._invalid_metadata(f"{req!r} is invalid for {{field}}", cause=exc)
else:
return reqs
class Metadata:
"""Representation of distribution metadata.
Compared to :class:`RawMetadata`, this class provides objects representing
metadata fields instead of only using built-in types. Any invalid metadata
will cause :exc:`InvalidMetadata` to be raised (with a
:py:attr:`~BaseException.__cause__` attribute as appropriate).
"""
_raw: RawMetadata
@classmethod
def from_raw(cls, data: RawMetadata, *, validate: bool = True) -> "Metadata":
"""Create an instance from :class:`RawMetadata`.
If *validate* is true, all metadata will be validated. All exceptions
related to validation will be gathered and raised as an :class:`ExceptionGroup`.
"""
ins = cls()
ins._raw = data.copy() # Mutations occur due to caching enriched values.
if validate:
exceptions: List[Exception] = []
try:
metadata_version = ins.metadata_version
metadata_age = _VALID_METADATA_VERSIONS.index(metadata_version)
except InvalidMetadata as metadata_version_exc:
exceptions.append(metadata_version_exc)
metadata_version = None
# Make sure to check for the fields that are present, the required
# fields (so their absence can be reported).
fields_to_check = frozenset(ins._raw) | _REQUIRED_ATTRS
# Remove fields that have already been checked.
fields_to_check -= {"metadata_version"}
for key in fields_to_check:
try:
if metadata_version:
# Can't use getattr() as that triggers descriptor protocol which
# will fail due to no value for the instance argument.
try:
field_metadata_version = cls.__dict__[key].added
except KeyError:
exc = InvalidMetadata(key, f"unrecognized field: {key!r}")
exceptions.append(exc)
continue
field_age = _VALID_METADATA_VERSIONS.index(
field_metadata_version
)
if field_age > metadata_age:
field = _RAW_TO_EMAIL_MAPPING[key]
exc = InvalidMetadata(
field,
"{field} introduced in metadata version "
"{field_metadata_version}, not {metadata_version}",
)
exceptions.append(exc)
continue
getattr(ins, key)
except InvalidMetadata as exc:
exceptions.append(exc)
if exceptions:
raise ExceptionGroup("invalid metadata", exceptions)
return ins
@classmethod
def from_email(
cls, data: Union[bytes, str], *, validate: bool = True
) -> "Metadata":
"""Parse metadata from email headers.
If *validate* is true, the metadata will be validated. All exceptions
related to validation will be gathered and raised as an :class:`ExceptionGroup`.
"""
raw, unparsed = parse_email(data)
if validate:
exceptions: list[Exception] = []
for unparsed_key in unparsed:
if unparsed_key in _EMAIL_TO_RAW_MAPPING:
message = f"{unparsed_key!r} has invalid data"
else:
message = f"unrecognized field: {unparsed_key!r}"
exceptions.append(InvalidMetadata(unparsed_key, message))
if exceptions:
raise ExceptionGroup("unparsed", exceptions)
try:
return cls.from_raw(raw, validate=validate)
except ExceptionGroup as exc_group:
raise ExceptionGroup(
"invalid or unparsed metadata", exc_group.exceptions
) from None
metadata_version: _Validator[_MetadataVersion] = _Validator()
""":external:ref:`core-metadata-metadata-version`
(required; validated to be a valid metadata version)"""
name: _Validator[str] = _Validator()
""":external:ref:`core-metadata-name`
(required; validated using :func:`~packaging.utils.canonicalize_name` and its
*validate* parameter)"""
version: _Validator[version_module.Version] = _Validator()
""":external:ref:`core-metadata-version` (required)"""
dynamic: _Validator[Optional[List[str]]] = _Validator(
added="2.2",
)
""":external:ref:`core-metadata-dynamic`
(validated against core metadata field names and lowercased)"""
platforms: _Validator[Optional[List[str]]] = _Validator()
""":external:ref:`core-metadata-platform`"""
supported_platforms: _Validator[Optional[List[str]]] = _Validator(added="1.1")
""":external:ref:`core-metadata-supported-platform`"""
summary: _Validator[Optional[str]] = _Validator()
""":external:ref:`core-metadata-summary` (validated to contain no newlines)"""
description: _Validator[Optional[str]] = _Validator() # TODO 2.1: can be in body
""":external:ref:`core-metadata-description`"""
description_content_type: _Validator[Optional[str]] = _Validator(added="2.1")
""":external:ref:`core-metadata-description-content-type` (validated)"""
keywords: _Validator[Optional[List[str]]] = _Validator()
""":external:ref:`core-metadata-keywords`"""
home_page: _Validator[Optional[str]] = _Validator()
""":external:ref:`core-metadata-home-page`"""
download_url: _Validator[Optional[str]] = _Validator(added="1.1")
""":external:ref:`core-metadata-download-url`"""
author: _Validator[Optional[str]] = _Validator()
""":external:ref:`core-metadata-author`"""
author_email: _Validator[Optional[str]] = _Validator()
""":external:ref:`core-metadata-author-email`"""
maintainer: _Validator[Optional[str]] = _Validator(added="1.2")
""":external:ref:`core-metadata-maintainer`"""
maintainer_email: _Validator[Optional[str]] = _Validator(added="1.2")
""":external:ref:`core-metadata-maintainer-email`"""
license: _Validator[Optional[str]] = _Validator()
""":external:ref:`core-metadata-license`"""
classifiers: _Validator[Optional[List[str]]] = _Validator(added="1.1")
""":external:ref:`core-metadata-classifier`"""
requires_dist: _Validator[Optional[List[requirements.Requirement]]] = _Validator(
added="1.2"
)
""":external:ref:`core-metadata-requires-dist`"""
requires_python: _Validator[Optional[specifiers.SpecifierSet]] = _Validator(
added="1.2"
)
""":external:ref:`core-metadata-requires-python`"""
# Because `Requires-External` allows for non-PEP 440 version specifiers, we
# don't do any processing on the values.
requires_external: _Validator[Optional[List[str]]] = _Validator(added="1.2")
""":external:ref:`core-metadata-requires-external`"""
project_urls: _Validator[Optional[Dict[str, str]]] = _Validator(added="1.2")
""":external:ref:`core-metadata-project-url`"""
# PEP 685 lets us raise an error if an extra doesn't pass `Name` validation
# regardless of metadata version.
provides_extra: _Validator[Optional[List[utils.NormalizedName]]] = _Validator(
added="2.1",
)
""":external:ref:`core-metadata-provides-extra`"""
provides_dist: _Validator[Optional[List[str]]] = _Validator(added="1.2")
""":external:ref:`core-metadata-provides-dist`"""
obsoletes_dist: _Validator[Optional[List[str]]] = _Validator(added="1.2")
""":external:ref:`core-metadata-obsoletes-dist`"""
requires: _Validator[Optional[List[str]]] = _Validator(added="1.1")
"""``Requires`` (deprecated)"""
provides: _Validator[Optional[List[str]]] = _Validator(added="1.1")
"""``Provides`` (deprecated)"""
obsoletes: _Validator[Optional[List[str]]] = _Validator(added="1.1")
"""``Obsoletes`` (deprecated)"""

View file

@ -2,13 +2,13 @@
# 2.0, and the BSD License. See the LICENSE file in the root of this repository # 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details. # for complete details.
import urllib.parse from typing import Any, Iterator, Optional, Set
from typing import Any, List, Optional, Set
from ._parser import parse_requirement as _parse_requirement from ._parser import parse_requirement as _parse_requirement
from ._tokenizer import ParserSyntaxError from ._tokenizer import ParserSyntaxError
from .markers import Marker, _normalize_extra_values from .markers import Marker, _normalize_extra_values
from .specifiers import SpecifierSet from .specifiers import SpecifierSet
from .utils import canonicalize_name
class InvalidRequirement(ValueError): class InvalidRequirement(ValueError):
@ -37,57 +37,52 @@ class Requirement:
raise InvalidRequirement(str(e)) from e raise InvalidRequirement(str(e)) from e
self.name: str = parsed.name self.name: str = parsed.name
if parsed.url: self.url: Optional[str] = parsed.url or None
parsed_url = urllib.parse.urlparse(parsed.url) self.extras: Set[str] = set(parsed.extras or [])
if parsed_url.scheme == "file":
if urllib.parse.urlunparse(parsed_url) != parsed.url:
raise InvalidRequirement("Invalid URL given")
elif not (parsed_url.scheme and parsed_url.netloc) or (
not parsed_url.scheme and not parsed_url.netloc
):
raise InvalidRequirement(f"Invalid URL: {parsed.url}")
self.url: Optional[str] = parsed.url
else:
self.url = None
self.extras: Set[str] = set(parsed.extras if parsed.extras else [])
self.specifier: SpecifierSet = SpecifierSet(parsed.specifier) self.specifier: SpecifierSet = SpecifierSet(parsed.specifier)
self.marker: Optional[Marker] = None self.marker: Optional[Marker] = None
if parsed.marker is not None: if parsed.marker is not None:
self.marker = Marker.__new__(Marker) self.marker = Marker.__new__(Marker)
self.marker._markers = _normalize_extra_values(parsed.marker) self.marker._markers = _normalize_extra_values(parsed.marker)
def __str__(self) -> str: def _iter_parts(self, name: str) -> Iterator[str]:
parts: List[str] = [self.name] yield name
if self.extras: if self.extras:
formatted_extras = ",".join(sorted(self.extras)) formatted_extras = ",".join(sorted(self.extras))
parts.append(f"[{formatted_extras}]") yield f"[{formatted_extras}]"
if self.specifier: if self.specifier:
parts.append(str(self.specifier)) yield str(self.specifier)
if self.url: if self.url:
parts.append(f"@ {self.url}") yield f"@ {self.url}"
if self.marker: if self.marker:
parts.append(" ") yield " "
if self.marker: if self.marker:
parts.append(f"; {self.marker}") yield f"; {self.marker}"
return "".join(parts) def __str__(self) -> str:
return "".join(self._iter_parts(self.name))
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Requirement('{self}')>" return f"<Requirement('{self}')>"
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((self.__class__.__name__, str(self))) return hash(
(
self.__class__.__name__,
*self._iter_parts(canonicalize_name(self.name)),
)
)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
if not isinstance(other, Requirement): if not isinstance(other, Requirement):
return NotImplemented return NotImplemented
return ( return (
self.name == other.name canonicalize_name(self.name) == canonicalize_name(other.name)
and self.extras == other.extras and self.extras == other.extras
and self.specifier == other.specifier and self.specifier == other.specifier
and self.url == other.url and self.url == other.url

View file

@ -11,17 +11,7 @@
import abc import abc
import itertools import itertools
import re import re
from typing import ( from typing import Callable, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
Callable,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
from .utils import canonicalize_version from .utils import canonicalize_version
from .version import Version from .version import Version
@ -383,7 +373,7 @@ class Specifier(BaseSpecifier):
# We want everything but the last item in the version, but we want to # We want everything but the last item in the version, but we want to
# ignore suffix segments. # ignore suffix segments.
prefix = ".".join( prefix = _version_join(
list(itertools.takewhile(_is_not_suffix, _version_split(spec)))[:-1] list(itertools.takewhile(_is_not_suffix, _version_split(spec)))[:-1]
) )
@ -404,13 +394,13 @@ class Specifier(BaseSpecifier):
) )
# Get the normalized version string ignoring the trailing .* # Get the normalized version string ignoring the trailing .*
normalized_spec = canonicalize_version(spec[:-2], strip_trailing_zero=False) normalized_spec = canonicalize_version(spec[:-2], strip_trailing_zero=False)
# Split the spec out by dots, and pretend that there is an implicit # Split the spec out by bangs and dots, and pretend that there is
# dot in between a release segment and a pre-release segment. # an implicit dot in between a release segment and a pre-release segment.
split_spec = _version_split(normalized_spec) split_spec = _version_split(normalized_spec)
# Split the prospective version out by dots, and pretend that there # Split the prospective version out by bangs and dots, and pretend
# is an implicit dot in between a release segment and a pre-release # that there is an implicit dot in between a release segment and
# segment. # a pre-release segment.
split_prospective = _version_split(normalized_prospective) split_prospective = _version_split(normalized_prospective)
# 0-pad the prospective version before shortening it to get the correct # 0-pad the prospective version before shortening it to get the correct
@ -644,8 +634,19 @@ _prefix_regex = re.compile(r"^([0-9]+)((?:a|b|c|rc)[0-9]+)$")
def _version_split(version: str) -> List[str]: def _version_split(version: str) -> List[str]:
"""Split version into components.
The split components are intended for version comparison. The logic does
not attempt to retain the original version string, so joining the
components back with :func:`_version_join` may not produce the original
version string.
"""
result: List[str] = [] result: List[str] = []
for item in version.split("."):
epoch, _, rest = version.rpartition("!")
result.append(epoch or "0")
for item in rest.split("."):
match = _prefix_regex.search(item) match = _prefix_regex.search(item)
if match: if match:
result.extend(match.groups()) result.extend(match.groups())
@ -654,6 +655,17 @@ def _version_split(version: str) -> List[str]:
return result return result
def _version_join(components: List[str]) -> str:
"""Join split version components into a version string.
This function assumes the input came from :func:`_version_split`, where the
first component must be the epoch (either empty or numeric), and all other
components numeric.
"""
epoch, *rest = components
return f"{epoch}!{'.'.join(rest)}"
def _is_not_suffix(segment: str) -> bool: def _is_not_suffix(segment: str) -> bool:
return not any( return not any(
segment.startswith(prefix) for prefix in ("dev", "a", "b", "rc", "post") segment.startswith(prefix) for prefix in ("dev", "a", "b", "rc", "post")
@ -675,7 +687,10 @@ def _pad_version(left: List[str], right: List[str]) -> Tuple[List[str], List[str
left_split.insert(1, ["0"] * max(0, len(right_split[0]) - len(left_split[0]))) left_split.insert(1, ["0"] * max(0, len(right_split[0]) - len(left_split[0])))
right_split.insert(1, ["0"] * max(0, len(left_split[0]) - len(right_split[0]))) right_split.insert(1, ["0"] * max(0, len(left_split[0]) - len(right_split[0])))
return (list(itertools.chain(*left_split)), list(itertools.chain(*right_split))) return (
list(itertools.chain.from_iterable(left_split)),
list(itertools.chain.from_iterable(right_split)),
)
class SpecifierSet(BaseSpecifier): class SpecifierSet(BaseSpecifier):
@ -707,14 +722,8 @@ class SpecifierSet(BaseSpecifier):
# strip each item to remove leading/trailing whitespace. # strip each item to remove leading/trailing whitespace.
split_specifiers = [s.strip() for s in specifiers.split(",") if s.strip()] split_specifiers = [s.strip() for s in specifiers.split(",") if s.strip()]
# Parsed each individual specifier, attempting first to make it a # Make each individual specifier a Specifier and save in a frozen set for later.
# Specifier. self._specs = frozenset(map(Specifier, split_specifiers))
parsed: Set[Specifier] = set()
for specifier in split_specifiers:
parsed.add(Specifier(specifier))
# Turn our parsed specifiers into a frozen set and save them for later.
self._specs = frozenset(parsed)
# Store our prereleases value so we can use it later to determine if # Store our prereleases value so we can use it later to determine if
# we accept prereleases or not. # we accept prereleases or not.

View file

@ -4,6 +4,8 @@
import logging import logging
import platform import platform
import re
import struct
import subprocess import subprocess
import sys import sys
import sysconfig import sysconfig
@ -37,7 +39,7 @@ INTERPRETER_SHORT_NAMES: Dict[str, str] = {
} }
_32_BIT_INTERPRETER = sys.maxsize <= 2**32 _32_BIT_INTERPRETER = struct.calcsize("P") == 4
class Tag: class Tag:
@ -123,20 +125,37 @@ def _normalize_string(string: str) -> str:
return string.replace(".", "_").replace("-", "_").replace(" ", "_") return string.replace(".", "_").replace("-", "_").replace(" ", "_")
def _abi3_applies(python_version: PythonVersion) -> bool: def _is_threaded_cpython(abis: List[str]) -> bool:
"""
Determine if the ABI corresponds to a threaded (`--disable-gil`) build.
The threaded builds are indicated by a "t" in the abiflags.
"""
if len(abis) == 0:
return False
# expect e.g., cp313
m = re.match(r"cp\d+(.*)", abis[0])
if not m:
return False
abiflags = m.group(1)
return "t" in abiflags
def _abi3_applies(python_version: PythonVersion, threading: bool) -> bool:
""" """
Determine if the Python version supports abi3. Determine if the Python version supports abi3.
PEP 384 was first implemented in Python 3.2. PEP 384 was first implemented in Python 3.2. The threaded (`--disable-gil`)
builds do not support abi3.
""" """
return len(python_version) > 1 and tuple(python_version) >= (3, 2) return len(python_version) > 1 and tuple(python_version) >= (3, 2) and not threading
def _cpython_abis(py_version: PythonVersion, warn: bool = False) -> List[str]: def _cpython_abis(py_version: PythonVersion, warn: bool = False) -> List[str]:
py_version = tuple(py_version) # To allow for version comparison. py_version = tuple(py_version) # To allow for version comparison.
abis = [] abis = []
version = _version_nodot(py_version[:2]) version = _version_nodot(py_version[:2])
debug = pymalloc = ucs4 = "" threading = debug = pymalloc = ucs4 = ""
with_debug = _get_config_var("Py_DEBUG", warn) with_debug = _get_config_var("Py_DEBUG", warn)
has_refcount = hasattr(sys, "gettotalrefcount") has_refcount = hasattr(sys, "gettotalrefcount")
# Windows doesn't set Py_DEBUG, so checking for support of debug-compiled # Windows doesn't set Py_DEBUG, so checking for support of debug-compiled
@ -145,6 +164,8 @@ def _cpython_abis(py_version: PythonVersion, warn: bool = False) -> List[str]:
has_ext = "_d.pyd" in EXTENSION_SUFFIXES has_ext = "_d.pyd" in EXTENSION_SUFFIXES
if with_debug or (with_debug is None and (has_refcount or has_ext)): if with_debug or (with_debug is None and (has_refcount or has_ext)):
debug = "d" debug = "d"
if py_version >= (3, 13) and _get_config_var("Py_GIL_DISABLED", warn):
threading = "t"
if py_version < (3, 8): if py_version < (3, 8):
with_pymalloc = _get_config_var("WITH_PYMALLOC", warn) with_pymalloc = _get_config_var("WITH_PYMALLOC", warn)
if with_pymalloc or with_pymalloc is None: if with_pymalloc or with_pymalloc is None:
@ -158,13 +179,8 @@ def _cpython_abis(py_version: PythonVersion, warn: bool = False) -> List[str]:
elif debug: elif debug:
# Debug builds can also load "normal" extension modules. # Debug builds can also load "normal" extension modules.
# We can also assume no UCS-4 or pymalloc requirement. # We can also assume no UCS-4 or pymalloc requirement.
abis.append(f"cp{version}") abis.append(f"cp{version}{threading}")
abis.insert( abis.insert(0, f"cp{version}{threading}{debug}{pymalloc}{ucs4}")
0,
"cp{version}{debug}{pymalloc}{ucs4}".format(
version=version, debug=debug, pymalloc=pymalloc, ucs4=ucs4
),
)
return abis return abis
@ -212,11 +228,14 @@ def cpython_tags(
for abi in abis: for abi in abis:
for platform_ in platforms: for platform_ in platforms:
yield Tag(interpreter, abi, platform_) yield Tag(interpreter, abi, platform_)
if _abi3_applies(python_version):
threading = _is_threaded_cpython(abis)
use_abi3 = _abi3_applies(python_version, threading)
if use_abi3:
yield from (Tag(interpreter, "abi3", platform_) for platform_ in platforms) yield from (Tag(interpreter, "abi3", platform_) for platform_ in platforms)
yield from (Tag(interpreter, "none", platform_) for platform_ in platforms) yield from (Tag(interpreter, "none", platform_) for platform_ in platforms)
if _abi3_applies(python_version): if use_abi3:
for minor_version in range(python_version[1] - 1, 1, -1): for minor_version in range(python_version[1] - 1, 1, -1):
for platform_ in platforms: for platform_ in platforms:
interpreter = "cp{version}".format( interpreter = "cp{version}".format(
@ -406,7 +425,7 @@ def mac_platforms(
check=True, check=True,
env={"SYSTEM_VERSION_COMPAT": "0"}, env={"SYSTEM_VERSION_COMPAT": "0"},
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
universal_newlines=True, text=True,
).stdout ).stdout
version = cast("MacVersion", tuple(map(int, version_str.split(".")[:2]))) version = cast("MacVersion", tuple(map(int, version_str.split(".")[:2])))
else: else:
@ -469,15 +488,21 @@ def mac_platforms(
def _linux_platforms(is_32bit: bool = _32_BIT_INTERPRETER) -> Iterator[str]: def _linux_platforms(is_32bit: bool = _32_BIT_INTERPRETER) -> Iterator[str]:
linux = _normalize_string(sysconfig.get_platform()) linux = _normalize_string(sysconfig.get_platform())
if not linux.startswith("linux_"):
# we should never be here, just yield the sysconfig one and return
yield linux
return
if is_32bit: if is_32bit:
if linux == "linux_x86_64": if linux == "linux_x86_64":
linux = "linux_i686" linux = "linux_i686"
elif linux == "linux_aarch64": elif linux == "linux_aarch64":
linux = "linux_armv7l" linux = "linux_armv8l"
_, arch = linux.split("_", 1) _, arch = linux.split("_", 1)
yield from _manylinux.platform_tags(linux, arch) archs = {"armv8l": ["armv8l", "armv7l"]}.get(arch, [arch])
yield from _musllinux.platform_tags(arch) yield from _manylinux.platform_tags(archs)
yield linux yield from _musllinux.platform_tags(archs)
for arch in archs:
yield f"linux_{arch}"
def _generic_platforms() -> Iterator[str]: def _generic_platforms() -> Iterator[str]:

View file

@ -12,6 +12,12 @@ BuildTag = Union[Tuple[()], Tuple[int, str]]
NormalizedName = NewType("NormalizedName", str) NormalizedName = NewType("NormalizedName", str)
class InvalidName(ValueError):
"""
An invalid distribution name; users should refer to the packaging user guide.
"""
class InvalidWheelFilename(ValueError): class InvalidWheelFilename(ValueError):
""" """
An invalid wheel filename was found, users should refer to PEP 427. An invalid wheel filename was found, users should refer to PEP 427.
@ -24,17 +30,28 @@ class InvalidSdistFilename(ValueError):
""" """
# Core metadata spec for `Name`
_validate_regex = re.compile(
r"^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$", re.IGNORECASE
)
_canonicalize_regex = re.compile(r"[-_.]+") _canonicalize_regex = re.compile(r"[-_.]+")
_normalized_regex = re.compile(r"^([a-z0-9]|[a-z0-9]([a-z0-9-](?!--))*[a-z0-9])$")
# PEP 427: The build number must start with a digit. # PEP 427: The build number must start with a digit.
_build_tag_regex = re.compile(r"(\d+)(.*)") _build_tag_regex = re.compile(r"(\d+)(.*)")
def canonicalize_name(name: str) -> NormalizedName: def canonicalize_name(name: str, *, validate: bool = False) -> NormalizedName:
if validate and not _validate_regex.match(name):
raise InvalidName(f"name is invalid: {name!r}")
# This is taken from PEP 503. # This is taken from PEP 503.
value = _canonicalize_regex.sub("-", name).lower() value = _canonicalize_regex.sub("-", name).lower()
return cast(NormalizedName, value) return cast(NormalizedName, value)
def is_normalized_name(name: str) -> bool:
return _normalized_regex.match(name) is not None
def canonicalize_version( def canonicalize_version(
version: Union[Version, str], *, strip_trailing_zero: bool = True version: Union[Version, str], *, strip_trailing_zero: bool = True
) -> str: ) -> str:
@ -100,11 +117,18 @@ def parse_wheel_filename(
parts = filename.split("-", dashes - 2) parts = filename.split("-", dashes - 2)
name_part = parts[0] name_part = parts[0]
# See PEP 427 for the rules on escaping the project name # See PEP 427 for the rules on escaping the project name.
if "__" in name_part or re.match(r"^[\w\d._]*$", name_part, re.UNICODE) is None: if "__" in name_part or re.match(r"^[\w\d._]*$", name_part, re.UNICODE) is None:
raise InvalidWheelFilename(f"Invalid project name: {filename}") raise InvalidWheelFilename(f"Invalid project name: {filename}")
name = canonicalize_name(name_part) name = canonicalize_name(name_part)
version = Version(parts[1])
try:
version = Version(parts[1])
except InvalidVersion as e:
raise InvalidWheelFilename(
f"Invalid wheel filename (invalid version): {filename}"
) from e
if dashes == 5: if dashes == 5:
build_part = parts[2] build_part = parts[2]
build_match = _build_tag_regex.match(build_part) build_match = _build_tag_regex.match(build_part)
@ -137,5 +161,12 @@ def parse_sdist_filename(filename: str) -> Tuple[NormalizedName, Version]:
raise InvalidSdistFilename(f"Invalid sdist filename: {filename}") raise InvalidSdistFilename(f"Invalid sdist filename: {filename}")
name = canonicalize_name(name_part) name = canonicalize_name(name_part)
version = Version(version_part)
try:
version = Version(version_part)
except InvalidVersion as e:
raise InvalidSdistFilename(
f"Invalid sdist filename (invalid version): {filename}"
) from e
return (name, version) return (name, version)

View file

@ -7,37 +7,39 @@
from packaging.version import parse, Version from packaging.version import parse, Version
""" """
import collections
import itertools import itertools
import re import re
from typing import Any, Callable, Optional, SupportsInt, Tuple, Union from typing import Any, Callable, NamedTuple, Optional, SupportsInt, Tuple, Union
from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType
__all__ = ["VERSION_PATTERN", "parse", "Version", "InvalidVersion"] __all__ = ["VERSION_PATTERN", "parse", "Version", "InvalidVersion"]
InfiniteTypes = Union[InfinityType, NegativeInfinityType] LocalType = Tuple[Union[int, str], ...]
PrePostDevType = Union[InfiniteTypes, Tuple[str, int]]
SubLocalType = Union[InfiniteTypes, int, str] CmpPrePostDevType = Union[InfinityType, NegativeInfinityType, Tuple[str, int]]
LocalType = Union[ CmpLocalType = Union[
NegativeInfinityType, NegativeInfinityType,
Tuple[ Tuple[Union[Tuple[int, str], Tuple[NegativeInfinityType, Union[int, str]]], ...],
Union[
SubLocalType,
Tuple[SubLocalType, str],
Tuple[NegativeInfinityType, SubLocalType],
],
...,
],
] ]
CmpKey = Tuple[ CmpKey = Tuple[
int, Tuple[int, ...], PrePostDevType, PrePostDevType, PrePostDevType, LocalType int,
Tuple[int, ...],
CmpPrePostDevType,
CmpPrePostDevType,
CmpPrePostDevType,
CmpLocalType,
] ]
VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool] VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool]
_Version = collections.namedtuple(
"_Version", ["epoch", "release", "dev", "pre", "post", "local"] class _Version(NamedTuple):
) epoch: int
release: Tuple[int, ...]
dev: Optional[Tuple[str, int]]
pre: Optional[Tuple[str, int]]
post: Optional[Tuple[str, int]]
local: Optional[LocalType]
def parse(version: str) -> "Version": def parse(version: str) -> "Version":
@ -117,7 +119,7 @@ _VERSION_PATTERN = r"""
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment (?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
(?P<pre> # pre-release (?P<pre> # pre-release
[-_\.]? [-_\.]?
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview)) (?P<pre_l>alpha|a|beta|b|preview|pre|c|rc)
[-_\.]? [-_\.]?
(?P<pre_n>[0-9]+)? (?P<pre_n>[0-9]+)?
)? )?
@ -269,8 +271,7 @@ class Version(_BaseVersion):
>>> Version("1!2.0.0").epoch >>> Version("1!2.0.0").epoch
1 1
""" """
_epoch: int = self._version.epoch return self._version.epoch
return _epoch
@property @property
def release(self) -> Tuple[int, ...]: def release(self) -> Tuple[int, ...]:
@ -286,8 +287,7 @@ class Version(_BaseVersion):
Includes trailing zeroes but not the epoch or any pre-release / development / Includes trailing zeroes but not the epoch or any pre-release / development /
post-release suffixes. post-release suffixes.
""" """
_release: Tuple[int, ...] = self._version.release return self._version.release
return _release
@property @property
def pre(self) -> Optional[Tuple[str, int]]: def pre(self) -> Optional[Tuple[str, int]]:
@ -302,8 +302,7 @@ class Version(_BaseVersion):
>>> Version("1.2.3rc1").pre >>> Version("1.2.3rc1").pre
('rc', 1) ('rc', 1)
""" """
_pre: Optional[Tuple[str, int]] = self._version.pre return self._version.pre
return _pre
@property @property
def post(self) -> Optional[int]: def post(self) -> Optional[int]:
@ -451,7 +450,7 @@ class Version(_BaseVersion):
def _parse_letter_version( def _parse_letter_version(
letter: str, number: Union[str, bytes, SupportsInt] letter: Optional[str], number: Union[str, bytes, SupportsInt, None]
) -> Optional[Tuple[str, int]]: ) -> Optional[Tuple[str, int]]:
if letter: if letter:
@ -489,7 +488,7 @@ def _parse_letter_version(
_local_version_separators = re.compile(r"[\._-]") _local_version_separators = re.compile(r"[\._-]")
def _parse_local_version(local: str) -> Optional[LocalType]: def _parse_local_version(local: Optional[str]) -> Optional[LocalType]:
""" """
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve"). Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
""" """
@ -507,7 +506,7 @@ def _cmpkey(
pre: Optional[Tuple[str, int]], pre: Optional[Tuple[str, int]],
post: Optional[Tuple[str, int]], post: Optional[Tuple[str, int]],
dev: Optional[Tuple[str, int]], dev: Optional[Tuple[str, int]],
local: Optional[Tuple[SubLocalType]], local: Optional[LocalType],
) -> CmpKey: ) -> CmpKey:
# When we compare a release version, we want to compare it with all of the # When we compare a release version, we want to compare it with all of the
@ -524,7 +523,7 @@ def _cmpkey(
# if there is not a pre or a post segment. If we have one of those then # if there is not a pre or a post segment. If we have one of those then
# the normal sorting rules will handle this case correctly. # the normal sorting rules will handle this case correctly.
if pre is None and post is None and dev is not None: if pre is None and post is None and dev is not None:
_pre: PrePostDevType = NegativeInfinity _pre: CmpPrePostDevType = NegativeInfinity
# Versions without a pre-release (except as noted above) should sort after # Versions without a pre-release (except as noted above) should sort after
# those with one. # those with one.
elif pre is None: elif pre is None:
@ -534,21 +533,21 @@ def _cmpkey(
# Versions without a post segment should sort before those with one. # Versions without a post segment should sort before those with one.
if post is None: if post is None:
_post: PrePostDevType = NegativeInfinity _post: CmpPrePostDevType = NegativeInfinity
else: else:
_post = post _post = post
# Versions without a development segment should sort after those with one. # Versions without a development segment should sort after those with one.
if dev is None: if dev is None:
_dev: PrePostDevType = Infinity _dev: CmpPrePostDevType = Infinity
else: else:
_dev = dev _dev = dev
if local is None: if local is None:
# Versions without a local segment should sort before those with one. # Versions without a local segment should sort before those with one.
_local: LocalType = NegativeInfinity _local: CmpLocalType = NegativeInfinity
else: else:
# Versions with a local segment need that segment parsed to implement # Versions with a local segment need that segment parsed to implement
# the sorting rules in PEP440. # the sorting rules in PEP440.

View file

@ -0,0 +1 @@
exclude = ["*"]

File diff suppressed because it is too large Load diff

View file

@ -1,11 +1,13 @@
packaging==23.1 packaging==24
platformdirs==2.6.2 platformdirs==2.6.2
# required for platformdirs on Python < 3.8
typing_extensions==4.4.0
jaraco.text==3.7.0 jaraco.text==3.7.0
# required for jaraco.text on older Pythons # required for jaraco.text on older Pythons
importlib_resources==5.10.2 importlib_resources==5.10.2
# required for importlib_resources on older Pythons # required for importlib_resources on older Pythons
zipp==3.7.0 zipp==3.7.0
# required for jaraco.functools
more_itertools==10.2.0
# required for jaraco.context on older Pythons
backports.tarfile

View file

@ -1,5 +1,8 @@
from importlib.machinery import ModuleSpec
import importlib.util import importlib.util
import sys import sys
from types import ModuleType
from typing import Iterable, Optional, Sequence
class VendorImporter: class VendorImporter:
@ -8,7 +11,12 @@ class VendorImporter:
or otherwise naturally-installed packages from root_name. or otherwise naturally-installed packages from root_name.
""" """
def __init__(self, root_name, vendored_names=(), vendor_pkg=None): def __init__(
self,
root_name: str,
vendored_names: Iterable[str] = (),
vendor_pkg: Optional[str] = None,
):
self.root_name = root_name self.root_name = root_name
self.vendored_names = set(vendored_names) self.vendored_names = set(vendored_names)
self.vendor_pkg = vendor_pkg or root_name.replace('extern', '_vendor') self.vendor_pkg = vendor_pkg or root_name.replace('extern', '_vendor')
@ -26,7 +34,7 @@ class VendorImporter:
root, base, target = fullname.partition(self.root_name + '.') root, base, target = fullname.partition(self.root_name + '.')
return not root and any(map(target.startswith, self.vendored_names)) return not root and any(map(target.startswith, self.vendored_names))
def load_module(self, fullname): def load_module(self, fullname: str):
""" """
Iterate over the search path to locate and load fullname. Iterate over the search path to locate and load fullname.
""" """
@ -48,16 +56,22 @@ class VendorImporter:
"distribution.".format(**locals()) "distribution.".format(**locals())
) )
def create_module(self, spec): def create_module(self, spec: ModuleSpec):
return self.load_module(spec.name) return self.load_module(spec.name)
def exec_module(self, module): def exec_module(self, module: ModuleType):
pass pass
def find_spec(self, fullname, path=None, target=None): def find_spec(
self,
fullname: str,
path: Optional[Sequence[str]] = None,
target: Optional[ModuleType] = None,
):
"""Return a module spec for vendored names.""" """Return a module spec for vendored names."""
return ( return (
importlib.util.spec_from_loader(fullname, self) # This should fix itself next mypy release https://github.com/python/typeshed/pull/11890
importlib.util.spec_from_loader(fullname, self) # type: ignore[arg-type]
if self._module_matches_namespace(fullname) if self._module_matches_namespace(fullname)
else None else None
) )
@ -70,11 +84,20 @@ class VendorImporter:
sys.meta_path.append(self) sys.meta_path.append(self)
# [[[cog
# import cog
# from tools.vendored import yield_top_level
# names = "\n".join(f" {x!r}," for x in yield_top_level('pkg_resources'))
# cog.outl(f"names = (\n{names}\n)")
# ]]]
names = ( names = (
'backports',
'importlib_resources',
'jaraco',
'more_itertools',
'packaging', 'packaging',
'platformdirs', 'platformdirs',
'jaraco', 'zipp',
'importlib_resources',
'more_itertools',
) )
# [[[end]]]
VendorImporter(__name__, names).install() VendorImporter(__name__, names).install()