Update urllib3 1.26.15 (25cca389) → 2.0.4 (af7c78fa).

This commit is contained in:
JackDandy 2023-09-17 20:52:08 +01:00
parent 4bb3ba0a15
commit 07935763df
49 changed files with 4441 additions and 5007 deletions

View file

@ -10,6 +10,7 @@
* Update package resource API 67.5.1 (f51eccd) to 68.1.2 (1ef36f2) * Update package resource API 67.5.1 (f51eccd) to 68.1.2 (1ef36f2)
* Update soupsieve 2.3.2.post1 (792d566) to 2.4.1 (2e66beb) * Update soupsieve 2.3.2.post1 (792d566) to 2.4.1 (2e66beb)
* Update Tornado Web Server 6.3.2 (e3aa6c5) to 6.3.3 (e4d6984) * Update Tornado Web Server 6.3.2 (e3aa6c5) to 6.3.3 (e4d6984)
* Update urllib3 1.26.15 (25cca389) to 2.0.4 (af7c78fa)
* Add thefuzz 0.19.0 (c2cd4f4) as a replacement with fallback to fuzzywuzzy 0.18.0 (2188520) * Add thefuzz 0.19.0 (c2cd4f4) as a replacement with fallback to fuzzywuzzy 0.18.0 (2188520)
* Fix regex that was not using py312 notation * Fix regex that was not using py312 notation
* Change sort backlog and manual segment search results episode number * Change sort backlog and manual segment search results episode number

View file

@ -1,23 +1,48 @@
""" """
Python HTTP library with thread-safe connection pooling, file post support, user friendly, and more Python HTTP library with thread-safe connection pooling, file post support, user friendly, and more
""" """
from __future__ import absolute_import
from __future__ import annotations
# Set default logging handler to avoid "No handler found" warnings. # Set default logging handler to avoid "No handler found" warnings.
import logging import logging
import typing
import warnings import warnings
from logging import NullHandler from logging import NullHandler
from . import exceptions from . import exceptions
from ._base_connection import _TYPE_BODY
from ._collections import HTTPHeaderDict
from ._version import __version__ from ._version import __version__
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
from .filepost import encode_multipart_formdata from .filepost import _TYPE_FIELDS, encode_multipart_formdata
from .poolmanager import PoolManager, ProxyManager, proxy_from_url from .poolmanager import PoolManager, ProxyManager, proxy_from_url
from .response import HTTPResponse from .response import BaseHTTPResponse, HTTPResponse
from .util.request import make_headers from .util.request import make_headers
from .util.retry import Retry from .util.retry import Retry
from .util.timeout import Timeout from .util.timeout import Timeout
from .util.url import get_host
# Ensure that Python is compiled with OpenSSL 1.1.1+
# If the 'ssl' module isn't available at all that's
# fine, we only care if the module is available.
try:
import ssl
except ImportError:
pass
else:
if not ssl.OPENSSL_VERSION.startswith("OpenSSL "): # Defensive:
warnings.warn(
"urllib3 v2.0 only supports OpenSSL 1.1.1+, currently "
f"the 'ssl' module is compiled with {ssl.OPENSSL_VERSION!r}. "
"See: https://github.com/urllib3/urllib3/issues/3020",
exceptions.NotOpenSSLWarning,
)
elif ssl.OPENSSL_VERSION_INFO < (1, 1, 1): # Defensive:
raise ImportError(
"urllib3 v2.0 only supports OpenSSL 1.1.1+, currently "
f"the 'ssl' module is compiled with {ssl.OPENSSL_VERSION!r}. "
"See: https://github.com/urllib3/urllib3/issues/2168"
)
# === NOTE TO REPACKAGERS AND VENDORS === # === NOTE TO REPACKAGERS AND VENDORS ===
# Please delete this block, this logic is only # Please delete this block, this logic is only
@ -25,12 +50,12 @@ from .util.url import get_host
# See: https://github.com/urllib3/urllib3/issues/2680 # See: https://github.com/urllib3/urllib3/issues/2680
try: try:
import urllib3_secure_extra # type: ignore # noqa: F401 import urllib3_secure_extra # type: ignore # noqa: F401
except ImportError: except ModuleNotFoundError:
pass pass
else: else:
warnings.warn( warnings.warn(
"'urllib3[secure]' extra is deprecated and will be removed " "'urllib3[secure]' extra is deprecated and will be removed "
"in a future release of urllib3 2.x. Read more in this issue: " "in urllib3 v2.1.0. Read more in this issue: "
"https://github.com/urllib3/urllib3/issues/2680", "https://github.com/urllib3/urllib3/issues/2680",
category=DeprecationWarning, category=DeprecationWarning,
stacklevel=2, stacklevel=2,
@ -42,6 +67,7 @@ __version__ = __version__
__all__ = ( __all__ = (
"HTTPConnectionPool", "HTTPConnectionPool",
"HTTPHeaderDict",
"HTTPSConnectionPool", "HTTPSConnectionPool",
"PoolManager", "PoolManager",
"ProxyManager", "ProxyManager",
@ -52,15 +78,18 @@ __all__ = (
"connection_from_url", "connection_from_url",
"disable_warnings", "disable_warnings",
"encode_multipart_formdata", "encode_multipart_formdata",
"get_host",
"make_headers", "make_headers",
"proxy_from_url", "proxy_from_url",
"request",
"BaseHTTPResponse",
) )
logging.getLogger(__name__).addHandler(NullHandler()) logging.getLogger(__name__).addHandler(NullHandler())
def add_stderr_logger(level=logging.DEBUG): def add_stderr_logger(
level: int = logging.DEBUG,
) -> logging.StreamHandler[typing.TextIO]:
""" """
Helper for quickly adding a StreamHandler to the logger. Useful for Helper for quickly adding a StreamHandler to the logger. Useful for
debugging. debugging.
@ -87,16 +116,51 @@ del NullHandler
# mechanisms to silence them. # mechanisms to silence them.
# SecurityWarning's always go off by default. # SecurityWarning's always go off by default.
warnings.simplefilter("always", exceptions.SecurityWarning, append=True) warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
# SubjectAltNameWarning's should go off once per host
warnings.simplefilter("default", exceptions.SubjectAltNameWarning, append=True)
# InsecurePlatformWarning's don't vary between requests, so we keep it default. # InsecurePlatformWarning's don't vary between requests, so we keep it default.
warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True) warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
# SNIMissingWarnings should go off only once.
warnings.simplefilter("default", exceptions.SNIMissingWarning, append=True)
def disable_warnings(category=exceptions.HTTPWarning): def disable_warnings(category: type[Warning] = exceptions.HTTPWarning) -> None:
""" """
Helper for quickly disabling all urllib3 warnings. Helper for quickly disabling all urllib3 warnings.
""" """
warnings.simplefilter("ignore", category) warnings.simplefilter("ignore", category)
_DEFAULT_POOL = PoolManager()
def request(
method: str,
url: str,
*,
body: _TYPE_BODY | None = None,
fields: _TYPE_FIELDS | None = None,
headers: typing.Mapping[str, str] | None = None,
preload_content: bool | None = True,
decode_content: bool | None = True,
redirect: bool | None = True,
retries: Retry | bool | int | None = None,
timeout: Timeout | float | int | None = 3,
json: typing.Any | None = None,
) -> BaseHTTPResponse:
"""
A convenience, top-level request method. It uses a module-global ``PoolManager`` instance.
Therefore, its side effects could be shared across dependencies relying on it.
To avoid side effects create a new ``PoolManager`` instance and use it instead.
The method does not accept low-level ``**urlopen_kw`` keyword arguments.
"""
return _DEFAULT_POOL.request(
method,
url,
body=body,
fields=fields,
headers=headers,
preload_content=preload_content,
decode_content=decode_content,
redirect=redirect,
retries=retries,
timeout=timeout,
json=json,
)

View file

@ -0,0 +1,173 @@
from __future__ import annotations
import typing
from .util.connection import _TYPE_SOCKET_OPTIONS
from .util.timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT
from .util.url import Url
_TYPE_BODY = typing.Union[bytes, typing.IO[typing.Any], typing.Iterable[bytes], str]
class ProxyConfig(typing.NamedTuple):
ssl_context: ssl.SSLContext | None
use_forwarding_for_https: bool
assert_hostname: None | str | Literal[False]
assert_fingerprint: str | None
class _ResponseOptions(typing.NamedTuple):
# TODO: Remove this in favor of a better
# HTTP request/response lifecycle tracking.
request_method: str
request_url: str
preload_content: bool
decode_content: bool
enforce_content_length: bool
if typing.TYPE_CHECKING:
import ssl
from typing_extensions import Literal, Protocol
from .response import BaseHTTPResponse
class BaseHTTPConnection(Protocol):
default_port: typing.ClassVar[int]
default_socket_options: typing.ClassVar[_TYPE_SOCKET_OPTIONS]
host: str
port: int
timeout: None | (
float
) # Instance doesn't store _DEFAULT_TIMEOUT, must be resolved.
blocksize: int
source_address: tuple[str, int] | None
socket_options: _TYPE_SOCKET_OPTIONS | None
proxy: Url | None
proxy_config: ProxyConfig | None
is_verified: bool
proxy_is_verified: bool | None
def __init__(
self,
host: str,
port: int | None = None,
*,
timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
source_address: tuple[str, int] | None = None,
blocksize: int = 8192,
socket_options: _TYPE_SOCKET_OPTIONS | None = ...,
proxy: Url | None = None,
proxy_config: ProxyConfig | None = None,
) -> None:
...
def set_tunnel(
self,
host: str,
port: int | None = None,
headers: typing.Mapping[str, str] | None = None,
scheme: str = "http",
) -> None:
...
def connect(self) -> None:
...
def request(
self,
method: str,
url: str,
body: _TYPE_BODY | None = None,
headers: typing.Mapping[str, str] | None = None,
# We know *at least* botocore is depending on the order of the
# first 3 parameters so to be safe we only mark the later ones
# as keyword-only to ensure we have space to extend.
*,
chunked: bool = False,
preload_content: bool = True,
decode_content: bool = True,
enforce_content_length: bool = True,
) -> None:
...
def getresponse(self) -> BaseHTTPResponse:
...
def close(self) -> None:
...
@property
def is_closed(self) -> bool:
"""Whether the connection either is brand new or has been previously closed.
If this property is True then both ``is_connected`` and ``has_connected_to_proxy``
properties must be False.
"""
@property
def is_connected(self) -> bool:
"""Whether the connection is actively connected to any origin (proxy or target)"""
@property
def has_connected_to_proxy(self) -> bool:
"""Whether the connection has successfully connected to its proxy.
This returns False if no proxy is in use. Used to determine whether
errors are coming from the proxy layer or from tunnelling to the target origin.
"""
class BaseHTTPSConnection(BaseHTTPConnection, Protocol):
default_port: typing.ClassVar[int]
default_socket_options: typing.ClassVar[_TYPE_SOCKET_OPTIONS]
# Certificate verification methods
cert_reqs: int | str | None
assert_hostname: None | str | Literal[False]
assert_fingerprint: str | None
ssl_context: ssl.SSLContext | None
# Trusted CAs
ca_certs: str | None
ca_cert_dir: str | None
ca_cert_data: None | str | bytes
# TLS version
ssl_minimum_version: int | None
ssl_maximum_version: int | None
ssl_version: int | str | None # Deprecated
# Client certificates
cert_file: str | None
key_file: str | None
key_password: str | None
def __init__(
self,
host: str,
port: int | None = None,
*,
timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
source_address: tuple[str, int] | None = None,
blocksize: int = 16384,
socket_options: _TYPE_SOCKET_OPTIONS | None = ...,
proxy: Url | None = None,
proxy_config: ProxyConfig | None = None,
cert_reqs: int | str | None = None,
assert_hostname: None | str | Literal[False] = None,
assert_fingerprint: str | None = None,
server_hostname: str | None = None,
ssl_context: ssl.SSLContext | None = None,
ca_certs: str | None = None,
ca_cert_dir: str | None = None,
ca_cert_data: None | str | bytes = None,
ssl_minimum_version: int | None = None,
ssl_maximum_version: int | None = None,
ssl_version: int | str | None = None, # Deprecated
cert_file: str | None = None,
key_file: str | None = None,
key_password: str | None = None,
) -> None:
...

View file

@ -1,34 +1,66 @@
from __future__ import absolute_import from __future__ import annotations
try:
from collections.abc import Mapping, MutableMapping
except ImportError:
from collections import Mapping, MutableMapping
try:
from threading import RLock
except ImportError: # Platform-specific: No threads available
class RLock:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
import typing
from collections import OrderedDict from collections import OrderedDict
from enum import Enum, auto
from threading import RLock
if typing.TYPE_CHECKING:
# We can only import Protocol if TYPE_CHECKING because it's a development
# dependency, and is not available at runtime.
from typing_extensions import Protocol
class HasGettableStringKeys(Protocol):
def keys(self) -> typing.Iterator[str]:
...
def __getitem__(self, key: str) -> str:
...
from .exceptions import InvalidHeader
from .packages import six
from .packages.six import iterkeys, itervalues
__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"] __all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
_Null = object() # Key type
_KT = typing.TypeVar("_KT")
# Value type
_VT = typing.TypeVar("_VT")
# Default type
_DT = typing.TypeVar("_DT")
ValidHTTPHeaderSource = typing.Union[
"HTTPHeaderDict",
typing.Mapping[str, str],
typing.Iterable[typing.Tuple[str, str]],
"HasGettableStringKeys",
]
class RecentlyUsedContainer(MutableMapping): class _Sentinel(Enum):
not_passed = auto()
def ensure_can_construct_http_header_dict(
potential: object,
) -> ValidHTTPHeaderSource | None:
if isinstance(potential, HTTPHeaderDict):
return potential
elif isinstance(potential, typing.Mapping):
# Full runtime checking of the contents of a Mapping is expensive, so for the
# purposes of typechecking, we assume that any Mapping is the right shape.
return typing.cast(typing.Mapping[str, str], potential)
elif isinstance(potential, typing.Iterable):
# Similarly to Mapping, full runtime checking of the contents of an Iterable is
# expensive, so for the purposes of typechecking, we assume that any Iterable
# is the right shape.
return typing.cast(typing.Iterable[typing.Tuple[str, str]], potential)
elif hasattr(potential, "keys") and hasattr(potential, "__getitem__"):
return typing.cast("HasGettableStringKeys", potential)
else:
return None
class RecentlyUsedContainer(typing.Generic[_KT, _VT], typing.MutableMapping[_KT, _VT]):
""" """
Provides a thread-safe dict-like container which maintains up to Provides a thread-safe dict-like container which maintains up to
``maxsize`` keys while throwing away the least-recently-used keys beyond ``maxsize`` keys while throwing away the least-recently-used keys beyond
@ -42,69 +74,134 @@ class RecentlyUsedContainer(MutableMapping):
``dispose_func(value)`` is called. Callback which will get called ``dispose_func(value)`` is called. Callback which will get called
""" """
ContainerCls = OrderedDict _container: typing.OrderedDict[_KT, _VT]
_maxsize: int
dispose_func: typing.Callable[[_VT], None] | None
lock: RLock
def __init__(self, maxsize=10, dispose_func=None): def __init__(
self,
maxsize: int = 10,
dispose_func: typing.Callable[[_VT], None] | None = None,
) -> None:
super().__init__()
self._maxsize = maxsize self._maxsize = maxsize
self.dispose_func = dispose_func self.dispose_func = dispose_func
self._container = OrderedDict()
self._container = self.ContainerCls()
self.lock = RLock() self.lock = RLock()
def __getitem__(self, key): def __getitem__(self, key: _KT) -> _VT:
# Re-insert the item, moving it to the end of the eviction line. # Re-insert the item, moving it to the end of the eviction line.
with self.lock: with self.lock:
item = self._container.pop(key) item = self._container.pop(key)
self._container[key] = item self._container[key] = item
return item return item
def __setitem__(self, key, value): def __setitem__(self, key: _KT, value: _VT) -> None:
evicted_value = _Null evicted_item = None
with self.lock: with self.lock:
# Possibly evict the existing value of 'key' # Possibly evict the existing value of 'key'
evicted_value = self._container.get(key, _Null) try:
self._container[key] = value # If the key exists, we'll overwrite it, which won't change the
# size of the pool. Because accessing a key should move it to
# the end of the eviction line, we pop it out first.
evicted_item = key, self._container.pop(key)
self._container[key] = value
except KeyError:
# When the key does not exist, we insert the value first so that
# evicting works in all cases, including when self._maxsize is 0
self._container[key] = value
if len(self._container) > self._maxsize:
# If we didn't evict an existing value, and we've hit our maximum
# size, then we have to evict the least recently used item from
# the beginning of the container.
evicted_item = self._container.popitem(last=False)
# If we didn't evict an existing value, we might have to evict the # After releasing the lock on the pool, dispose of any evicted value.
# least recently used item from the beginning of the container. if evicted_item is not None and self.dispose_func:
if len(self._container) > self._maxsize: _, evicted_value = evicted_item
_key, evicted_value = self._container.popitem(last=False)
if self.dispose_func and evicted_value is not _Null:
self.dispose_func(evicted_value) self.dispose_func(evicted_value)
def __delitem__(self, key): def __delitem__(self, key: _KT) -> None:
with self.lock: with self.lock:
value = self._container.pop(key) value = self._container.pop(key)
if self.dispose_func: if self.dispose_func:
self.dispose_func(value) self.dispose_func(value)
def __len__(self): def __len__(self) -> int:
with self.lock: with self.lock:
return len(self._container) return len(self._container)
def __iter__(self): def __iter__(self) -> typing.NoReturn:
raise NotImplementedError( raise NotImplementedError(
"Iteration over this class is unlikely to be threadsafe." "Iteration over this class is unlikely to be threadsafe."
) )
def clear(self): def clear(self) -> None:
with self.lock: with self.lock:
# Copy pointers to all values, then wipe the mapping # Copy pointers to all values, then wipe the mapping
values = list(itervalues(self._container)) values = list(self._container.values())
self._container.clear() self._container.clear()
if self.dispose_func: if self.dispose_func:
for value in values: for value in values:
self.dispose_func(value) self.dispose_func(value)
def keys(self): def keys(self) -> set[_KT]: # type: ignore[override]
with self.lock: with self.lock:
return list(iterkeys(self._container)) return set(self._container.keys())
class HTTPHeaderDict(MutableMapping): class HTTPHeaderDictItemView(typing.Set[typing.Tuple[str, str]]):
"""
HTTPHeaderDict is unusual for a Mapping[str, str] in that it has two modes of
address.
If we directly try to get an item with a particular name, we will get a string
back that is the concatenated version of all the values:
>>> d['X-Header-Name']
'Value1, Value2, Value3'
However, if we iterate over an HTTPHeaderDict's items, we will optionally combine
these values based on whether combine=True was called when building up the dictionary
>>> d = HTTPHeaderDict({"A": "1", "B": "foo"})
>>> d.add("A", "2", combine=True)
>>> d.add("B", "bar")
>>> list(d.items())
[
('A', '1, 2'),
('B', 'foo'),
('B', 'bar'),
]
This class conforms to the interface required by the MutableMapping ABC while
also giving us the nonstandard iteration behavior we want; items with duplicate
keys, ordered by time of first insertion.
"""
_headers: HTTPHeaderDict
def __init__(self, headers: HTTPHeaderDict) -> None:
self._headers = headers
def __len__(self) -> int:
return len(list(self._headers.iteritems()))
def __iter__(self) -> typing.Iterator[tuple[str, str]]:
return self._headers.iteritems()
def __contains__(self, item: object) -> bool:
if isinstance(item, tuple) and len(item) == 2:
passed_key, passed_val = item
if isinstance(passed_key, str) and isinstance(passed_val, str):
return self._headers._has_value_for_header(passed_key, passed_val)
return False
class HTTPHeaderDict(typing.MutableMapping[str, str]):
""" """
:param headers: :param headers:
An iterable of field-value pairs. Must not contain multiple field names An iterable of field-value pairs. Must not contain multiple field names
@ -138,9 +235,11 @@ class HTTPHeaderDict(MutableMapping):
'7' '7'
""" """
def __init__(self, headers=None, **kwargs): _container: typing.MutableMapping[str, list[str]]
super(HTTPHeaderDict, self).__init__()
self._container = OrderedDict() def __init__(self, headers: ValidHTTPHeaderSource | None = None, **kwargs: str):
super().__init__()
self._container = {} # 'dict' is insert-ordered in Python 3.7+
if headers is not None: if headers is not None:
if isinstance(headers, HTTPHeaderDict): if isinstance(headers, HTTPHeaderDict):
self._copy_from(headers) self._copy_from(headers)
@ -149,123 +248,147 @@ class HTTPHeaderDict(MutableMapping):
if kwargs: if kwargs:
self.extend(kwargs) self.extend(kwargs)
def __setitem__(self, key, val): def __setitem__(self, key: str, val: str) -> None:
# avoid a bytes/str comparison by decoding before httplib
if isinstance(key, bytes):
key = key.decode("latin-1")
self._container[key.lower()] = [key, val] self._container[key.lower()] = [key, val]
return self._container[key.lower()]
def __getitem__(self, key): def __getitem__(self, key: str) -> str:
val = self._container[key.lower()] val = self._container[key.lower()]
return ", ".join(val[1:]) return ", ".join(val[1:])
def __delitem__(self, key): def __delitem__(self, key: str) -> None:
del self._container[key.lower()] del self._container[key.lower()]
def __contains__(self, key): def __contains__(self, key: object) -> bool:
return key.lower() in self._container if isinstance(key, str):
return key.lower() in self._container
return False
def __eq__(self, other): def setdefault(self, key: str, default: str = "") -> str:
if not isinstance(other, Mapping) and not hasattr(other, "keys"): return super().setdefault(key, default)
def __eq__(self, other: object) -> bool:
maybe_constructable = ensure_can_construct_http_header_dict(other)
if maybe_constructable is None:
return False return False
if not isinstance(other, type(self)): else:
other = type(self)(other) other_as_http_header_dict = type(self)(maybe_constructable)
return dict((k.lower(), v) for k, v in self.itermerged()) == dict(
(k.lower(), v) for k, v in other.itermerged()
)
def __ne__(self, other): return {k.lower(): v for k, v in self.itermerged()} == {
k.lower(): v for k, v in other_as_http_header_dict.itermerged()
}
def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
if six.PY2: # Python 2 def __len__(self) -> int:
iterkeys = MutableMapping.iterkeys
itervalues = MutableMapping.itervalues
__marker = object()
def __len__(self):
return len(self._container) return len(self._container)
def __iter__(self): def __iter__(self) -> typing.Iterator[str]:
# Only provide the originally cased names # Only provide the originally cased names
for vals in self._container.values(): for vals in self._container.values():
yield vals[0] yield vals[0]
def pop(self, key, default=__marker): def discard(self, key: str) -> None:
"""D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
"""
# Using the MutableMapping function directly fails due to the private marker.
# Using ordinary dict.pop would expose the internal structures.
# So let's reinvent the wheel.
try:
value = self[key]
except KeyError:
if default is self.__marker:
raise
return default
else:
del self[key]
return value
def discard(self, key):
try: try:
del self[key] del self[key]
except KeyError: except KeyError:
pass pass
def add(self, key, val): def add(self, key: str, val: str, *, combine: bool = False) -> None:
"""Adds a (name, value) pair, doesn't overwrite the value if it already """Adds a (name, value) pair, doesn't overwrite the value if it already
exists. exists.
If this is called with combine=True, instead of adding a new header value
as a distinct item during iteration, this will instead append the value to
any existing header value with a comma. If no existing header value exists
for the key, then the value will simply be added, ignoring the combine parameter.
>>> headers = HTTPHeaderDict(foo='bar') >>> headers = HTTPHeaderDict(foo='bar')
>>> headers.add('Foo', 'baz') >>> headers.add('Foo', 'baz')
>>> headers['foo'] >>> headers['foo']
'bar, baz' 'bar, baz'
>>> list(headers.items())
[('foo', 'bar'), ('foo', 'baz')]
>>> headers.add('foo', 'quz', combine=True)
>>> list(headers.items())
[('foo', 'bar, baz, quz')]
""" """
# avoid a bytes/str comparison by decoding before httplib
if isinstance(key, bytes):
key = key.decode("latin-1")
key_lower = key.lower() key_lower = key.lower()
new_vals = [key, val] new_vals = [key, val]
# Keep the common case aka no item present as fast as possible # Keep the common case aka no item present as fast as possible
vals = self._container.setdefault(key_lower, new_vals) vals = self._container.setdefault(key_lower, new_vals)
if new_vals is not vals: if new_vals is not vals:
vals.append(val) # if there are values here, then there is at least the initial
# key/value pair
assert len(vals) >= 2
if combine:
vals[-1] = vals[-1] + ", " + val
else:
vals.append(val)
def extend(self, *args, **kwargs): def extend(self, *args: ValidHTTPHeaderSource, **kwargs: str) -> None:
"""Generic import function for any type of header-like object. """Generic import function for any type of header-like object.
Adapted version of MutableMapping.update in order to insert items Adapted version of MutableMapping.update in order to insert items
with self.add instead of self.__setitem__ with self.add instead of self.__setitem__
""" """
if len(args) > 1: if len(args) > 1:
raise TypeError( raise TypeError(
"extend() takes at most 1 positional " f"extend() takes at most 1 positional arguments ({len(args)} given)"
"arguments ({0} given)".format(len(args))
) )
other = args[0] if len(args) >= 1 else () other = args[0] if len(args) >= 1 else ()
if isinstance(other, HTTPHeaderDict): if isinstance(other, HTTPHeaderDict):
for key, val in other.iteritems(): for key, val in other.iteritems():
self.add(key, val) self.add(key, val)
elif isinstance(other, Mapping): elif isinstance(other, typing.Mapping):
for key in other: for key, val in other.items():
self.add(key, other[key]) self.add(key, val)
elif hasattr(other, "keys"): elif isinstance(other, typing.Iterable):
for key in other.keys(): other = typing.cast(typing.Iterable[typing.Tuple[str, str]], other)
self.add(key, other[key])
else:
for key, value in other: for key, value in other:
self.add(key, value) self.add(key, value)
elif hasattr(other, "keys") and hasattr(other, "__getitem__"):
# THIS IS NOT A TYPESAFE BRANCH
# In this branch, the object has a `keys` attr but is not a Mapping or any of
# the other types indicated in the method signature. We do some stuff with
# it as though it partially implements the Mapping interface, but we're not
# doing that stuff safely AT ALL.
for key in other.keys():
self.add(key, other[key])
for key, value in kwargs.items(): for key, value in kwargs.items():
self.add(key, value) self.add(key, value)
def getlist(self, key, default=__marker): @typing.overload
def getlist(self, key: str) -> list[str]:
...
@typing.overload
def getlist(self, key: str, default: _DT) -> list[str] | _DT:
...
def getlist(
self, key: str, default: _Sentinel | _DT = _Sentinel.not_passed
) -> list[str] | _DT:
"""Returns a list of all the values for the named field. Returns an """Returns a list of all the values for the named field. Returns an
empty list if the key doesn't exist.""" empty list if the key doesn't exist."""
try: try:
vals = self._container[key.lower()] vals = self._container[key.lower()]
except KeyError: except KeyError:
if default is self.__marker: if default is _Sentinel.not_passed:
# _DT is unbound; empty list is instance of List[str]
return [] return []
# _DT is bound; default is instance of _DT
return default return default
else: else:
# _DT may or may not be bound; vals[1:] is instance of List[str], which
# meets our external interface requirement of `Union[List[str], _DT]`.
return vals[1:] return vals[1:]
# Backwards compatibility for httplib # Backwards compatibility for httplib
@ -276,62 +399,65 @@ class HTTPHeaderDict(MutableMapping):
# Backwards compatibility for http.cookiejar # Backwards compatibility for http.cookiejar
get_all = getlist get_all = getlist
def __repr__(self): def __repr__(self) -> str:
return "%s(%s)" % (type(self).__name__, dict(self.itermerged())) return f"{type(self).__name__}({dict(self.itermerged())})"
def _copy_from(self, other): def _copy_from(self, other: HTTPHeaderDict) -> None:
for key in other: for key in other:
val = other.getlist(key) val = other.getlist(key)
if isinstance(val, list): self._container[key.lower()] = [key, *val]
# Don't need to convert tuples
val = list(val)
self._container[key.lower()] = [key] + val
def copy(self): def copy(self) -> HTTPHeaderDict:
clone = type(self)() clone = type(self)()
clone._copy_from(self) clone._copy_from(self)
return clone return clone
def iteritems(self): def iteritems(self) -> typing.Iterator[tuple[str, str]]:
"""Iterate over all header lines, including duplicate ones.""" """Iterate over all header lines, including duplicate ones."""
for key in self: for key in self:
vals = self._container[key.lower()] vals = self._container[key.lower()]
for val in vals[1:]: for val in vals[1:]:
yield vals[0], val yield vals[0], val
def itermerged(self): def itermerged(self) -> typing.Iterator[tuple[str, str]]:
"""Iterate over all headers, merging duplicate ones together.""" """Iterate over all headers, merging duplicate ones together."""
for key in self: for key in self:
val = self._container[key.lower()] val = self._container[key.lower()]
yield val[0], ", ".join(val[1:]) yield val[0], ", ".join(val[1:])
def items(self): def items(self) -> HTTPHeaderDictItemView: # type: ignore[override]
return list(self.iteritems()) return HTTPHeaderDictItemView(self)
@classmethod def _has_value_for_header(self, header_name: str, potential_value: str) -> bool:
def from_httplib(cls, message): # Python 2 if header_name in self:
"""Read headers from a Python 2 httplib message object.""" return potential_value in self._container[header_name.lower()][1:]
# python2.7 does not expose a proper API for exporting multiheaders return False
# efficiently. This function re-reads raw lines from the message
# object and extracts the multiheaders properly.
obs_fold_continued_leaders = (" ", "\t")
headers = []
for line in message.headers: def __ior__(self, other: object) -> HTTPHeaderDict:
if line.startswith(obs_fold_continued_leaders): # Supports extending a header dict in-place using operator |=
if not headers: # combining items with add instead of __setitem__
# We received a header line that starts with OWS as described maybe_constructable = ensure_can_construct_http_header_dict(other)
# in RFC-7230 S3.2.4. This indicates a multiline header, but if maybe_constructable is None:
# there exists no previous header to which we can attach it. return NotImplemented
raise InvalidHeader( self.extend(maybe_constructable)
"Header continuation with no previous header: %s" % line return self
)
else:
key, value = headers[-1]
headers[-1] = (key, value + " " + line.strip())
continue
key, value = line.split(":", 1) def __or__(self, other: object) -> HTTPHeaderDict:
headers.append((key, value.strip())) # Supports merging header dicts using operator |
# combining items with add instead of __setitem__
maybe_constructable = ensure_can_construct_http_header_dict(other)
if maybe_constructable is None:
return NotImplemented
result = self.copy()
result.extend(maybe_constructable)
return result
return cls(headers) def __ror__(self, other: object) -> HTTPHeaderDict:
# Supports merging header dicts using operator | when other is on left side
# combining items with add instead of __setitem__
maybe_constructable = ensure_can_construct_http_header_dict(other)
if maybe_constructable is None:
return NotImplemented
result = type(self)(maybe_constructable)
result.extend(self)
return result

View file

@ -1,12 +1,23 @@
from __future__ import absolute_import from __future__ import annotations
from .filepost import encode_multipart_formdata import json as _json
from .packages.six.moves.urllib.parse import urlencode import typing
from urllib.parse import urlencode
from ._base_connection import _TYPE_BODY
from ._collections import HTTPHeaderDict
from .filepost import _TYPE_FIELDS, encode_multipart_formdata
from .response import BaseHTTPResponse
__all__ = ["RequestMethods"] __all__ = ["RequestMethods"]
_TYPE_ENCODE_URL_FIELDS = typing.Union[
typing.Sequence[typing.Tuple[str, typing.Union[str, bytes]]],
typing.Mapping[str, typing.Union[str, bytes]],
]
class RequestMethods(object):
class RequestMethods:
""" """
Convenience mixin for classes who implement a :meth:`urlopen` method, such Convenience mixin for classes who implement a :meth:`urlopen` method, such
as :class:`urllib3.HTTPConnectionPool` and as :class:`urllib3.HTTPConnectionPool` and
@ -37,25 +48,34 @@ class RequestMethods(object):
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"} _encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
def __init__(self, headers=None): def __init__(self, headers: typing.Mapping[str, str] | None = None) -> None:
self.headers = headers or {} self.headers = headers or {}
def urlopen( def urlopen(
self, self,
method, method: str,
url, url: str,
body=None, body: _TYPE_BODY | None = None,
headers=None, headers: typing.Mapping[str, str] | None = None,
encode_multipart=True, encode_multipart: bool = True,
multipart_boundary=None, multipart_boundary: str | None = None,
**kw **kw: typing.Any,
): # Abstract ) -> BaseHTTPResponse: # Abstract
raise NotImplementedError( raise NotImplementedError(
"Classes extending RequestMethods must implement " "Classes extending RequestMethods must implement "
"their own ``urlopen`` method." "their own ``urlopen`` method."
) )
def request(self, method, url, fields=None, headers=None, **urlopen_kw): def request(
self,
method: str,
url: str,
body: _TYPE_BODY | None = None,
fields: _TYPE_FIELDS | None = None,
headers: typing.Mapping[str, str] | None = None,
json: typing.Any | None = None,
**urlopen_kw: typing.Any,
) -> BaseHTTPResponse:
""" """
Make a request using :meth:`urlopen` with the appropriate encoding of Make a request using :meth:`urlopen` with the appropriate encoding of
``fields`` based on the ``method`` used. ``fields`` based on the ``method`` used.
@ -68,18 +88,45 @@ class RequestMethods(object):
""" """
method = method.upper() method = method.upper()
urlopen_kw["request_url"] = url if json is not None and body is not None:
raise TypeError(
"request got values for both 'body' and 'json' parameters which are mutually exclusive"
)
if json is not None:
if headers is None:
headers = self.headers.copy() # type: ignore
if not ("content-type" in map(str.lower, headers.keys())):
headers["Content-Type"] = "application/json" # type: ignore
body = _json.dumps(json, separators=(",", ":"), ensure_ascii=False).encode(
"utf-8"
)
if body is not None:
urlopen_kw["body"] = body
if method in self._encode_url_methods: if method in self._encode_url_methods:
return self.request_encode_url( return self.request_encode_url(
method, url, fields=fields, headers=headers, **urlopen_kw method,
url,
fields=fields, # type: ignore[arg-type]
headers=headers,
**urlopen_kw,
) )
else: else:
return self.request_encode_body( return self.request_encode_body(
method, url, fields=fields, headers=headers, **urlopen_kw method, url, fields=fields, headers=headers, **urlopen_kw
) )
def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_kw): def request_encode_url(
self,
method: str,
url: str,
fields: _TYPE_ENCODE_URL_FIELDS | None = None,
headers: typing.Mapping[str, str] | None = None,
**urlopen_kw: str,
) -> BaseHTTPResponse:
""" """
Make a request using :meth:`urlopen` with the ``fields`` encoded in Make a request using :meth:`urlopen` with the ``fields`` encoded in
the url. This is useful for request methods like GET, HEAD, DELETE, etc. the url. This is useful for request methods like GET, HEAD, DELETE, etc.
@ -87,7 +134,7 @@ class RequestMethods(object):
if headers is None: if headers is None:
headers = self.headers headers = self.headers
extra_kw = {"headers": headers} extra_kw: dict[str, typing.Any] = {"headers": headers}
extra_kw.update(urlopen_kw) extra_kw.update(urlopen_kw)
if fields: if fields:
@ -97,14 +144,14 @@ class RequestMethods(object):
def request_encode_body( def request_encode_body(
self, self,
method, method: str,
url, url: str,
fields=None, fields: _TYPE_FIELDS | None = None,
headers=None, headers: typing.Mapping[str, str] | None = None,
encode_multipart=True, encode_multipart: bool = True,
multipart_boundary=None, multipart_boundary: str | None = None,
**urlopen_kw **urlopen_kw: str,
): ) -> BaseHTTPResponse:
""" """
Make a request using :meth:`urlopen` with the ``fields`` encoded in Make a request using :meth:`urlopen` with the ``fields`` encoded in
the body. This is useful for request methods like POST, PUT, PATCH, etc. the body. This is useful for request methods like POST, PUT, PATCH, etc.
@ -143,7 +190,8 @@ class RequestMethods(object):
if headers is None: if headers is None:
headers = self.headers headers = self.headers
extra_kw = {"headers": {}} extra_kw: dict[str, typing.Any] = {"headers": HTTPHeaderDict(headers)}
body: bytes | str
if fields: if fields:
if "body" in urlopen_kw: if "body" in urlopen_kw:
@ -157,14 +205,13 @@ class RequestMethods(object):
) )
else: else:
body, content_type = ( body, content_type = (
urlencode(fields), urlencode(fields), # type: ignore[arg-type]
"application/x-www-form-urlencoded", "application/x-www-form-urlencoded",
) )
extra_kw["body"] = body extra_kw["body"] = body
extra_kw["headers"] = {"Content-Type": content_type} extra_kw["headers"].setdefault("Content-Type", content_type)
extra_kw["headers"].update(headers)
extra_kw.update(urlopen_kw) extra_kw.update(urlopen_kw)
return self.urlopen(method, url, **extra_kw) return self.urlopen(method, url, **extra_kw)

View file

@ -1,2 +1,4 @@
# This file is protected via CODEOWNERS # This file is protected via CODEOWNERS
__version__ = "1.26.15" from __future__ import annotations
__version__ = "2.0.4"

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,36 +0,0 @@
"""
This module provides means to detect the App Engine environment.
"""
import os
def is_appengine():
return is_local_appengine() or is_prod_appengine()
def is_appengine_sandbox():
"""Reports if the app is running in the first generation sandbox.
The second generation runtimes are technically still in a sandbox, but it
is much less restrictive, so generally you shouldn't need to check for it.
see https://cloud.google.com/appengine/docs/standard/runtimes
"""
return is_appengine() and os.environ["APPENGINE_RUNTIME"] == "python27"
def is_local_appengine():
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
"SERVER_SOFTWARE", ""
).startswith("Development/")
def is_prod_appengine():
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
"SERVER_SOFTWARE", ""
).startswith("Google App Engine/")
def is_prod_appengine_mvms():
"""Deprecated."""
return False

View file

@ -1,3 +1,5 @@
# type: ignore
""" """
This module uses ctypes to bind a whole bunch of functions and constants from This module uses ctypes to bind a whole bunch of functions and constants from
SecureTransport. The goal here is to provide the low-level API to SecureTransport. The goal here is to provide the low-level API to
@ -29,7 +31,8 @@ license and by oscrypto's:
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import absolute_import
from __future__ import annotations
import platform import platform
from ctypes import ( from ctypes import (
@ -48,8 +51,6 @@ from ctypes import (
) )
from ctypes.util import find_library from ctypes.util import find_library
from ...packages.six import raise_from
if platform.system() != "Darwin": if platform.system() != "Darwin":
raise ImportError("Only macOS is supported") raise ImportError("Only macOS is supported")
@ -57,16 +58,16 @@ version = platform.mac_ver()[0]
version_info = tuple(map(int, version.split("."))) version_info = tuple(map(int, version.split(".")))
if version_info < (10, 8): if version_info < (10, 8):
raise OSError( raise OSError(
"Only OS X 10.8 and newer are supported, not %s.%s" f"Only OS X 10.8 and newer are supported, not {version_info[0]}.{version_info[1]}"
% (version_info[0], version_info[1])
) )
def load_cdll(name, macos10_16_path): def load_cdll(name: str, macos10_16_path: str) -> CDLL:
"""Loads a CDLL by name, falling back to known path on 10.16+""" """Loads a CDLL by name, falling back to known path on 10.16+"""
try: try:
# Big Sur is technically 11 but we use 10.16 due to the Big Sur # Big Sur is technically 11 but we use 10.16 due to the Big Sur
# beta being labeled as 10.16. # beta being labeled as 10.16.
path: str | None
if version_info >= (10, 16): if version_info >= (10, 16):
path = macos10_16_path path = macos10_16_path
else: else:
@ -75,7 +76,7 @@ def load_cdll(name, macos10_16_path):
raise OSError # Caught and reraised as 'ImportError' raise OSError # Caught and reraised as 'ImportError'
return CDLL(path, use_errno=True) return CDLL(path, use_errno=True)
except OSError: except OSError:
raise_from(ImportError("The library %s failed to load" % name), None) raise ImportError(f"The library {name} failed to load") from None
Security = load_cdll( Security = load_cdll(
@ -416,104 +417,14 @@ try:
CoreFoundation.CFStringRef = CFStringRef CoreFoundation.CFStringRef = CFStringRef
CoreFoundation.CFDictionaryRef = CFDictionaryRef CoreFoundation.CFDictionaryRef = CFDictionaryRef
except (AttributeError): except AttributeError:
raise ImportError("Error initializing ctypes") raise ImportError("Error initializing ctypes") from None
class CFConst(object): class CFConst:
""" """
A class object that acts as essentially a namespace for CoreFoundation A class object that acts as essentially a namespace for CoreFoundation
constants. constants.
""" """
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100) kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)
class SecurityConst(object):
"""
A class object that acts as essentially a namespace for Security constants.
"""
kSSLSessionOptionBreakOnServerAuth = 0
kSSLProtocol2 = 1
kSSLProtocol3 = 2
kTLSProtocol1 = 4
kTLSProtocol11 = 7
kTLSProtocol12 = 8
# SecureTransport does not support TLS 1.3 even if there's a constant for it
kTLSProtocol13 = 10
kTLSProtocolMaxSupported = 999
kSSLClientSide = 1
kSSLStreamType = 0
kSecFormatPEMSequence = 10
kSecTrustResultInvalid = 0
kSecTrustResultProceed = 1
# This gap is present on purpose: this was kSecTrustResultConfirm, which
# is deprecated.
kSecTrustResultDeny = 3
kSecTrustResultUnspecified = 4
kSecTrustResultRecoverableTrustFailure = 5
kSecTrustResultFatalTrustFailure = 6
kSecTrustResultOtherError = 7
errSSLProtocol = -9800
errSSLWouldBlock = -9803
errSSLClosedGraceful = -9805
errSSLClosedNoNotify = -9816
errSSLClosedAbort = -9806
errSSLXCertChainInvalid = -9807
errSSLCrypto = -9809
errSSLInternal = -9810
errSSLCertExpired = -9814
errSSLCertNotYetValid = -9815
errSSLUnknownRootCert = -9812
errSSLNoRootCert = -9813
errSSLHostNameMismatch = -9843
errSSLPeerHandshakeFail = -9824
errSSLPeerUserCancelled = -9839
errSSLWeakPeerEphemeralDHKey = -9850
errSSLServerAuthCompleted = -9841
errSSLRecordOverflow = -9847
errSecVerifyFailed = -67808
errSecNoTrustSettings = -25263
errSecItemNotFound = -25300
errSecInvalidTrustSettings = -25262
# Cipher suites. We only pick the ones our default cipher string allows.
# Source: https://developer.apple.com/documentation/security/1550981-ssl_cipher_suite_values
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8
TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F
TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014
TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B
TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013
TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067
TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033
TLS_RSA_WITH_AES_256_GCM_SHA384 = 0x009D
TLS_RSA_WITH_AES_128_GCM_SHA256 = 0x009C
TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x003D
TLS_RSA_WITH_AES_128_CBC_SHA256 = 0x003C
TLS_RSA_WITH_AES_256_CBC_SHA = 0x0035
TLS_RSA_WITH_AES_128_CBC_SHA = 0x002F
TLS_AES_128_GCM_SHA256 = 0x1301
TLS_AES_256_GCM_SHA384 = 0x1302
TLS_AES_128_CCM_8_SHA256 = 0x1305
TLS_AES_128_CCM_SHA256 = 0x1304

View file

@ -7,6 +7,8 @@ CoreFoundation messing about and memory management. The concerns in this module
are almost entirely about trying to avoid memory leaks and providing are almost entirely about trying to avoid memory leaks and providing
appropriate and useful assistance to the higher-level code. appropriate and useful assistance to the higher-level code.
""" """
from __future__ import annotations
import base64 import base64
import ctypes import ctypes
import itertools import itertools
@ -15,8 +17,20 @@ import re
import ssl import ssl
import struct import struct
import tempfile import tempfile
import typing
from .bindings import CFConst, CoreFoundation, Security from .bindings import ( # type: ignore[attr-defined]
CFArray,
CFConst,
CFData,
CFDictionary,
CFMutableArray,
CFString,
CFTypeRef,
CoreFoundation,
SecKeychainRef,
Security,
)
# This regular expression is used to grab PEM data out of a PEM bundle. # This regular expression is used to grab PEM data out of a PEM bundle.
_PEM_CERTS_RE = re.compile( _PEM_CERTS_RE = re.compile(
@ -24,7 +38,7 @@ _PEM_CERTS_RE = re.compile(
) )
def _cf_data_from_bytes(bytestring): def _cf_data_from_bytes(bytestring: bytes) -> CFData:
""" """
Given a bytestring, create a CFData object from it. This CFData object must Given a bytestring, create a CFData object from it. This CFData object must
be CFReleased by the caller. be CFReleased by the caller.
@ -34,7 +48,9 @@ def _cf_data_from_bytes(bytestring):
) )
def _cf_dictionary_from_tuples(tuples): def _cf_dictionary_from_tuples(
tuples: list[tuple[typing.Any, typing.Any]]
) -> CFDictionary:
""" """
Given a list of Python tuples, create an associated CFDictionary. Given a list of Python tuples, create an associated CFDictionary.
""" """
@ -56,7 +72,7 @@ def _cf_dictionary_from_tuples(tuples):
) )
def _cfstr(py_bstr): def _cfstr(py_bstr: bytes) -> CFString:
""" """
Given a Python binary data, create a CFString. Given a Python binary data, create a CFString.
The string must be CFReleased by the caller. The string must be CFReleased by the caller.
@ -70,7 +86,7 @@ def _cfstr(py_bstr):
return cf_str return cf_str
def _create_cfstring_array(lst): def _create_cfstring_array(lst: list[bytes]) -> CFMutableArray:
""" """
Given a list of Python binary data, create an associated CFMutableArray. Given a list of Python binary data, create an associated CFMutableArray.
The array must be CFReleased by the caller. The array must be CFReleased by the caller.
@ -97,11 +113,11 @@ def _create_cfstring_array(lst):
except BaseException as e: except BaseException as e:
if cf_arr: if cf_arr:
CoreFoundation.CFRelease(cf_arr) CoreFoundation.CFRelease(cf_arr)
raise ssl.SSLError("Unable to allocate array: %s" % (e,)) raise ssl.SSLError(f"Unable to allocate array: {e}") from None
return cf_arr return cf_arr
def _cf_string_to_unicode(value): def _cf_string_to_unicode(value: CFString) -> str | None:
""" """
Creates a Unicode string from a CFString object. Used entirely for error Creates a Unicode string from a CFString object. Used entirely for error
reporting. reporting.
@ -123,10 +139,12 @@ def _cf_string_to_unicode(value):
string = buffer.value string = buffer.value
if string is not None: if string is not None:
string = string.decode("utf-8") string = string.decode("utf-8")
return string return string # type: ignore[no-any-return]
def _assert_no_error(error, exception_class=None): def _assert_no_error(
error: int, exception_class: type[BaseException] | None = None
) -> None:
""" """
Checks the return code and throws an exception if there is an error to Checks the return code and throws an exception if there is an error to
report report
@ -138,8 +156,8 @@ def _assert_no_error(error, exception_class=None):
output = _cf_string_to_unicode(cf_error_string) output = _cf_string_to_unicode(cf_error_string)
CoreFoundation.CFRelease(cf_error_string) CoreFoundation.CFRelease(cf_error_string)
if output is None or output == u"": if output is None or output == "":
output = u"OSStatus %s" % error output = f"OSStatus {error}"
if exception_class is None: if exception_class is None:
exception_class = ssl.SSLError exception_class = ssl.SSLError
@ -147,7 +165,7 @@ def _assert_no_error(error, exception_class=None):
raise exception_class(output) raise exception_class(output)
def _cert_array_from_pem(pem_bundle): def _cert_array_from_pem(pem_bundle: bytes) -> CFArray:
""" """
Given a bundle of certs in PEM format, turns them into a CFArray of certs Given a bundle of certs in PEM format, turns them into a CFArray of certs
that can be used to validate a cert chain. that can be used to validate a cert chain.
@ -193,23 +211,23 @@ def _cert_array_from_pem(pem_bundle):
return cert_array return cert_array
def _is_cert(item): def _is_cert(item: CFTypeRef) -> bool:
""" """
Returns True if a given CFTypeRef is a certificate. Returns True if a given CFTypeRef is a certificate.
""" """
expected = Security.SecCertificateGetTypeID() expected = Security.SecCertificateGetTypeID()
return CoreFoundation.CFGetTypeID(item) == expected return CoreFoundation.CFGetTypeID(item) == expected # type: ignore[no-any-return]
def _is_identity(item): def _is_identity(item: CFTypeRef) -> bool:
""" """
Returns True if a given CFTypeRef is an identity. Returns True if a given CFTypeRef is an identity.
""" """
expected = Security.SecIdentityGetTypeID() expected = Security.SecIdentityGetTypeID()
return CoreFoundation.CFGetTypeID(item) == expected return CoreFoundation.CFGetTypeID(item) == expected # type: ignore[no-any-return]
def _temporary_keychain(): def _temporary_keychain() -> tuple[SecKeychainRef, str]:
""" """
This function creates a temporary Mac keychain that we can use to work with This function creates a temporary Mac keychain that we can use to work with
credentials. This keychain uses a one-time password and a temporary file to credentials. This keychain uses a one-time password and a temporary file to
@ -244,7 +262,9 @@ def _temporary_keychain():
return keychain, tempdirectory return keychain, tempdirectory
def _load_items_from_file(keychain, path): def _load_items_from_file(
keychain: SecKeychainRef, path: str
) -> tuple[list[CFTypeRef], list[CFTypeRef]]:
""" """
Given a single file, loads all the trust objects from it into arrays and Given a single file, loads all the trust objects from it into arrays and
the keychain. the keychain.
@ -299,7 +319,7 @@ def _load_items_from_file(keychain, path):
return (identities, certificates) return (identities, certificates)
def _load_client_cert_chain(keychain, *paths): def _load_client_cert_chain(keychain: SecKeychainRef, *paths: str | None) -> CFArray:
""" """
Load certificates and maybe keys from a number of files. Has the end goal Load certificates and maybe keys from a number of files. Has the end goal
of returning a CFArray containing one SecIdentityRef, and then zero or more of returning a CFArray containing one SecIdentityRef, and then zero or more
@ -335,10 +355,10 @@ def _load_client_cert_chain(keychain, *paths):
identities = [] identities = []
# Filter out bad paths. # Filter out bad paths.
paths = (path for path in paths if path) filtered_paths = (path for path in paths if path)
try: try:
for file_path in paths: for file_path in filtered_paths:
new_identities, new_certs = _load_items_from_file(keychain, file_path) new_identities, new_certs = _load_items_from_file(keychain, file_path)
identities.extend(new_identities) identities.extend(new_identities)
certificates.extend(new_certs) certificates.extend(new_certs)
@ -383,7 +403,7 @@ TLS_PROTOCOL_VERSIONS = {
} }
def _build_tls_unknown_ca_alert(version): def _build_tls_unknown_ca_alert(version: str) -> bytes:
""" """
Builds a TLS alert record for an unknown CA. Builds a TLS alert record for an unknown CA.
""" """
@ -395,3 +415,60 @@ def _build_tls_unknown_ca_alert(version):
record_type_alert = 0x15 record_type_alert = 0x15
record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg
return record return record
class SecurityConst:
"""
A class object that acts as essentially a namespace for Security constants.
"""
kSSLSessionOptionBreakOnServerAuth = 0
kSSLProtocol2 = 1
kSSLProtocol3 = 2
kTLSProtocol1 = 4
kTLSProtocol11 = 7
kTLSProtocol12 = 8
# SecureTransport does not support TLS 1.3 even if there's a constant for it
kTLSProtocol13 = 10
kTLSProtocolMaxSupported = 999
kSSLClientSide = 1
kSSLStreamType = 0
kSecFormatPEMSequence = 10
kSecTrustResultInvalid = 0
kSecTrustResultProceed = 1
# This gap is present on purpose: this was kSecTrustResultConfirm, which
# is deprecated.
kSecTrustResultDeny = 3
kSecTrustResultUnspecified = 4
kSecTrustResultRecoverableTrustFailure = 5
kSecTrustResultFatalTrustFailure = 6
kSecTrustResultOtherError = 7
errSSLProtocol = -9800
errSSLWouldBlock = -9803
errSSLClosedGraceful = -9805
errSSLClosedNoNotify = -9816
errSSLClosedAbort = -9806
errSSLXCertChainInvalid = -9807
errSSLCrypto = -9809
errSSLInternal = -9810
errSSLCertExpired = -9814
errSSLCertNotYetValid = -9815
errSSLUnknownRootCert = -9812
errSSLNoRootCert = -9813
errSSLHostNameMismatch = -9843
errSSLPeerHandshakeFail = -9824
errSSLPeerUserCancelled = -9839
errSSLWeakPeerEphemeralDHKey = -9850
errSSLServerAuthCompleted = -9841
errSSLRecordOverflow = -9847
errSecVerifyFailed = -67808
errSecNoTrustSettings = -25263
errSecItemNotFound = -25300
errSecInvalidTrustSettings = -25262

View file

@ -1,314 +0,0 @@
"""
This module provides a pool manager that uses Google App Engine's
`URLFetch Service <https://cloud.google.com/appengine/docs/python/urlfetch>`_.
Example usage::
from urllib3 import PoolManager
from urllib3.contrib.appengine import AppEngineManager, is_appengine_sandbox
if is_appengine_sandbox():
# AppEngineManager uses AppEngine's URLFetch API behind the scenes
http = AppEngineManager()
else:
# PoolManager uses a socket-level API behind the scenes
http = PoolManager()
r = http.request('GET', 'https://google.com/')
There are `limitations <https://cloud.google.com/appengine/docs/python/\
urlfetch/#Python_Quotas_and_limits>`_ to the URLFetch service and it may not be
the best choice for your application. There are three options for using
urllib3 on Google App Engine:
1. You can use :class:`AppEngineManager` with URLFetch. URLFetch is
cost-effective in many circumstances as long as your usage is within the
limitations.
2. You can use a normal :class:`~urllib3.PoolManager` by enabling sockets.
Sockets also have `limitations and restrictions
<https://cloud.google.com/appengine/docs/python/sockets/\
#limitations-and-restrictions>`_ and have a lower free quota than URLFetch.
To use sockets, be sure to specify the following in your ``app.yaml``::
env_variables:
GAE_USE_SOCKETS_HTTPLIB : 'true'
3. If you are using `App Engine Flexible
<https://cloud.google.com/appengine/docs/flexible/>`_, you can use the standard
:class:`PoolManager` without any configuration or special environment variables.
"""
from __future__ import absolute_import
import io
import logging
import warnings
from ..exceptions import (
HTTPError,
HTTPWarning,
MaxRetryError,
ProtocolError,
SSLError,
TimeoutError,
)
from ..packages.six.moves.urllib.parse import urljoin
from ..request import RequestMethods
from ..response import HTTPResponse
from ..util.retry import Retry
from ..util.timeout import Timeout
from . import _appengine_environ
try:
from google.appengine.api import urlfetch
except ImportError:
urlfetch = None
log = logging.getLogger(__name__)
class AppEnginePlatformWarning(HTTPWarning):
pass
class AppEnginePlatformError(HTTPError):
pass
class AppEngineManager(RequestMethods):
"""
Connection manager for Google App Engine sandbox applications.
This manager uses the URLFetch service directly instead of using the
emulated httplib, and is subject to URLFetch limitations as described in
the App Engine documentation `here
<https://cloud.google.com/appengine/docs/python/urlfetch>`_.
Notably it will raise an :class:`AppEnginePlatformError` if:
* URLFetch is not available.
* If you attempt to use this on App Engine Flexible, as full socket
support is available.
* If a request size is more than 10 megabytes.
* If a response size is more than 32 megabytes.
* If you use an unsupported request method such as OPTIONS.
Beyond those cases, it will raise normal urllib3 errors.
"""
def __init__(
self,
headers=None,
retries=None,
validate_certificate=True,
urlfetch_retries=True,
):
if not urlfetch:
raise AppEnginePlatformError(
"URLFetch is not available in this environment."
)
warnings.warn(
"urllib3 is using URLFetch on Google App Engine sandbox instead "
"of sockets. To use sockets directly instead of URLFetch see "
"https://urllib3.readthedocs.io/en/1.26.x/reference/urllib3.contrib.html.",
AppEnginePlatformWarning,
)
RequestMethods.__init__(self, headers)
self.validate_certificate = validate_certificate
self.urlfetch_retries = urlfetch_retries
self.retries = retries or Retry.DEFAULT
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Return False to re-raise any potential exceptions
return False
def urlopen(
self,
method,
url,
body=None,
headers=None,
retries=None,
redirect=True,
timeout=Timeout.DEFAULT_TIMEOUT,
**response_kw
):
retries = self._get_retries(retries, redirect)
try:
follow_redirects = redirect and retries.redirect != 0 and retries.total
response = urlfetch.fetch(
url,
payload=body,
method=method,
headers=headers or {},
allow_truncated=False,
follow_redirects=self.urlfetch_retries and follow_redirects,
deadline=self._get_absolute_timeout(timeout),
validate_certificate=self.validate_certificate,
)
except urlfetch.DeadlineExceededError as e:
raise TimeoutError(self, e)
except urlfetch.InvalidURLError as e:
if "too large" in str(e):
raise AppEnginePlatformError(
"URLFetch request too large, URLFetch only "
"supports requests up to 10mb in size.",
e,
)
raise ProtocolError(e)
except urlfetch.DownloadError as e:
if "Too many redirects" in str(e):
raise MaxRetryError(self, url, reason=e)
raise ProtocolError(e)
except urlfetch.ResponseTooLargeError as e:
raise AppEnginePlatformError(
"URLFetch response too large, URLFetch only supports"
"responses up to 32mb in size.",
e,
)
except urlfetch.SSLCertificateError as e:
raise SSLError(e)
except urlfetch.InvalidMethodError as e:
raise AppEnginePlatformError(
"URLFetch does not support method: %s" % method, e
)
http_response = self._urlfetch_response_to_http_response(
response, retries=retries, **response_kw
)
# Handle redirect?
redirect_location = redirect and http_response.get_redirect_location()
if redirect_location:
# Check for redirect response
if self.urlfetch_retries and retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects")
else:
if http_response.status == 303:
method = "GET"
try:
retries = retries.increment(
method, url, response=http_response, _pool=self
)
except MaxRetryError:
if retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects")
return http_response
retries.sleep_for_retry(http_response)
log.debug("Redirecting %s -> %s", url, redirect_location)
redirect_url = urljoin(url, redirect_location)
return self.urlopen(
method,
redirect_url,
body,
headers,
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw
)
# Check if we should retry the HTTP response.
has_retry_after = bool(http_response.headers.get("Retry-After"))
if retries.is_retry(method, http_response.status, has_retry_after):
retries = retries.increment(method, url, response=http_response, _pool=self)
log.debug("Retry: %s", url)
retries.sleep(http_response)
return self.urlopen(
method,
url,
body=body,
headers=headers,
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw
)
return http_response
def _urlfetch_response_to_http_response(self, urlfetch_resp, **response_kw):
if is_prod_appengine():
# Production GAE handles deflate encoding automatically, but does
# not remove the encoding header.
content_encoding = urlfetch_resp.headers.get("content-encoding")
if content_encoding == "deflate":
del urlfetch_resp.headers["content-encoding"]
transfer_encoding = urlfetch_resp.headers.get("transfer-encoding")
# We have a full response's content,
# so let's make sure we don't report ourselves as chunked data.
if transfer_encoding == "chunked":
encodings = transfer_encoding.split(",")
encodings.remove("chunked")
urlfetch_resp.headers["transfer-encoding"] = ",".join(encodings)
original_response = HTTPResponse(
# In order for decoding to work, we must present the content as
# a file-like object.
body=io.BytesIO(urlfetch_resp.content),
msg=urlfetch_resp.header_msg,
headers=urlfetch_resp.headers,
status=urlfetch_resp.status_code,
**response_kw
)
return HTTPResponse(
body=io.BytesIO(urlfetch_resp.content),
headers=urlfetch_resp.headers,
status=urlfetch_resp.status_code,
original_response=original_response,
**response_kw
)
def _get_absolute_timeout(self, timeout):
if timeout is Timeout.DEFAULT_TIMEOUT:
return None # Defer to URLFetch's default.
if isinstance(timeout, Timeout):
if timeout._read is not None or timeout._connect is not None:
warnings.warn(
"URLFetch does not support granular timeout settings, "
"reverting to total or default URLFetch timeout.",
AppEnginePlatformWarning,
)
return timeout.total
return timeout
def _get_retries(self, retries, redirect):
if not isinstance(retries, Retry):
retries = Retry.from_int(retries, redirect=redirect, default=self.retries)
if retries.connect or retries.read or retries.redirect:
warnings.warn(
"URLFetch only supports total retries and does not "
"recognize connect, read, or redirect retry parameters.",
AppEnginePlatformWarning,
)
return retries
# Alias methods from _appengine_environ to maintain public API interface.
is_appengine = _appengine_environ.is_appengine
is_appengine_sandbox = _appengine_environ.is_appengine_sandbox
is_local_appengine = _appengine_environ.is_local_appengine
is_prod_appengine = _appengine_environ.is_prod_appengine
is_prod_appengine_mvms = _appengine_environ.is_prod_appengine_mvms

View file

@ -1,130 +0,0 @@
"""
NTLM authenticating pool, contributed by erikcederstran
Issue #10, see: http://code.google.com/p/urllib3/issues/detail?id=10
"""
from __future__ import absolute_import
import warnings
from logging import getLogger
from ntlm import ntlm
from .. import HTTPSConnectionPool
from ..packages.six.moves.http_client import HTTPSConnection
warnings.warn(
"The 'urllib3.contrib.ntlmpool' module is deprecated and will be removed "
"in urllib3 v2.0 release, urllib3 is not able to support it properly due "
"to reasons listed in issue: https://github.com/urllib3/urllib3/issues/2282. "
"If you are a user of this module please comment in the mentioned issue.",
DeprecationWarning,
)
log = getLogger(__name__)
class NTLMConnectionPool(HTTPSConnectionPool):
"""
Implements an NTLM authentication version of an urllib3 connection pool
"""
scheme = "https"
def __init__(self, user, pw, authurl, *args, **kwargs):
"""
authurl is a random URL on the server that is protected by NTLM.
user is the Windows user, probably in the DOMAIN\\username format.
pw is the password for the user.
"""
super(NTLMConnectionPool, self).__init__(*args, **kwargs)
self.authurl = authurl
self.rawuser = user
user_parts = user.split("\\", 1)
self.domain = user_parts[0].upper()
self.user = user_parts[1]
self.pw = pw
def _new_conn(self):
# Performs the NTLM handshake that secures the connection. The socket
# must be kept open while requests are performed.
self.num_connections += 1
log.debug(
"Starting NTLM HTTPS connection no. %d: https://%s%s",
self.num_connections,
self.host,
self.authurl,
)
headers = {"Connection": "Keep-Alive"}
req_header = "Authorization"
resp_header = "www-authenticate"
conn = HTTPSConnection(host=self.host, port=self.port)
# Send negotiation message
headers[req_header] = "NTLM %s" % ntlm.create_NTLM_NEGOTIATE_MESSAGE(
self.rawuser
)
log.debug("Request headers: %s", headers)
conn.request("GET", self.authurl, None, headers)
res = conn.getresponse()
reshdr = dict(res.headers)
log.debug("Response status: %s %s", res.status, res.reason)
log.debug("Response headers: %s", reshdr)
log.debug("Response data: %s [...]", res.read(100))
# Remove the reference to the socket, so that it can not be closed by
# the response object (we want to keep the socket open)
res.fp = None
# Server should respond with a challenge message
auth_header_values = reshdr[resp_header].split(", ")
auth_header_value = None
for s in auth_header_values:
if s[:5] == "NTLM ":
auth_header_value = s[5:]
if auth_header_value is None:
raise Exception(
"Unexpected %s response header: %s" % (resp_header, reshdr[resp_header])
)
# Send authentication message
ServerChallenge, NegotiateFlags = ntlm.parse_NTLM_CHALLENGE_MESSAGE(
auth_header_value
)
auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(
ServerChallenge, self.user, self.domain, self.pw, NegotiateFlags
)
headers[req_header] = "NTLM %s" % auth_msg
log.debug("Request headers: %s", headers)
conn.request("GET", self.authurl, None, headers)
res = conn.getresponse()
log.debug("Response status: %s %s", res.status, res.reason)
log.debug("Response headers: %s", dict(res.headers))
log.debug("Response data: %s [...]", res.read()[:100])
if res.status != 200:
if res.status == 401:
raise Exception("Server rejected request: wrong username or password")
raise Exception("Wrong server response: %s %s" % (res.status, res.reason))
res.fp = None
log.debug("Connection established")
return conn
def urlopen(
self,
method,
url,
body=None,
headers=None,
retries=3,
redirect=True,
assert_same_host=True,
):
if headers is None:
headers = {}
headers["Connection"] = "Keep-Alive"
return super(NTLMConnectionPool, self).urlopen(
method, url, body, headers, retries, redirect, assert_same_host
)

View file

@ -1,8 +1,8 @@
""" """
TLS with SNI_-support for Python 2. Follow these instructions if you would Module for using pyOpenSSL as a TLS backend. This module was relevant before
like to verify TLS certificates in Python 2. Note, the default libraries do the standard library ``ssl`` module supported SNI, but now that we've dropped
*not* do certificate checking; you need to do additional work to validate support for Python 2.7 all relevant Python versions support SNI so
certificates yourself. **this module is no longer recommended**.
This needs the following packages installed: This needs the following packages installed:
@ -10,7 +10,7 @@ This needs the following packages installed:
* `cryptography`_ (minimum 1.3.4, from pyopenssl) * `cryptography`_ (minimum 1.3.4, from pyopenssl)
* `idna`_ (minimum 2.0, from cryptography) * `idna`_ (minimum 2.0, from cryptography)
However, pyopenssl depends on cryptography, which depends on idna, so while we However, pyOpenSSL depends on cryptography, which depends on idna, so while we
use all three directly here we end up having relatively few packages required. use all three directly here we end up having relatively few packages required.
You can install them with the following command: You can install them with the following command:
@ -33,75 +33,55 @@ like this:
except ImportError: except ImportError:
pass pass
Now you can use :mod:`urllib3` as you normally would, and it will support SNI
when the required modules are installed.
Activating this module also has the positive side effect of disabling SSL/TLS
compression in Python 2 (see `CRIME attack`_).
.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication
.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit)
.. _pyopenssl: https://www.pyopenssl.org .. _pyopenssl: https://www.pyopenssl.org
.. _cryptography: https://cryptography.io .. _cryptography: https://cryptography.io
.. _idna: https://github.com/kjd/idna .. _idna: https://github.com/kjd/idna
""" """
from __future__ import absolute_import
import OpenSSL.crypto from __future__ import annotations
import OpenSSL.SSL
import OpenSSL.SSL # type: ignore[import]
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.backends.openssl import backend as openssl_backend
try: try:
from cryptography.x509 import UnsupportedExtension from cryptography.x509 import UnsupportedExtension # type: ignore[attr-defined]
except ImportError: except ImportError:
# UnsupportedExtension is gone in cryptography >= 2.1.0 # UnsupportedExtension is gone in cryptography >= 2.1.0
class UnsupportedExtension(Exception): class UnsupportedExtension(Exception): # type: ignore[no-redef]
pass pass
from io import BytesIO
from socket import error as SocketError
from socket import timeout
try: # Platform-specific: Python 2
from socket import _fileobject
except ImportError: # Platform-specific: Python 3
_fileobject = None
from ..packages.backports.makefile import backport_makefile
import logging import logging
import ssl import ssl
import sys import typing
import warnings import warnings
from io import BytesIO
from socket import socket as socket_cls
from socket import timeout
from .. import util from .. import util
from ..packages import six
from ..util.ssl_ import PROTOCOL_TLS_CLIENT
warnings.warn( warnings.warn(
"'urllib3.contrib.pyopenssl' module is deprecated and will be removed " "'urllib3.contrib.pyopenssl' module is deprecated and will be removed "
"in a future release of urllib3 2.x. Read more in this issue: " "in urllib3 v2.1.0. Read more in this issue: "
"https://github.com/urllib3/urllib3/issues/2680", "https://github.com/urllib3/urllib3/issues/2680",
category=DeprecationWarning, category=DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
__all__ = ["inject_into_urllib3", "extract_from_urllib3"] if typing.TYPE_CHECKING:
from OpenSSL.crypto import X509 # type: ignore[import]
# SNI always works.
HAS_SNI = True __all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# Map from urllib3 to PyOpenSSL compatible parameter-values. # Map from urllib3 to PyOpenSSL compatible parameter-values.
_openssl_versions = { _openssl_versions = {
util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, util.ssl_.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined]
PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, util.ssl_.PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined]
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
} }
if hasattr(ssl, "PROTOCOL_SSLv3") and hasattr(OpenSSL.SSL, "SSLv3_METHOD"):
_openssl_versions[ssl.PROTOCOL_SSLv3] = OpenSSL.SSL.SSLv3_METHOD
if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"): if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD _openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
@ -115,43 +95,77 @@ _stdlib_to_openssl_verify = {
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
} }
_openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items()) _openssl_to_stdlib_verify = {v: k for k, v in _stdlib_to_openssl_verify.items()}
# The SSLvX values are the most likely to be missing in the future
# but we check them all just to be sure.
_OP_NO_SSLv2_OR_SSLv3: int = getattr(OpenSSL.SSL, "OP_NO_SSLv2", 0) | getattr(
OpenSSL.SSL, "OP_NO_SSLv3", 0
)
_OP_NO_TLSv1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1", 0)
_OP_NO_TLSv1_1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_1", 0)
_OP_NO_TLSv1_2: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_2", 0)
_OP_NO_TLSv1_3: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_3", 0)
_openssl_to_ssl_minimum_version: dict[int, int] = {
ssl.TLSVersion.MINIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3,
ssl.TLSVersion.TLSv1: _OP_NO_SSLv2_OR_SSLv3,
ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1,
ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1,
ssl.TLSVersion.TLSv1_3: (
_OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2
),
ssl.TLSVersion.MAXIMUM_SUPPORTED: (
_OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2
),
}
_openssl_to_ssl_maximum_version: dict[int, int] = {
ssl.TLSVersion.MINIMUM_SUPPORTED: (
_OP_NO_SSLv2_OR_SSLv3
| _OP_NO_TLSv1
| _OP_NO_TLSv1_1
| _OP_NO_TLSv1_2
| _OP_NO_TLSv1_3
),
ssl.TLSVersion.TLSv1: (
_OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3
),
ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3,
ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_3,
ssl.TLSVersion.TLSv1_3: _OP_NO_SSLv2_OR_SSLv3,
ssl.TLSVersion.MAXIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3,
}
# OpenSSL will only write 16K at a time # OpenSSL will only write 16K at a time
SSL_WRITE_BLOCKSIZE = 16384 SSL_WRITE_BLOCKSIZE = 16384
orig_util_HAS_SNI = util.HAS_SNI
orig_util_SSLContext = util.ssl_.SSLContext orig_util_SSLContext = util.ssl_.SSLContext
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def inject_into_urllib3(): def inject_into_urllib3() -> None:
"Monkey-patch urllib3 with PyOpenSSL-backed SSL-support." "Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
_validate_dependencies_met() _validate_dependencies_met()
util.SSLContext = PyOpenSSLContext util.SSLContext = PyOpenSSLContext # type: ignore[assignment]
util.ssl_.SSLContext = PyOpenSSLContext util.ssl_.SSLContext = PyOpenSSLContext # type: ignore[assignment]
util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI
util.IS_PYOPENSSL = True util.IS_PYOPENSSL = True
util.ssl_.IS_PYOPENSSL = True util.ssl_.IS_PYOPENSSL = True
def extract_from_urllib3(): def extract_from_urllib3() -> None:
"Undo monkey-patching by :func:`inject_into_urllib3`." "Undo monkey-patching by :func:`inject_into_urllib3`."
util.SSLContext = orig_util_SSLContext util.SSLContext = orig_util_SSLContext
util.ssl_.SSLContext = orig_util_SSLContext util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
util.IS_PYOPENSSL = False util.IS_PYOPENSSL = False
util.ssl_.IS_PYOPENSSL = False util.ssl_.IS_PYOPENSSL = False
def _validate_dependencies_met(): def _validate_dependencies_met() -> None:
""" """
Verifies that PyOpenSSL's package-level dependencies have been met. Verifies that PyOpenSSL's package-level dependencies have been met.
Throws `ImportError` if they are not met. Throws `ImportError` if they are not met.
@ -177,7 +191,7 @@ def _validate_dependencies_met():
) )
def _dnsname_to_stdlib(name): def _dnsname_to_stdlib(name: str) -> str | None:
""" """
Converts a dNSName SubjectAlternativeName field to the form used by the Converts a dNSName SubjectAlternativeName field to the form used by the
standard library on the given Python version. standard library on the given Python version.
@ -191,7 +205,7 @@ def _dnsname_to_stdlib(name):
the name given should be skipped. the name given should be skipped.
""" """
def idna_encode(name): def idna_encode(name: str) -> bytes | None:
""" """
Borrowed wholesale from the Python Cryptography Project. It turns out Borrowed wholesale from the Python Cryptography Project. It turns out
that we can't just safely call `idna.encode`: it can explode for that we can't just safely call `idna.encode`: it can explode for
@ -200,7 +214,7 @@ def _dnsname_to_stdlib(name):
import idna import idna
try: try:
for prefix in [u"*.", u"."]: for prefix in ["*.", "."]:
if name.startswith(prefix): if name.startswith(prefix):
name = name[len(prefix) :] name = name[len(prefix) :]
return prefix.encode("ascii") + idna.encode(name) return prefix.encode("ascii") + idna.encode(name)
@ -212,24 +226,17 @@ def _dnsname_to_stdlib(name):
if ":" in name: if ":" in name:
return name return name
name = idna_encode(name) encoded_name = idna_encode(name)
if name is None: if encoded_name is None:
return None return None
elif sys.version_info >= (3, 0): return encoded_name.decode("utf-8")
name = name.decode("utf-8")
return name
def get_subj_alt_name(peer_cert): def get_subj_alt_name(peer_cert: X509) -> list[tuple[str, str]]:
""" """
Given an PyOpenSSL certificate, provides all the subject alternative names. Given an PyOpenSSL certificate, provides all the subject alternative names.
""" """
# Pass the cert to cryptography, which has much better APIs for this. cert = peer_cert.to_cryptography()
if hasattr(peer_cert, "to_cryptography"):
cert = peer_cert.to_cryptography()
else:
der = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, peer_cert)
cert = x509.load_der_x509_certificate(der, openssl_backend)
# We want to find the SAN extension. Ask Cryptography to locate it (it's # We want to find the SAN extension. Ask Cryptography to locate it (it's
# faster than looping in Python) # faster than looping in Python)
@ -273,93 +280,94 @@ def get_subj_alt_name(peer_cert):
return names return names
class WrappedSocket(object): class WrappedSocket:
"""API-compatibility wrapper for Python OpenSSL's Connection-class. """API-compatibility wrapper for Python OpenSSL's Connection-class."""
Note: _makefile_refs, _drop() and _reuse() are needed for the garbage def __init__(
collector of pypy. self,
""" connection: OpenSSL.SSL.Connection,
socket: socket_cls,
def __init__(self, connection, socket, suppress_ragged_eofs=True): suppress_ragged_eofs: bool = True,
) -> None:
self.connection = connection self.connection = connection
self.socket = socket self.socket = socket
self.suppress_ragged_eofs = suppress_ragged_eofs self.suppress_ragged_eofs = suppress_ragged_eofs
self._makefile_refs = 0 self._io_refs = 0
self._closed = False self._closed = False
def fileno(self): def fileno(self) -> int:
return self.socket.fileno() return self.socket.fileno()
# Copy-pasted from Python 3.5 source code # Copy-pasted from Python 3.5 source code
def _decref_socketios(self): def _decref_socketios(self) -> None:
if self._makefile_refs > 0: if self._io_refs > 0:
self._makefile_refs -= 1 self._io_refs -= 1
if self._closed: if self._closed:
self.close() self.close()
def recv(self, *args, **kwargs): def recv(self, *args: typing.Any, **kwargs: typing.Any) -> bytes:
try: try:
data = self.connection.recv(*args, **kwargs) data = self.connection.recv(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e: except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return b"" return b""
else: else:
raise SocketError(str(e)) raise OSError(e.args[0], str(e)) from e
except OpenSSL.SSL.ZeroReturnError: except OpenSSL.SSL.ZeroReturnError:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return b"" return b""
else: else:
raise raise
except OpenSSL.SSL.WantReadError: except OpenSSL.SSL.WantReadError as e:
if not util.wait_for_read(self.socket, self.socket.gettimeout()): if not util.wait_for_read(self.socket, self.socket.gettimeout()):
raise timeout("The read operation timed out") raise timeout("The read operation timed out") from e
else: else:
return self.recv(*args, **kwargs) return self.recv(*args, **kwargs)
# TLS 1.3 post-handshake authentication # TLS 1.3 post-handshake authentication
except OpenSSL.SSL.Error as e: except OpenSSL.SSL.Error as e:
raise ssl.SSLError("read error: %r" % e) raise ssl.SSLError(f"read error: {e!r}") from e
else: else:
return data return data # type: ignore[no-any-return]
def recv_into(self, *args, **kwargs): def recv_into(self, *args: typing.Any, **kwargs: typing.Any) -> int:
try: try:
return self.connection.recv_into(*args, **kwargs) return self.connection.recv_into(*args, **kwargs) # type: ignore[no-any-return]
except OpenSSL.SSL.SysCallError as e: except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return 0 return 0
else: else:
raise SocketError(str(e)) raise OSError(e.args[0], str(e)) from e
except OpenSSL.SSL.ZeroReturnError: except OpenSSL.SSL.ZeroReturnError:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return 0 return 0
else: else:
raise raise
except OpenSSL.SSL.WantReadError: except OpenSSL.SSL.WantReadError as e:
if not util.wait_for_read(self.socket, self.socket.gettimeout()): if not util.wait_for_read(self.socket, self.socket.gettimeout()):
raise timeout("The read operation timed out") raise timeout("The read operation timed out") from e
else: else:
return self.recv_into(*args, **kwargs) return self.recv_into(*args, **kwargs)
# TLS 1.3 post-handshake authentication # TLS 1.3 post-handshake authentication
except OpenSSL.SSL.Error as e: except OpenSSL.SSL.Error as e:
raise ssl.SSLError("read error: %r" % e) raise ssl.SSLError(f"read error: {e!r}") from e
def settimeout(self, timeout): def settimeout(self, timeout: float) -> None:
return self.socket.settimeout(timeout) return self.socket.settimeout(timeout)
def _send_until_done(self, data): def _send_until_done(self, data: bytes) -> int:
while True: while True:
try: try:
return self.connection.send(data) return self.connection.send(data) # type: ignore[no-any-return]
except OpenSSL.SSL.WantWriteError: except OpenSSL.SSL.WantWriteError as e:
if not util.wait_for_write(self.socket, self.socket.gettimeout()): if not util.wait_for_write(self.socket, self.socket.gettimeout()):
raise timeout() raise timeout() from e
continue continue
except OpenSSL.SSL.SysCallError as e: except OpenSSL.SSL.SysCallError as e:
raise SocketError(str(e)) raise OSError(e.args[0], str(e)) from e
def sendall(self, data): def sendall(self, data: bytes) -> None:
total_sent = 0 total_sent = 0
while total_sent < len(data): while total_sent < len(data):
sent = self._send_until_done( sent = self._send_until_done(
@ -367,135 +375,135 @@ class WrappedSocket(object):
) )
total_sent += sent total_sent += sent
def shutdown(self): def shutdown(self) -> None:
# FIXME rethrow compatible exceptions should we ever use this # FIXME rethrow compatible exceptions should we ever use this
self.connection.shutdown() self.connection.shutdown()
def close(self): def close(self) -> None:
if self._makefile_refs < 1: self._closed = True
try: if self._io_refs <= 0:
self._closed = True self._real_close()
return self.connection.close()
except OpenSSL.SSL.Error:
return
else:
self._makefile_refs -= 1
def getpeercert(self, binary_form=False): def _real_close(self) -> None:
try:
return self.connection.close() # type: ignore[no-any-return]
except OpenSSL.SSL.Error:
return
def getpeercert(
self, binary_form: bool = False
) -> dict[str, list[typing.Any]] | None:
x509 = self.connection.get_peer_certificate() x509 = self.connection.get_peer_certificate()
if not x509: if not x509:
return x509 return x509 # type: ignore[no-any-return]
if binary_form: if binary_form:
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) # type: ignore[no-any-return]
return { return {
"subject": ((("commonName", x509.get_subject().CN),),), "subject": ((("commonName", x509.get_subject().CN),),), # type: ignore[dict-item]
"subjectAltName": get_subj_alt_name(x509), "subjectAltName": get_subj_alt_name(x509),
} }
def version(self): def version(self) -> str:
return self.connection.get_protocol_version_name() return self.connection.get_protocol_version_name() # type: ignore[no-any-return]
def _reuse(self):
self._makefile_refs += 1
def _drop(self):
if self._makefile_refs < 1:
self.close()
else:
self._makefile_refs -= 1
if _fileobject: # Platform-specific: Python 2 WrappedSocket.makefile = socket_cls.makefile # type: ignore[attr-defined]
def makefile(self, mode, bufsize=-1):
self._makefile_refs += 1
return _fileobject(self, mode, bufsize, close=True)
else: # Platform-specific: Python 3
makefile = backport_makefile
WrappedSocket.makefile = makefile
class PyOpenSSLContext(object): class PyOpenSSLContext:
""" """
I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible
for translating the interface of the standard library ``SSLContext`` object for translating the interface of the standard library ``SSLContext`` object
to calls into PyOpenSSL. to calls into PyOpenSSL.
""" """
def __init__(self, protocol): def __init__(self, protocol: int) -> None:
self.protocol = _openssl_versions[protocol] self.protocol = _openssl_versions[protocol]
self._ctx = OpenSSL.SSL.Context(self.protocol) self._ctx = OpenSSL.SSL.Context(self.protocol)
self._options = 0 self._options = 0
self.check_hostname = False self.check_hostname = False
self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED
self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED
@property @property
def options(self): def options(self) -> int:
return self._options return self._options
@options.setter @options.setter
def options(self, value): def options(self, value: int) -> None:
self._options = value self._options = value
self._ctx.set_options(value) self._set_ctx_options()
@property @property
def verify_mode(self): def verify_mode(self) -> int:
return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()] return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()]
@verify_mode.setter @verify_mode.setter
def verify_mode(self, value): def verify_mode(self, value: ssl.VerifyMode) -> None:
self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback) self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
def set_default_verify_paths(self): def set_default_verify_paths(self) -> None:
self._ctx.set_default_verify_paths() self._ctx.set_default_verify_paths()
def set_ciphers(self, ciphers): def set_ciphers(self, ciphers: bytes | str) -> None:
if isinstance(ciphers, six.text_type): if isinstance(ciphers, str):
ciphers = ciphers.encode("utf-8") ciphers = ciphers.encode("utf-8")
self._ctx.set_cipher_list(ciphers) self._ctx.set_cipher_list(ciphers)
def load_verify_locations(self, cafile=None, capath=None, cadata=None): def load_verify_locations(
self,
cafile: str | None = None,
capath: str | None = None,
cadata: bytes | None = None,
) -> None:
if cafile is not None: if cafile is not None:
cafile = cafile.encode("utf-8") cafile = cafile.encode("utf-8") # type: ignore[assignment]
if capath is not None: if capath is not None:
capath = capath.encode("utf-8") capath = capath.encode("utf-8") # type: ignore[assignment]
try: try:
self._ctx.load_verify_locations(cafile, capath) self._ctx.load_verify_locations(cafile, capath)
if cadata is not None: if cadata is not None:
self._ctx.load_verify_locations(BytesIO(cadata)) self._ctx.load_verify_locations(BytesIO(cadata))
except OpenSSL.SSL.Error as e: except OpenSSL.SSL.Error as e:
raise ssl.SSLError("unable to load trusted certificates: %r" % e) raise ssl.SSLError(f"unable to load trusted certificates: {e!r}") from e
def load_cert_chain(self, certfile, keyfile=None, password=None): def load_cert_chain(
self._ctx.use_certificate_chain_file(certfile) self,
if password is not None: certfile: str,
if not isinstance(password, six.binary_type): keyfile: str | None = None,
password = password.encode("utf-8") password: str | None = None,
self._ctx.set_passwd_cb(lambda *_: password) ) -> None:
self._ctx.use_privatekey_file(keyfile or certfile) try:
self._ctx.use_certificate_chain_file(certfile)
if password is not None:
if not isinstance(password, bytes):
password = password.encode("utf-8") # type: ignore[assignment]
self._ctx.set_passwd_cb(lambda *_: password)
self._ctx.use_privatekey_file(keyfile or certfile)
except OpenSSL.SSL.Error as e:
raise ssl.SSLError(f"Unable to load certificate chain: {e!r}") from e
def set_alpn_protocols(self, protocols): def set_alpn_protocols(self, protocols: list[bytes | str]) -> None:
protocols = [six.ensure_binary(p) for p in protocols] protocols = [util.util.to_bytes(p, "ascii") for p in protocols]
return self._ctx.set_alpn_protos(protocols) return self._ctx.set_alpn_protos(protocols) # type: ignore[no-any-return]
def wrap_socket( def wrap_socket(
self, self,
sock, sock: socket_cls,
server_side=False, server_side: bool = False,
do_handshake_on_connect=True, do_handshake_on_connect: bool = True,
suppress_ragged_eofs=True, suppress_ragged_eofs: bool = True,
server_hostname=None, server_hostname: bytes | str | None = None,
): ) -> WrappedSocket:
cnx = OpenSSL.SSL.Connection(self._ctx, sock) cnx = OpenSSL.SSL.Connection(self._ctx, sock)
if isinstance(server_hostname, six.text_type): # Platform-specific: Python 3 # If server_hostname is an IP, don't use it for SNI, per RFC6066 Section 3
server_hostname = server_hostname.encode("utf-8") if server_hostname and not util.ssl_.is_ipaddress(server_hostname):
if isinstance(server_hostname, str):
if server_hostname is not None: server_hostname = server_hostname.encode("utf-8")
cnx.set_tlsext_host_name(server_hostname) cnx.set_tlsext_host_name(server_hostname)
cnx.set_connect_state() cnx.set_connect_state()
@ -503,16 +511,47 @@ class PyOpenSSLContext(object):
while True: while True:
try: try:
cnx.do_handshake() cnx.do_handshake()
except OpenSSL.SSL.WantReadError: except OpenSSL.SSL.WantReadError as e:
if not util.wait_for_read(sock, sock.gettimeout()): if not util.wait_for_read(sock, sock.gettimeout()):
raise timeout("select timed out") raise timeout("select timed out") from e
continue continue
except OpenSSL.SSL.Error as e: except OpenSSL.SSL.Error as e:
raise ssl.SSLError("bad handshake: %r" % e) raise ssl.SSLError(f"bad handshake: {e!r}") from e
break break
return WrappedSocket(cnx, sock) return WrappedSocket(cnx, sock)
def _set_ctx_options(self) -> None:
self._ctx.set_options(
self._options
| _openssl_to_ssl_minimum_version[self._minimum_version]
| _openssl_to_ssl_maximum_version[self._maximum_version]
)
def _verify_callback(cnx, x509, err_no, err_depth, return_code): @property
def minimum_version(self) -> int:
return self._minimum_version
@minimum_version.setter
def minimum_version(self, minimum_version: int) -> None:
self._minimum_version = minimum_version
self._set_ctx_options()
@property
def maximum_version(self) -> int:
return self._maximum_version
@maximum_version.setter
def maximum_version(self, maximum_version: int) -> None:
self._maximum_version = maximum_version
self._set_ctx_options()
def _verify_callback(
cnx: OpenSSL.SSL.Connection,
x509: X509,
err_no: int,
err_depth: int,
return_code: int,
) -> bool:
return err_no == 0 return err_no == 0

View file

@ -51,7 +51,8 @@ license and by oscrypto's:
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import absolute_import
from __future__ import annotations
import contextlib import contextlib
import ctypes import ctypes
@ -62,14 +63,18 @@ import socket
import ssl import ssl
import struct import struct
import threading import threading
import typing
import warnings
import weakref import weakref
from socket import socket as socket_cls
import six
from .. import util from .. import util
from ..util.ssl_ import PROTOCOL_TLS_CLIENT from ._securetransport.bindings import ( # type: ignore[attr-defined]
from ._securetransport.bindings import CoreFoundation, Security, SecurityConst CoreFoundation,
Security,
)
from ._securetransport.low_level import ( from ._securetransport.low_level import (
SecurityConst,
_assert_no_error, _assert_no_error,
_build_tls_unknown_ca_alert, _build_tls_unknown_ca_alert,
_cert_array_from_pem, _cert_array_from_pem,
@ -78,18 +83,19 @@ from ._securetransport.low_level import (
_temporary_keychain, _temporary_keychain,
) )
try: # Platform-specific: Python 2 warnings.warn(
from socket import _fileobject "'urllib3.contrib.securetransport' module is deprecated and will be removed "
except ImportError: # Platform-specific: Python 3 "in urllib3 v2.1.0. Read more in this issue: "
_fileobject = None "https://github.com/urllib3/urllib3/issues/2681",
from ..packages.backports.makefile import backport_makefile category=DeprecationWarning,
stacklevel=2,
)
if typing.TYPE_CHECKING:
from typing_extensions import Literal
__all__ = ["inject_into_urllib3", "extract_from_urllib3"] __all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# SNI always works
HAS_SNI = True
orig_util_HAS_SNI = util.HAS_SNI
orig_util_SSLContext = util.ssl_.SSLContext orig_util_SSLContext = util.ssl_.SSLContext
# This dictionary is used by the read callback to obtain a handle to the # This dictionary is used by the read callback to obtain a handle to the
@ -108,55 +114,24 @@ orig_util_SSLContext = util.ssl_.SSLContext
# #
# This is good: if we had to lock in the callbacks we'd drastically slow down # This is good: if we had to lock in the callbacks we'd drastically slow down
# the performance of this code. # the performance of this code.
_connection_refs = weakref.WeakValueDictionary() _connection_refs: weakref.WeakValueDictionary[
int, WrappedSocket
] = weakref.WeakValueDictionary()
_connection_ref_lock = threading.Lock() _connection_ref_lock = threading.Lock()
# Limit writes to 16kB. This is OpenSSL's limit, but we'll cargo-cult it over # Limit writes to 16kB. This is OpenSSL's limit, but we'll cargo-cult it over
# for no better reason than we need *a* limit, and this one is right there. # for no better reason than we need *a* limit, and this one is right there.
SSL_WRITE_BLOCKSIZE = 16384 SSL_WRITE_BLOCKSIZE = 16384
# This is our equivalent of util.ssl_.DEFAULT_CIPHERS, but expanded out to
# individual cipher suites. We need to do this because this is how
# SecureTransport wants them.
CIPHER_SUITES = [
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_AES_256_GCM_SHA384,
SecurityConst.TLS_AES_128_GCM_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_AES_128_CCM_8_SHA256,
SecurityConst.TLS_AES_128_CCM_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA256,
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA,
]
# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of # Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of
# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version. # TLSv1 and a high of TLSv1.2. For everything else, we pin to that version.
# TLSv1 to 1.2 are supported on macOS 10.8+ # TLSv1 to 1.2 are supported on macOS 10.8+
_protocol_to_min_max = { _protocol_to_min_max = {
util.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12), util.ssl_.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12), # type: ignore[attr-defined]
PROTOCOL_TLS_CLIENT: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12), util.ssl_.PROTOCOL_TLS_CLIENT: ( # type: ignore[attr-defined]
SecurityConst.kTLSProtocol1,
SecurityConst.kTLSProtocol12,
),
} }
if hasattr(ssl, "PROTOCOL_SSLv2"): if hasattr(ssl, "PROTOCOL_SSLv2"):
@ -186,31 +161,38 @@ if hasattr(ssl, "PROTOCOL_TLSv1_2"):
) )
def inject_into_urllib3(): _tls_version_to_st: dict[int, int] = {
ssl.TLSVersion.MINIMUM_SUPPORTED: SecurityConst.kTLSProtocol1,
ssl.TLSVersion.TLSv1: SecurityConst.kTLSProtocol1,
ssl.TLSVersion.TLSv1_1: SecurityConst.kTLSProtocol11,
ssl.TLSVersion.TLSv1_2: SecurityConst.kTLSProtocol12,
ssl.TLSVersion.MAXIMUM_SUPPORTED: SecurityConst.kTLSProtocol12,
}
def inject_into_urllib3() -> None:
""" """
Monkey-patch urllib3 with SecureTransport-backed SSL-support. Monkey-patch urllib3 with SecureTransport-backed SSL-support.
""" """
util.SSLContext = SecureTransportContext util.SSLContext = SecureTransportContext # type: ignore[assignment]
util.ssl_.SSLContext = SecureTransportContext util.ssl_.SSLContext = SecureTransportContext # type: ignore[assignment]
util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI
util.IS_SECURETRANSPORT = True util.IS_SECURETRANSPORT = True
util.ssl_.IS_SECURETRANSPORT = True util.ssl_.IS_SECURETRANSPORT = True
def extract_from_urllib3(): def extract_from_urllib3() -> None:
""" """
Undo monkey-patching by :func:`inject_into_urllib3`. Undo monkey-patching by :func:`inject_into_urllib3`.
""" """
util.SSLContext = orig_util_SSLContext util.SSLContext = orig_util_SSLContext
util.ssl_.SSLContext = orig_util_SSLContext util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
util.IS_SECURETRANSPORT = False util.IS_SECURETRANSPORT = False
util.ssl_.IS_SECURETRANSPORT = False util.ssl_.IS_SECURETRANSPORT = False
def _read_callback(connection_id, data_buffer, data_length_pointer): def _read_callback(
connection_id: int, data_buffer: int, data_length_pointer: bytearray
) -> int:
""" """
SecureTransport read callback. This is called by ST to request that data SecureTransport read callback. This is called by ST to request that data
be returned from the socket. be returned from the socket.
@ -232,7 +214,7 @@ def _read_callback(connection_id, data_buffer, data_length_pointer):
while read_count < requested_length: while read_count < requested_length:
if timeout is None or timeout >= 0: if timeout is None or timeout >= 0:
if not util.wait_for_read(base_socket, timeout): if not util.wait_for_read(base_socket, timeout):
raise socket.error(errno.EAGAIN, "timed out") raise OSError(errno.EAGAIN, "timed out")
remaining = requested_length - read_count remaining = requested_length - read_count
buffer = (ctypes.c_char * remaining).from_address( buffer = (ctypes.c_char * remaining).from_address(
@ -244,7 +226,7 @@ def _read_callback(connection_id, data_buffer, data_length_pointer):
if not read_count: if not read_count:
return SecurityConst.errSSLClosedGraceful return SecurityConst.errSSLClosedGraceful
break break
except (socket.error) as e: except OSError as e:
error = e.errno error = e.errno
if error is not None and error != errno.EAGAIN: if error is not None and error != errno.EAGAIN:
@ -265,7 +247,9 @@ def _read_callback(connection_id, data_buffer, data_length_pointer):
return SecurityConst.errSSLInternal return SecurityConst.errSSLInternal
def _write_callback(connection_id, data_buffer, data_length_pointer): def _write_callback(
connection_id: int, data_buffer: int, data_length_pointer: bytearray
) -> int:
""" """
SecureTransport write callback. This is called by ST to request that data SecureTransport write callback. This is called by ST to request that data
actually be sent on the network. actually be sent on the network.
@ -288,14 +272,14 @@ def _write_callback(connection_id, data_buffer, data_length_pointer):
while sent < bytes_to_write: while sent < bytes_to_write:
if timeout is None or timeout >= 0: if timeout is None or timeout >= 0:
if not util.wait_for_write(base_socket, timeout): if not util.wait_for_write(base_socket, timeout):
raise socket.error(errno.EAGAIN, "timed out") raise OSError(errno.EAGAIN, "timed out")
chunk_sent = base_socket.send(data) chunk_sent = base_socket.send(data)
sent += chunk_sent sent += chunk_sent
# This has some needless copying here, but I'm not sure there's # This has some needless copying here, but I'm not sure there's
# much value in optimising this data path. # much value in optimising this data path.
data = data[chunk_sent:] data = data[chunk_sent:]
except (socket.error) as e: except OSError as e:
error = e.errno error = e.errno
if error is not None and error != errno.EAGAIN: if error is not None and error != errno.EAGAIN:
@ -323,22 +307,20 @@ _read_callback_pointer = Security.SSLReadFunc(_read_callback)
_write_callback_pointer = Security.SSLWriteFunc(_write_callback) _write_callback_pointer = Security.SSLWriteFunc(_write_callback)
class WrappedSocket(object): class WrappedSocket:
""" """
API-compatibility wrapper for Python's OpenSSL wrapped socket object. API-compatibility wrapper for Python's OpenSSL wrapped socket object.
Note: _makefile_refs, _drop(), and _reuse() are needed for the garbage
collector of PyPy.
""" """
def __init__(self, socket): def __init__(self, socket: socket_cls) -> None:
self.socket = socket self.socket = socket
self.context = None self.context = None
self._makefile_refs = 0 self._io_refs = 0
self._closed = False self._closed = False
self._exception = None self._real_closed = False
self._exception: Exception | None = None
self._keychain = None self._keychain = None
self._keychain_dir = None self._keychain_dir: str | None = None
self._client_cert_chain = None self._client_cert_chain = None
# We save off the previously-configured timeout and then set it to # We save off the previously-configured timeout and then set it to
@ -350,7 +332,7 @@ class WrappedSocket(object):
self.socket.settimeout(0) self.socket.settimeout(0)
@contextlib.contextmanager @contextlib.contextmanager
def _raise_on_error(self): def _raise_on_error(self) -> typing.Generator[None, None, None]:
""" """
A context manager that can be used to wrap calls that do I/O from A context manager that can be used to wrap calls that do I/O from
SecureTransport. If any of the I/O callbacks hit an exception, this SecureTransport. If any of the I/O callbacks hit an exception, this
@ -367,23 +349,10 @@ class WrappedSocket(object):
yield yield
if self._exception is not None: if self._exception is not None:
exception, self._exception = self._exception, None exception, self._exception = self._exception, None
self.close() self._real_close()
raise exception raise exception
def _set_ciphers(self): def _set_alpn_protocols(self, protocols: list[bytes] | None) -> None:
"""
Sets up the allowed ciphers. By default this matches the set in
util.ssl_.DEFAULT_CIPHERS, at least as supported by macOS. This is done
custom and doesn't allow changing at this time, mostly because parsing
OpenSSL cipher strings is going to be a freaking nightmare.
"""
ciphers = (Security.SSLCipherSuite * len(CIPHER_SUITES))(*CIPHER_SUITES)
result = Security.SSLSetEnabledCiphers(
self.context, ciphers, len(CIPHER_SUITES)
)
_assert_no_error(result)
def _set_alpn_protocols(self, protocols):
""" """
Sets up the ALPN protocols on the context. Sets up the ALPN protocols on the context.
""" """
@ -396,7 +365,7 @@ class WrappedSocket(object):
finally: finally:
CoreFoundation.CFRelease(protocols_arr) CoreFoundation.CFRelease(protocols_arr)
def _custom_validate(self, verify, trust_bundle): def _custom_validate(self, verify: bool, trust_bundle: bytes | None) -> None:
""" """
Called when we have set custom validation. We do this in two cases: Called when we have set custom validation. We do this in two cases:
first, when cert validation is entirely disabled; and second, when first, when cert validation is entirely disabled; and second, when
@ -404,7 +373,7 @@ class WrappedSocket(object):
Raises an SSLError if the connection is not trusted. Raises an SSLError if the connection is not trusted.
""" """
# If we disabled cert validation, just say: cool. # If we disabled cert validation, just say: cool.
if not verify: if not verify or trust_bundle is None:
return return
successes = ( successes = (
@ -415,10 +384,12 @@ class WrappedSocket(object):
trust_result = self._evaluate_trust(trust_bundle) trust_result = self._evaluate_trust(trust_bundle)
if trust_result in successes: if trust_result in successes:
return return
reason = "error code: %d" % (trust_result,) reason = f"error code: {int(trust_result)}"
exc = None
except Exception as e: except Exception as e:
# Do not trust on error # Do not trust on error
reason = "exception: %r" % (e,) reason = f"exception: {e!r}"
exc = e
# SecureTransport does not send an alert nor shuts down the connection. # SecureTransport does not send an alert nor shuts down the connection.
rec = _build_tls_unknown_ca_alert(self.version()) rec = _build_tls_unknown_ca_alert(self.version())
@ -428,10 +399,10 @@ class WrappedSocket(object):
# l_linger = 0, linger for 0 seoncds # l_linger = 0, linger for 0 seoncds
opts = struct.pack("ii", 1, 0) opts = struct.pack("ii", 1, 0)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts)
self.close() self._real_close()
raise ssl.SSLError("certificate verify failed, %s" % reason) raise ssl.SSLError(f"certificate verify failed, {reason}") from exc
def _evaluate_trust(self, trust_bundle): def _evaluate_trust(self, trust_bundle: bytes) -> int:
# We want data in memory, so load it up. # We want data in memory, so load it up.
if os.path.isfile(trust_bundle): if os.path.isfile(trust_bundle):
with open(trust_bundle, "rb") as f: with open(trust_bundle, "rb") as f:
@ -469,20 +440,20 @@ class WrappedSocket(object):
if cert_array is not None: if cert_array is not None:
CoreFoundation.CFRelease(cert_array) CoreFoundation.CFRelease(cert_array)
return trust_result.value return trust_result.value # type: ignore[no-any-return]
def handshake( def handshake(
self, self,
server_hostname, server_hostname: bytes | str | None,
verify, verify: bool,
trust_bundle, trust_bundle: bytes | None,
min_version, min_version: int,
max_version, max_version: int,
client_cert, client_cert: str | None,
client_key, client_key: str | None,
client_key_passphrase, client_key_passphrase: typing.Any,
alpn_protocols, alpn_protocols: list[bytes] | None,
): ) -> None:
""" """
Actually performs the TLS handshake. This is run automatically by Actually performs the TLS handshake. This is run automatically by
wrapped socket, and shouldn't be needed in user code. wrapped socket, and shouldn't be needed in user code.
@ -510,6 +481,8 @@ class WrappedSocket(object):
_assert_no_error(result) _assert_no_error(result)
# If we have a server hostname, we should set that too. # If we have a server hostname, we should set that too.
# RFC6066 Section 3 tells us not to use SNI when the host is an IP, but we have
# to do it anyway to match server_hostname against the server certificate
if server_hostname: if server_hostname:
if not isinstance(server_hostname, bytes): if not isinstance(server_hostname, bytes):
server_hostname = server_hostname.encode("utf-8") server_hostname = server_hostname.encode("utf-8")
@ -519,9 +492,6 @@ class WrappedSocket(object):
) )
_assert_no_error(result) _assert_no_error(result)
# Setup the ciphers.
self._set_ciphers()
# Setup the ALPN protocols. # Setup the ALPN protocols.
self._set_alpn_protocols(alpn_protocols) self._set_alpn_protocols(alpn_protocols)
@ -564,25 +534,27 @@ class WrappedSocket(object):
_assert_no_error(result) _assert_no_error(result)
break break
def fileno(self): def fileno(self) -> int:
return self.socket.fileno() return self.socket.fileno()
# Copy-pasted from Python 3.5 source code # Copy-pasted from Python 3.5 source code
def _decref_socketios(self): def _decref_socketios(self) -> None:
if self._makefile_refs > 0: if self._io_refs > 0:
self._makefile_refs -= 1 self._io_refs -= 1
if self._closed: if self._closed:
self.close() self.close()
def recv(self, bufsiz): def recv(self, bufsiz: int) -> bytes:
buffer = ctypes.create_string_buffer(bufsiz) buffer = ctypes.create_string_buffer(bufsiz)
bytes_read = self.recv_into(buffer, bufsiz) bytes_read = self.recv_into(buffer, bufsiz)
data = buffer[:bytes_read] data = buffer[:bytes_read]
return data return typing.cast(bytes, data)
def recv_into(self, buffer, nbytes=None): def recv_into(
self, buffer: ctypes.Array[ctypes.c_char], nbytes: int | None = None
) -> int:
# Read short on EOF. # Read short on EOF.
if self._closed: if self._real_closed:
return 0 return 0
if nbytes is None: if nbytes is None:
@ -615,7 +587,7 @@ class WrappedSocket(object):
# well. Note that we don't actually return here because in # well. Note that we don't actually return here because in
# principle this could actually be fired along with return data. # principle this could actually be fired along with return data.
# It's unlikely though. # It's unlikely though.
self.close() self._real_close()
else: else:
_assert_no_error(result) _assert_no_error(result)
@ -623,13 +595,13 @@ class WrappedSocket(object):
# was actually read. # was actually read.
return processed_bytes.value return processed_bytes.value
def settimeout(self, timeout): def settimeout(self, timeout: float) -> None:
self._timeout = timeout self._timeout = timeout
def gettimeout(self): def gettimeout(self) -> float | None:
return self._timeout return self._timeout
def send(self, data): def send(self, data: bytes) -> int:
processed_bytes = ctypes.c_size_t(0) processed_bytes = ctypes.c_size_t(0)
with self._raise_on_error(): with self._raise_on_error():
@ -646,36 +618,38 @@ class WrappedSocket(object):
# We sent, and probably succeeded. Tell them how much we sent. # We sent, and probably succeeded. Tell them how much we sent.
return processed_bytes.value return processed_bytes.value
def sendall(self, data): def sendall(self, data: bytes) -> None:
total_sent = 0 total_sent = 0
while total_sent < len(data): while total_sent < len(data):
sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]) sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE])
total_sent += sent total_sent += sent
def shutdown(self): def shutdown(self) -> None:
with self._raise_on_error(): with self._raise_on_error():
Security.SSLClose(self.context) Security.SSLClose(self.context)
def close(self): def close(self) -> None:
self._closed = True
# TODO: should I do clean shutdown here? Do I have to? # TODO: should I do clean shutdown here? Do I have to?
if self._makefile_refs < 1: if self._io_refs <= 0:
self._closed = True self._real_close()
if self.context:
CoreFoundation.CFRelease(self.context)
self.context = None
if self._client_cert_chain:
CoreFoundation.CFRelease(self._client_cert_chain)
self._client_cert_chain = None
if self._keychain:
Security.SecKeychainDelete(self._keychain)
CoreFoundation.CFRelease(self._keychain)
shutil.rmtree(self._keychain_dir)
self._keychain = self._keychain_dir = None
return self.socket.close()
else:
self._makefile_refs -= 1
def getpeercert(self, binary_form=False): def _real_close(self) -> None:
self._real_closed = True
if self.context:
CoreFoundation.CFRelease(self.context)
self.context = None
if self._client_cert_chain:
CoreFoundation.CFRelease(self._client_cert_chain)
self._client_cert_chain = None
if self._keychain:
Security.SecKeychainDelete(self._keychain)
CoreFoundation.CFRelease(self._keychain)
shutil.rmtree(self._keychain_dir)
self._keychain = self._keychain_dir = None
return self.socket.close()
def getpeercert(self, binary_form: bool = False) -> bytes | None:
# Urgh, annoying. # Urgh, annoying.
# #
# Here's how we do this: # Here's how we do this:
@ -733,7 +707,7 @@ class WrappedSocket(object):
return der_bytes return der_bytes
def version(self): def version(self) -> str:
protocol = Security.SSLProtocol() protocol = Security.SSLProtocol()
result = Security.SSLGetNegotiatedProtocolVersion( result = Security.SSLGetNegotiatedProtocolVersion(
self.context, ctypes.byref(protocol) self.context, ctypes.byref(protocol)
@ -752,55 +726,50 @@ class WrappedSocket(object):
elif protocol.value == SecurityConst.kSSLProtocol2: elif protocol.value == SecurityConst.kSSLProtocol2:
return "SSLv2" return "SSLv2"
else: else:
raise ssl.SSLError("Unknown TLS version: %r" % protocol) raise ssl.SSLError(f"Unknown TLS version: {protocol!r}")
def _reuse(self):
self._makefile_refs += 1
def _drop(self):
if self._makefile_refs < 1:
self.close()
else:
self._makefile_refs -= 1
if _fileobject: # Platform-specific: Python 2 def makefile(
self: socket_cls,
def makefile(self, mode, bufsize=-1): mode: (
self._makefile_refs += 1 Literal["r"] | Literal["w"] | Literal["rw"] | Literal["wr"] | Literal[""]
return _fileobject(self, mode, bufsize, close=True) ) = "r",
buffering: int | None = None,
else: # Platform-specific: Python 3 *args: typing.Any,
**kwargs: typing.Any,
def makefile(self, mode="r", buffering=None, *args, **kwargs): ) -> typing.BinaryIO | typing.TextIO:
# We disable buffering with SecureTransport because it conflicts with # We disable buffering with SecureTransport because it conflicts with
# the buffering that ST does internally (see issue #1153 for more). # the buffering that ST does internally (see issue #1153 for more).
buffering = 0 buffering = 0
return backport_makefile(self, mode, buffering, *args, **kwargs) return socket_cls.makefile(self, mode, buffering, *args, **kwargs)
WrappedSocket.makefile = makefile WrappedSocket.makefile = makefile # type: ignore[attr-defined]
class SecureTransportContext(object): class SecureTransportContext:
""" """
I am a wrapper class for the SecureTransport library, to translate the I am a wrapper class for the SecureTransport library, to translate the
interface of the standard library ``SSLContext`` object to calls into interface of the standard library ``SSLContext`` object to calls into
SecureTransport. SecureTransport.
""" """
def __init__(self, protocol): def __init__(self, protocol: int) -> None:
self._min_version, self._max_version = _protocol_to_min_max[protocol] self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED
self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED
if protocol not in (None, ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_CLIENT):
self._min_version, self._max_version = _protocol_to_min_max[protocol]
self._options = 0 self._options = 0
self._verify = False self._verify = False
self._trust_bundle = None self._trust_bundle: bytes | None = None
self._client_cert = None self._client_cert: str | None = None
self._client_key = None self._client_key: str | None = None
self._client_key_passphrase = None self._client_key_passphrase = None
self._alpn_protocols = None self._alpn_protocols: list[bytes] | None = None
@property @property
def check_hostname(self): def check_hostname(self) -> Literal[True]:
""" """
SecureTransport cannot have its hostname checking disabled. For more, SecureTransport cannot have its hostname checking disabled. For more,
see the comment on getpeercert() in this file. see the comment on getpeercert() in this file.
@ -808,15 +777,14 @@ class SecureTransportContext(object):
return True return True
@check_hostname.setter @check_hostname.setter
def check_hostname(self, value): def check_hostname(self, value: typing.Any) -> None:
""" """
SecureTransport cannot have its hostname checking disabled. For more, SecureTransport cannot have its hostname checking disabled. For more,
see the comment on getpeercert() in this file. see the comment on getpeercert() in this file.
""" """
pass
@property @property
def options(self): def options(self) -> int:
# TODO: Well, crap. # TODO: Well, crap.
# #
# So this is the bit of the code that is the most likely to cause us # So this is the bit of the code that is the most likely to cause us
@ -826,19 +794,19 @@ class SecureTransportContext(object):
return self._options return self._options
@options.setter @options.setter
def options(self, value): def options(self, value: int) -> None:
# TODO: Update in line with above. # TODO: Update in line with above.
self._options = value self._options = value
@property @property
def verify_mode(self): def verify_mode(self) -> int:
return ssl.CERT_REQUIRED if self._verify else ssl.CERT_NONE return ssl.CERT_REQUIRED if self._verify else ssl.CERT_NONE
@verify_mode.setter @verify_mode.setter
def verify_mode(self, value): def verify_mode(self, value: int) -> None:
self._verify = True if value == ssl.CERT_REQUIRED else False self._verify = value == ssl.CERT_REQUIRED
def set_default_verify_paths(self): def set_default_verify_paths(self) -> None:
# So, this has to do something a bit weird. Specifically, what it does # So, this has to do something a bit weird. Specifically, what it does
# is nothing. # is nothing.
# #
@ -850,15 +818,18 @@ class SecureTransportContext(object):
# ignoring it. # ignoring it.
pass pass
def load_default_certs(self): def load_default_certs(self) -> None:
return self.set_default_verify_paths() return self.set_default_verify_paths()
def set_ciphers(self, ciphers): def set_ciphers(self, ciphers: typing.Any) -> None:
# For now, we just require the default cipher string. raise ValueError("SecureTransport doesn't support custom cipher strings")
if ciphers != util.ssl_.DEFAULT_CIPHERS:
raise ValueError("SecureTransport doesn't support custom cipher strings")
def load_verify_locations(self, cafile=None, capath=None, cadata=None): def load_verify_locations(
self,
cafile: str | None = None,
capath: str | None = None,
cadata: bytes | None = None,
) -> None:
# OK, we only really support cadata and cafile. # OK, we only really support cadata and cafile.
if capath is not None: if capath is not None:
raise ValueError("SecureTransport does not support cert directories") raise ValueError("SecureTransport does not support cert directories")
@ -868,14 +839,19 @@ class SecureTransportContext(object):
with open(cafile): with open(cafile):
pass pass
self._trust_bundle = cafile or cadata self._trust_bundle = cafile or cadata # type: ignore[assignment]
def load_cert_chain(self, certfile, keyfile=None, password=None): def load_cert_chain(
self,
certfile: str,
keyfile: str | None = None,
password: str | None = None,
) -> None:
self._client_cert = certfile self._client_cert = certfile
self._client_key = keyfile self._client_key = keyfile
self._client_cert_passphrase = password self._client_cert_passphrase = password
def set_alpn_protocols(self, protocols): def set_alpn_protocols(self, protocols: list[str | bytes]) -> None:
""" """
Sets the ALPN protocols that will later be set on the context. Sets the ALPN protocols that will later be set on the context.
@ -885,16 +861,16 @@ class SecureTransportContext(object):
raise NotImplementedError( raise NotImplementedError(
"SecureTransport supports ALPN only in macOS 10.12+" "SecureTransport supports ALPN only in macOS 10.12+"
) )
self._alpn_protocols = [six.ensure_binary(p) for p in protocols] self._alpn_protocols = [util.util.to_bytes(p, "ascii") for p in protocols]
def wrap_socket( def wrap_socket(
self, self,
sock, sock: socket_cls,
server_side=False, server_side: bool = False,
do_handshake_on_connect=True, do_handshake_on_connect: bool = True,
suppress_ragged_eofs=True, suppress_ragged_eofs: bool = True,
server_hostname=None, server_hostname: bytes | str | None = None,
): ) -> WrappedSocket:
# So, what do we do here? Firstly, we assert some properties. This is a # So, what do we do here? Firstly, we assert some properties. This is a
# stripped down shim, so there is some functionality we don't support. # stripped down shim, so there is some functionality we don't support.
# See PEP 543 for the real deal. # See PEP 543 for the real deal.
@ -911,11 +887,27 @@ class SecureTransportContext(object):
server_hostname, server_hostname,
self._verify, self._verify,
self._trust_bundle, self._trust_bundle,
self._min_version, _tls_version_to_st[self._minimum_version],
self._max_version, _tls_version_to_st[self._maximum_version],
self._client_cert, self._client_cert,
self._client_key, self._client_key,
self._client_key_passphrase, self._client_key_passphrase,
self._alpn_protocols, self._alpn_protocols,
) )
return wrapped_socket return wrapped_socket
@property
def minimum_version(self) -> int:
return self._minimum_version
@minimum_version.setter
def minimum_version(self, minimum_version: int) -> None:
self._minimum_version = minimum_version
@property
def maximum_version(self) -> int:
return self._maximum_version
@maximum_version.setter
def maximum_version(self, maximum_version: int) -> None:
self._maximum_version = maximum_version

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
""" """
This module contains provisional support for SOCKS proxies from within This module contains provisional support for SOCKS proxies from within
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
@ -38,10 +37,11 @@ with the proxy:
proxy_url="socks5h://<username>:<password>@proxy-host" proxy_url="socks5h://<username>:<password>@proxy-host"
""" """
from __future__ import absolute_import
from __future__ import annotations
try: try:
import socks import socks # type: ignore[import]
except ImportError: except ImportError:
import warnings import warnings
@ -51,13 +51,13 @@ except ImportError:
( (
"SOCKS support in urllib3 requires the installation of optional " "SOCKS support in urllib3 requires the installation of optional "
"dependencies: specifically, PySocks. For more information, see " "dependencies: specifically, PySocks. For more information, see "
"https://urllib3.readthedocs.io/en/1.26.x/contrib.html#socks-proxies" "https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies"
), ),
DependencyWarning, DependencyWarning,
) )
raise raise
from socket import error as SocketError import typing
from socket import timeout as SocketTimeout from socket import timeout as SocketTimeout
from ..connection import HTTPConnection, HTTPSConnection from ..connection import HTTPConnection, HTTPSConnection
@ -69,7 +69,21 @@ from ..util.url import parse_url
try: try:
import ssl import ssl
except ImportError: except ImportError:
ssl = None ssl = None # type: ignore[assignment]
try:
from typing import TypedDict
class _TYPE_SOCKS_OPTIONS(TypedDict):
socks_version: int
proxy_host: str | None
proxy_port: str | None
username: str | None
password: str | None
rdns: bool
except ImportError: # Python 3.7
_TYPE_SOCKS_OPTIONS = typing.Dict[str, typing.Any] # type: ignore[misc, assignment]
class SOCKSConnection(HTTPConnection): class SOCKSConnection(HTTPConnection):
@ -77,15 +91,20 @@ class SOCKSConnection(HTTPConnection):
A plain-text HTTP connection that connects via a SOCKS proxy. A plain-text HTTP connection that connects via a SOCKS proxy.
""" """
def __init__(self, *args, **kwargs): def __init__(
self._socks_options = kwargs.pop("_socks_options") self,
super(SOCKSConnection, self).__init__(*args, **kwargs) _socks_options: _TYPE_SOCKS_OPTIONS,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
self._socks_options = _socks_options
super().__init__(*args, **kwargs)
def _new_conn(self): def _new_conn(self) -> socks.socksocket:
""" """
Establish a new connection via the SOCKS proxy. Establish a new connection via the SOCKS proxy.
""" """
extra_kw = {} extra_kw: dict[str, typing.Any] = {}
if self.source_address: if self.source_address:
extra_kw["source_address"] = self.source_address extra_kw["source_address"] = self.source_address
@ -102,15 +121,14 @@ class SOCKSConnection(HTTPConnection):
proxy_password=self._socks_options["password"], proxy_password=self._socks_options["password"],
proxy_rdns=self._socks_options["rdns"], proxy_rdns=self._socks_options["rdns"],
timeout=self.timeout, timeout=self.timeout,
**extra_kw **extra_kw,
) )
except SocketTimeout: except SocketTimeout as e:
raise ConnectTimeoutError( raise ConnectTimeoutError(
self, self,
"Connection to %s timed out. (connect timeout=%s)" f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
% (self.host, self.timeout), ) from e
)
except socks.ProxyError as e: except socks.ProxyError as e:
# This is fragile as hell, but it seems to be the only way to raise # This is fragile as hell, but it seems to be the only way to raise
@ -120,22 +138,23 @@ class SOCKSConnection(HTTPConnection):
if isinstance(error, SocketTimeout): if isinstance(error, SocketTimeout):
raise ConnectTimeoutError( raise ConnectTimeoutError(
self, self,
"Connection to %s timed out. (connect timeout=%s)" f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
% (self.host, self.timeout), ) from e
)
else: else:
# Adding `from e` messes with coverage somehow, so it's omitted.
# See #2386.
raise NewConnectionError( raise NewConnectionError(
self, "Failed to establish a new connection: %s" % error self, f"Failed to establish a new connection: {error}"
) )
else: else:
raise NewConnectionError( raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e self, f"Failed to establish a new connection: {e}"
) ) from e
except SocketError as e: # Defensive: PySocks should catch all these. except OSError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError( raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e self, f"Failed to establish a new connection: {e}"
) ) from e
return conn return conn
@ -169,12 +188,12 @@ class SOCKSProxyManager(PoolManager):
def __init__( def __init__(
self, self,
proxy_url, proxy_url: str,
username=None, username: str | None = None,
password=None, password: str | None = None,
num_pools=10, num_pools: int = 10,
headers=None, headers: typing.Mapping[str, str] | None = None,
**connection_pool_kw **connection_pool_kw: typing.Any,
): ):
parsed = parse_url(proxy_url) parsed = parse_url(proxy_url)
@ -195,7 +214,7 @@ class SOCKSProxyManager(PoolManager):
socks_version = socks.PROXY_TYPE_SOCKS4 socks_version = socks.PROXY_TYPE_SOCKS4
rdns = True rdns = True
else: else:
raise ValueError("Unable to determine SOCKS version from %s" % proxy_url) raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
self.proxy_url = proxy_url self.proxy_url = proxy_url
@ -209,8 +228,6 @@ class SOCKSProxyManager(PoolManager):
} }
connection_pool_kw["_socks_options"] = socks_options connection_pool_kw["_socks_options"] = socks_options
super(SOCKSProxyManager, self).__init__( super().__init__(num_pools, headers, **connection_pool_kw)
num_pools, headers, **connection_pool_kw
)
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme

View file

@ -1,6 +1,16 @@
from __future__ import absolute_import from __future__ import annotations
from .packages.six.moves.http_client import IncompleteRead as httplib_IncompleteRead import socket
import typing
import warnings
from email.errors import MessageDefect
from http.client import IncompleteRead as httplib_IncompleteRead
if typing.TYPE_CHECKING:
from .connection import HTTPConnection
from .connectionpool import ConnectionPool
from .response import HTTPResponse
from .util.retry import Retry
# Base Exceptions # Base Exceptions
@ -8,23 +18,24 @@ from .packages.six.moves.http_client import IncompleteRead as httplib_Incomplete
class HTTPError(Exception): class HTTPError(Exception):
"""Base exception used by this module.""" """Base exception used by this module."""
pass
class HTTPWarning(Warning): class HTTPWarning(Warning):
"""Base warning used by this module.""" """Base warning used by this module."""
pass
_TYPE_REDUCE_RESULT = typing.Tuple[
typing.Callable[..., object], typing.Tuple[object, ...]
]
class PoolError(HTTPError): class PoolError(HTTPError):
"""Base exception for errors caused within a pool.""" """Base exception for errors caused within a pool."""
def __init__(self, pool, message): def __init__(self, pool: ConnectionPool, message: str) -> None:
self.pool = pool self.pool = pool
HTTPError.__init__(self, "%s: %s" % (pool, message)) super().__init__(f"{pool}: {message}")
def __reduce__(self): def __reduce__(self) -> _TYPE_REDUCE_RESULT:
# For pickling purposes. # For pickling purposes.
return self.__class__, (None, None) return self.__class__, (None, None)
@ -32,11 +43,11 @@ class PoolError(HTTPError):
class RequestError(PoolError): class RequestError(PoolError):
"""Base exception for PoolErrors that have associated URLs.""" """Base exception for PoolErrors that have associated URLs."""
def __init__(self, pool, url, message): def __init__(self, pool: ConnectionPool, url: str, message: str) -> None:
self.url = url self.url = url
PoolError.__init__(self, pool, message) super().__init__(pool, message)
def __reduce__(self): def __reduce__(self) -> _TYPE_REDUCE_RESULT:
# For pickling purposes. # For pickling purposes.
return self.__class__, (None, self.url, None) return self.__class__, (None, self.url, None)
@ -44,28 +55,25 @@ class RequestError(PoolError):
class SSLError(HTTPError): class SSLError(HTTPError):
"""Raised when SSL certificate fails in an HTTPS connection.""" """Raised when SSL certificate fails in an HTTPS connection."""
pass
class ProxyError(HTTPError): class ProxyError(HTTPError):
"""Raised when the connection to a proxy fails.""" """Raised when the connection to a proxy fails."""
def __init__(self, message, error, *args): # The original error is also available as __cause__.
super(ProxyError, self).__init__(message, error, *args) original_error: Exception
def __init__(self, message: str, error: Exception) -> None:
super().__init__(message, error)
self.original_error = error self.original_error = error
class DecodeError(HTTPError): class DecodeError(HTTPError):
"""Raised when automatic decoding based on Content-Type fails.""" """Raised when automatic decoding based on Content-Type fails."""
pass
class ProtocolError(HTTPError): class ProtocolError(HTTPError):
"""Raised when something unexpected happens mid-request/response.""" """Raised when something unexpected happens mid-request/response."""
pass
#: Renamed to ProtocolError but aliased for backwards compatibility. #: Renamed to ProtocolError but aliased for backwards compatibility.
ConnectionError = ProtocolError ConnectionError = ProtocolError
@ -79,33 +87,36 @@ class MaxRetryError(RequestError):
:param pool: The connection pool :param pool: The connection pool
:type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool` :type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool`
:param string url: The requested Url :param str url: The requested Url
:param exceptions.Exception reason: The underlying error :param reason: The underlying error
:type reason: :class:`Exception`
""" """
def __init__(self, pool, url, reason=None): def __init__(
self, pool: ConnectionPool, url: str, reason: Exception | None = None
) -> None:
self.reason = reason self.reason = reason
message = "Max retries exceeded with url: %s (Caused by %r)" % (url, reason) message = f"Max retries exceeded with url: {url} (Caused by {reason!r})"
RequestError.__init__(self, pool, url, message) super().__init__(pool, url, message)
class HostChangedError(RequestError): class HostChangedError(RequestError):
"""Raised when an existing pool gets a request for a foreign host.""" """Raised when an existing pool gets a request for a foreign host."""
def __init__(self, pool, url, retries=3): def __init__(
message = "Tried to open a foreign host with url: %s" % url self, pool: ConnectionPool, url: str, retries: Retry | int = 3
RequestError.__init__(self, pool, url, message) ) -> None:
message = f"Tried to open a foreign host with url: {url}"
super().__init__(pool, url, message)
self.retries = retries self.retries = retries
class TimeoutStateError(HTTPError): class TimeoutStateError(HTTPError):
"""Raised when passing an invalid state to a timeout""" """Raised when passing an invalid state to a timeout"""
pass
class TimeoutError(HTTPError): class TimeoutError(HTTPError):
"""Raised when a socket timeout error occurs. """Raised when a socket timeout error occurs.
@ -114,53 +125,66 @@ class TimeoutError(HTTPError):
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`. <ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
""" """
pass
class ReadTimeoutError(TimeoutError, RequestError): class ReadTimeoutError(TimeoutError, RequestError):
"""Raised when a socket timeout occurs while receiving data from a server""" """Raised when a socket timeout occurs while receiving data from a server"""
pass
# This timeout error does not have a URL attached and needs to inherit from the # This timeout error does not have a URL attached and needs to inherit from the
# base HTTPError # base HTTPError
class ConnectTimeoutError(TimeoutError): class ConnectTimeoutError(TimeoutError):
"""Raised when a socket timeout occurs while connecting to a server""" """Raised when a socket timeout occurs while connecting to a server"""
pass
class NewConnectionError(ConnectTimeoutError, HTTPError):
class NewConnectionError(ConnectTimeoutError, PoolError):
"""Raised when we fail to establish a new connection. Usually ECONNREFUSED.""" """Raised when we fail to establish a new connection. Usually ECONNREFUSED."""
pass def __init__(self, conn: HTTPConnection, message: str) -> None:
self.conn = conn
super().__init__(f"{conn}: {message}")
@property
def pool(self) -> HTTPConnection:
warnings.warn(
"The 'pool' property is deprecated and will be removed "
"in urllib3 v2.1.0. Use 'conn' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.conn
class NameResolutionError(NewConnectionError):
"""Raised when host name resolution fails."""
def __init__(self, host: str, conn: HTTPConnection, reason: socket.gaierror):
message = f"Failed to resolve '{host}' ({reason})"
super().__init__(conn, message)
class EmptyPoolError(PoolError): class EmptyPoolError(PoolError):
"""Raised when a pool runs out of connections and no more are allowed.""" """Raised when a pool runs out of connections and no more are allowed."""
pass
class FullPoolError(PoolError):
"""Raised when we try to add a connection to a full pool in blocking mode."""
class ClosedPoolError(PoolError): class ClosedPoolError(PoolError):
"""Raised when a request enters a pool after the pool has been closed.""" """Raised when a request enters a pool after the pool has been closed."""
pass
class LocationValueError(ValueError, HTTPError): class LocationValueError(ValueError, HTTPError):
"""Raised when there is something wrong with a given URL input.""" """Raised when there is something wrong with a given URL input."""
pass
class LocationParseError(LocationValueError): class LocationParseError(LocationValueError):
"""Raised when get_host or similar fails to parse the URL input.""" """Raised when get_host or similar fails to parse the URL input."""
def __init__(self, location): def __init__(self, location: str) -> None:
message = "Failed to parse: %s" % location message = f"Failed to parse: {location}"
HTTPError.__init__(self, message) super().__init__(message)
self.location = location self.location = location
@ -168,9 +192,9 @@ class LocationParseError(LocationValueError):
class URLSchemeUnknown(LocationValueError): class URLSchemeUnknown(LocationValueError):
"""Raised when a URL input has an unsupported scheme.""" """Raised when a URL input has an unsupported scheme."""
def __init__(self, scheme): def __init__(self, scheme: str):
message = "Not supported URL scheme %s" % scheme message = f"Not supported URL scheme {scheme}"
super(URLSchemeUnknown, self).__init__(message) super().__init__(message)
self.scheme = scheme self.scheme = scheme
@ -185,38 +209,22 @@ class ResponseError(HTTPError):
class SecurityWarning(HTTPWarning): class SecurityWarning(HTTPWarning):
"""Warned when performing security reducing actions""" """Warned when performing security reducing actions"""
pass
class SubjectAltNameWarning(SecurityWarning):
"""Warned when connecting to a host with a certificate missing a SAN."""
pass
class InsecureRequestWarning(SecurityWarning): class InsecureRequestWarning(SecurityWarning):
"""Warned when making an unverified HTTPS request.""" """Warned when making an unverified HTTPS request."""
pass
class NotOpenSSLWarning(SecurityWarning):
"""Warned when using unsupported SSL library"""
class SystemTimeWarning(SecurityWarning): class SystemTimeWarning(SecurityWarning):
"""Warned when system time is suspected to be wrong""" """Warned when system time is suspected to be wrong"""
pass
class InsecurePlatformWarning(SecurityWarning): class InsecurePlatformWarning(SecurityWarning):
"""Warned when certain TLS/SSL configuration is not available on a platform.""" """Warned when certain TLS/SSL configuration is not available on a platform."""
pass
class SNIMissingWarning(HTTPWarning):
"""Warned when making a HTTPS request without SNI available."""
pass
class DependencyWarning(HTTPWarning): class DependencyWarning(HTTPWarning):
""" """
@ -224,14 +232,10 @@ class DependencyWarning(HTTPWarning):
dependencies. dependencies.
""" """
pass
class ResponseNotChunked(ProtocolError, ValueError): class ResponseNotChunked(ProtocolError, ValueError):
"""Response needs to be chunked in order to read it as chunks.""" """Response needs to be chunked in order to read it as chunks."""
pass
class BodyNotHttplibCompatible(HTTPError): class BodyNotHttplibCompatible(HTTPError):
""" """
@ -239,8 +243,6 @@ class BodyNotHttplibCompatible(HTTPError):
(have an fp attribute which returns raw chunks) for read_chunked(). (have an fp attribute which returns raw chunks) for read_chunked().
""" """
pass
class IncompleteRead(HTTPError, httplib_IncompleteRead): class IncompleteRead(HTTPError, httplib_IncompleteRead):
""" """
@ -250,12 +252,13 @@ class IncompleteRead(HTTPError, httplib_IncompleteRead):
for ``partial`` to avoid creating large objects on streamed reads. for ``partial`` to avoid creating large objects on streamed reads.
""" """
def __init__(self, partial, expected): def __init__(self, partial: int, expected: int) -> None:
super(IncompleteRead, self).__init__(partial, expected) self.partial = partial # type: ignore[assignment]
self.expected = expected
def __repr__(self): def __repr__(self) -> str:
return "IncompleteRead(%i bytes read, %i more expected)" % ( return "IncompleteRead(%i bytes read, %i more expected)" % (
self.partial, self.partial, # type: ignore[str-format]
self.expected, self.expected,
) )
@ -263,14 +266,13 @@ class IncompleteRead(HTTPError, httplib_IncompleteRead):
class InvalidChunkLength(HTTPError, httplib_IncompleteRead): class InvalidChunkLength(HTTPError, httplib_IncompleteRead):
"""Invalid chunk length in a chunked response.""" """Invalid chunk length in a chunked response."""
def __init__(self, response, length): def __init__(self, response: HTTPResponse, length: bytes) -> None:
super(InvalidChunkLength, self).__init__( self.partial: int = response.tell() # type: ignore[assignment]
response.tell(), response.length_remaining self.expected: int | None = response.length_remaining
)
self.response = response self.response = response
self.length = length self.length = length
def __repr__(self): def __repr__(self) -> str:
return "InvalidChunkLength(got length %r, %i bytes read)" % ( return "InvalidChunkLength(got length %r, %i bytes read)" % (
self.length, self.length,
self.partial, self.partial,
@ -280,15 +282,13 @@ class InvalidChunkLength(HTTPError, httplib_IncompleteRead):
class InvalidHeader(HTTPError): class InvalidHeader(HTTPError):
"""The header provided was somehow invalid.""" """The header provided was somehow invalid."""
pass
class ProxySchemeUnknown(AssertionError, URLSchemeUnknown): class ProxySchemeUnknown(AssertionError, URLSchemeUnknown):
"""ProxyManager does not support the supplied scheme""" """ProxyManager does not support the supplied scheme"""
# TODO(t-8ch): Stop inheriting from AssertionError in v2.0. # TODO(t-8ch): Stop inheriting from AssertionError in v2.0.
def __init__(self, scheme): def __init__(self, scheme: str | None) -> None:
# 'localhost' is here because our URL parser parses # 'localhost' is here because our URL parser parses
# localhost:8080 -> scheme=localhost, remove if we fix this. # localhost:8080 -> scheme=localhost, remove if we fix this.
if scheme == "localhost": if scheme == "localhost":
@ -296,28 +296,23 @@ class ProxySchemeUnknown(AssertionError, URLSchemeUnknown):
if scheme is None: if scheme is None:
message = "Proxy URL had no scheme, should start with http:// or https://" message = "Proxy URL had no scheme, should start with http:// or https://"
else: else:
message = ( message = f"Proxy URL had unsupported scheme {scheme}, should use http:// or https://"
"Proxy URL had unsupported scheme %s, should use http:// or https://" super().__init__(message)
% scheme
)
super(ProxySchemeUnknown, self).__init__(message)
class ProxySchemeUnsupported(ValueError): class ProxySchemeUnsupported(ValueError):
"""Fetching HTTPS resources through HTTPS proxies is unsupported""" """Fetching HTTPS resources through HTTPS proxies is unsupported"""
pass
class HeaderParsingError(HTTPError): class HeaderParsingError(HTTPError):
"""Raised by assert_header_parsing, but we convert it to a log.warning statement.""" """Raised by assert_header_parsing, but we convert it to a log.warning statement."""
def __init__(self, defects, unparsed_data): def __init__(
message = "%s, unparsed data: %r" % (defects or "Unknown", unparsed_data) self, defects: list[MessageDefect], unparsed_data: bytes | str | None
super(HeaderParsingError, self).__init__(message) ) -> None:
message = f"{defects or 'Unknown'}, unparsed data: {unparsed_data!r}"
super().__init__(message)
class UnrewindableBodyError(HTTPError): class UnrewindableBodyError(HTTPError):
"""urllib3 encountered an error when trying to rewind a body""" """urllib3 encountered an error when trying to rewind a body"""
pass

View file

@ -1,55 +0,0 @@
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
if TYPE_CHECKING:
from urllib3.connectionpool import ConnectionPool
class HTTPError(Exception): ...
class HTTPWarning(Warning): ...
class PoolError(HTTPError):
pool: ConnectionPool
def __init__(self, pool: ConnectionPool, message: str) -> None: ...
def __reduce__(self) -> Union[str, Tuple[Any, ...]]: ...
class RequestError(PoolError):
url: str
def __init__(self, pool: ConnectionPool, url: str, message: str) -> None: ...
def __reduce__(self) -> Union[str, Tuple[Any, ...]]: ...
class SSLError(HTTPError): ...
class ProxyError(HTTPError): ...
class DecodeError(HTTPError): ...
class ProtocolError(HTTPError): ...
ConnectionError: ProtocolError
class MaxRetryError(RequestError):
reason: str
def __init__(
self, pool: ConnectionPool, url: str, reason: Optional[str]
) -> None: ...
class HostChangedError(RequestError):
retries: int
def __init__(self, pool: ConnectionPool, url: str, retries: int) -> None: ...
class TimeoutStateError(HTTPError): ...
class TimeoutError(HTTPError): ...
class ReadTimeoutError(TimeoutError, RequestError): ...
class ConnectTimeoutError(TimeoutError): ...
class EmptyPoolError(PoolError): ...
class ClosedPoolError(PoolError): ...
class LocationValueError(ValueError, HTTPError): ...
class LocationParseError(LocationValueError):
location: str
def __init__(self, location: str) -> None: ...
class ResponseError(HTTPError):
GENERIC_ERROR: Any
SPECIFIC_ERROR: Any
class SecurityWarning(HTTPWarning): ...
class InsecureRequestWarning(SecurityWarning): ...
class SystemTimeWarning(SecurityWarning): ...
class InsecurePlatformWarning(SecurityWarning): ...

View file

@ -1,13 +1,20 @@
from __future__ import absolute_import from __future__ import annotations
import email.utils import email.utils
import mimetypes import mimetypes
import re import typing
from .packages import six _TYPE_FIELD_VALUE = typing.Union[str, bytes]
_TYPE_FIELD_VALUE_TUPLE = typing.Union[
_TYPE_FIELD_VALUE,
typing.Tuple[str, _TYPE_FIELD_VALUE],
typing.Tuple[str, _TYPE_FIELD_VALUE, str],
]
def guess_content_type(filename, default="application/octet-stream"): def guess_content_type(
filename: str | None, default: str = "application/octet-stream"
) -> str:
""" """
Guess the "Content-Type" of a file. Guess the "Content-Type" of a file.
@ -21,7 +28,7 @@ def guess_content_type(filename, default="application/octet-stream"):
return default return default
def format_header_param_rfc2231(name, value): def format_header_param_rfc2231(name: str, value: _TYPE_FIELD_VALUE) -> str:
""" """
Helper function to format and quote a single header parameter using the Helper function to format and quote a single header parameter using the
strategy defined in RFC 2231. strategy defined in RFC 2231.
@ -34,14 +41,28 @@ def format_header_param_rfc2231(name, value):
The name of the parameter, a string expected to be ASCII only. The name of the parameter, a string expected to be ASCII only.
:param value: :param value:
The value of the parameter, provided as ``bytes`` or `str``. The value of the parameter, provided as ``bytes`` or `str``.
:ret: :returns:
An RFC-2231-formatted unicode string. An RFC-2231-formatted unicode string.
.. deprecated:: 2.0.0
Will be removed in urllib3 v2.1.0. This is not valid for
``multipart/form-data`` header parameters.
""" """
if isinstance(value, six.binary_type): import warnings
warnings.warn(
"'format_header_param_rfc2231' is deprecated and will be "
"removed in urllib3 v2.1.0. This is not valid for "
"multipart/form-data header parameters.",
DeprecationWarning,
stacklevel=2,
)
if isinstance(value, bytes):
value = value.decode("utf-8") value = value.decode("utf-8")
if not any(ch in value for ch in '"\\\r\n'): if not any(ch in value for ch in '"\\\r\n'):
result = u'%s="%s"' % (name, value) result = f'{name}="{value}"'
try: try:
result.encode("ascii") result.encode("ascii")
except (UnicodeEncodeError, UnicodeDecodeError): except (UnicodeEncodeError, UnicodeDecodeError):
@ -49,81 +70,87 @@ def format_header_param_rfc2231(name, value):
else: else:
return result return result
if six.PY2: # Python 2:
value = value.encode("utf-8")
# encode_rfc2231 accepts an encoded string and returns an ascii-encoded
# string in Python 2 but accepts and returns unicode strings in Python 3
value = email.utils.encode_rfc2231(value, "utf-8") value = email.utils.encode_rfc2231(value, "utf-8")
value = "%s*=%s" % (name, value) value = f"{name}*={value}"
if six.PY2: # Python 2:
value = value.decode("utf-8")
return value return value
_HTML5_REPLACEMENTS = { def format_multipart_header_param(name: str, value: _TYPE_FIELD_VALUE) -> str:
u"\u0022": u"%22",
# Replace "\" with "\\".
u"\u005C": u"\u005C\u005C",
}
# All control characters from 0x00 to 0x1F *except* 0x1B.
_HTML5_REPLACEMENTS.update(
{
six.unichr(cc): u"%{:02X}".format(cc)
for cc in range(0x00, 0x1F + 1)
if cc not in (0x1B,)
}
)
def _replace_multiple(value, needles_and_replacements):
def replacer(match):
return needles_and_replacements[match.group(0)]
pattern = re.compile(
r"|".join([re.escape(needle) for needle in needles_and_replacements.keys()])
)
result = pattern.sub(replacer, value)
return result
def format_header_param_html5(name, value):
""" """
Helper function to format and quote a single header parameter using the Format and quote a single multipart header parameter.
HTML5 strategy.
Particularly useful for header parameters which might contain This follows the `WHATWG HTML Standard`_ as of 2021/06/10, matching
non-ASCII values, like file names. This follows the `HTML5 Working Draft the behavior of current browser and curl versions. Values are
Section 4.10.22.7`_ and matches the behavior of curl and modern browsers. assumed to be UTF-8. The ``\\n``, ``\\r``, and ``"`` characters are
percent encoded.
.. _HTML5 Working Draft Section 4.10.22.7: .. _WHATWG HTML Standard:
https://w3c.github.io/html/sec-forms.html#multipart-form-data https://html.spec.whatwg.org/multipage/
form-control-infrastructure.html#multipart-form-data
:param name: :param name:
The name of the parameter, a string expected to be ASCII only. The name of the parameter, an ASCII-only ``str``.
:param value: :param value:
The value of the parameter, provided as ``bytes`` or `str``. The value of the parameter, a ``str`` or UTF-8 encoded
:ret: ``bytes``.
A unicode string, stripped of troublesome characters. :returns:
A string ``name="value"`` with the escaped value.
.. versionchanged:: 2.0.0
Matches the WHATWG HTML Standard as of 2021/06/10. Control
characters are no longer percent encoded.
.. versionchanged:: 2.0.0
Renamed from ``format_header_param_html5`` and
``format_header_param``. The old names will be removed in
urllib3 v2.1.0.
""" """
if isinstance(value, six.binary_type): if isinstance(value, bytes):
value = value.decode("utf-8") value = value.decode("utf-8")
value = _replace_multiple(value, _HTML5_REPLACEMENTS) # percent encode \n \r "
value = value.translate({10: "%0A", 13: "%0D", 34: "%22"})
return u'%s="%s"' % (name, value) return f'{name}="{value}"'
# For backwards-compatibility. def format_header_param_html5(name: str, value: _TYPE_FIELD_VALUE) -> str:
format_header_param = format_header_param_html5 """
.. deprecated:: 2.0.0
Renamed to :func:`format_multipart_header_param`. Will be
removed in urllib3 v2.1.0.
"""
import warnings
warnings.warn(
"'format_header_param_html5' has been renamed to "
"'format_multipart_header_param'. The old name will be "
"removed in urllib3 v2.1.0.",
DeprecationWarning,
stacklevel=2,
)
return format_multipart_header_param(name, value)
class RequestField(object): def format_header_param(name: str, value: _TYPE_FIELD_VALUE) -> str:
"""
.. deprecated:: 2.0.0
Renamed to :func:`format_multipart_header_param`. Will be
removed in urllib3 v2.1.0.
"""
import warnings
warnings.warn(
"'format_header_param' has been renamed to "
"'format_multipart_header_param'. The old name will be "
"removed in urllib3 v2.1.0.",
DeprecationWarning,
stacklevel=2,
)
return format_multipart_header_param(name, value)
class RequestField:
""" """
A data container for request body parameters. A data container for request body parameters.
@ -135,29 +162,47 @@ class RequestField(object):
An optional filename of the request field. Must be unicode. An optional filename of the request field. Must be unicode.
:param headers: :param headers:
An optional dict-like object of headers to initially use for the field. An optional dict-like object of headers to initially use for the field.
:param header_formatter:
An optional callable that is used to encode and format the headers. By .. versionchanged:: 2.0.0
default, this is :func:`format_header_param_html5`. The ``header_formatter`` parameter is deprecated and will
be removed in urllib3 v2.1.0.
""" """
def __init__( def __init__(
self, self,
name, name: str,
data, data: _TYPE_FIELD_VALUE,
filename=None, filename: str | None = None,
headers=None, headers: typing.Mapping[str, str] | None = None,
header_formatter=format_header_param_html5, header_formatter: typing.Callable[[str, _TYPE_FIELD_VALUE], str] | None = None,
): ):
self._name = name self._name = name
self._filename = filename self._filename = filename
self.data = data self.data = data
self.headers = {} self.headers: dict[str, str | None] = {}
if headers: if headers:
self.headers = dict(headers) self.headers = dict(headers)
self.header_formatter = header_formatter
if header_formatter is not None:
import warnings
warnings.warn(
"The 'header_formatter' parameter is deprecated and "
"will be removed in urllib3 v2.1.0.",
DeprecationWarning,
stacklevel=2,
)
self.header_formatter = header_formatter
else:
self.header_formatter = format_multipart_header_param
@classmethod @classmethod
def from_tuples(cls, fieldname, value, header_formatter=format_header_param_html5): def from_tuples(
cls,
fieldname: str,
value: _TYPE_FIELD_VALUE_TUPLE,
header_formatter: typing.Callable[[str, _TYPE_FIELD_VALUE], str] | None = None,
) -> RequestField:
""" """
A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters. A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
@ -174,11 +219,19 @@ class RequestField(object):
Field names and filenames must be unicode. Field names and filenames must be unicode.
""" """
filename: str | None
content_type: str | None
data: _TYPE_FIELD_VALUE
if isinstance(value, tuple): if isinstance(value, tuple):
if len(value) == 3: if len(value) == 3:
filename, data, content_type = value filename, data, content_type = typing.cast(
typing.Tuple[str, _TYPE_FIELD_VALUE, str], value
)
else: else:
filename, data = value filename, data = typing.cast(
typing.Tuple[str, _TYPE_FIELD_VALUE], value
)
content_type = guess_content_type(filename) content_type = guess_content_type(filename)
else: else:
filename = None filename = None
@ -192,20 +245,29 @@ class RequestField(object):
return request_param return request_param
def _render_part(self, name, value): def _render_part(self, name: str, value: _TYPE_FIELD_VALUE) -> str:
""" """
Overridable helper function to format a single header parameter. By Override this method to change how each multipart header
default, this calls ``self.header_formatter``. parameter is formatted. By default, this calls
:func:`format_multipart_header_param`.
:param name: :param name:
The name of the parameter, a string expected to be ASCII only. The name of the parameter, an ASCII-only ``str``.
:param value: :param value:
The value of the parameter, provided as a unicode string. The value of the parameter, a ``str`` or UTF-8 encoded
""" ``bytes``.
:meta public:
"""
return self.header_formatter(name, value) return self.header_formatter(name, value)
def _render_parts(self, header_parts): def _render_parts(
self,
header_parts: (
dict[str, _TYPE_FIELD_VALUE | None]
| typing.Sequence[tuple[str, _TYPE_FIELD_VALUE | None]]
),
) -> str:
""" """
Helper function to format and quote a single header. Helper function to format and quote a single header.
@ -216,18 +278,21 @@ class RequestField(object):
A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format
as `k1="v1"; k2="v2"; ...`. as `k1="v1"; k2="v2"; ...`.
""" """
iterable: typing.Iterable[tuple[str, _TYPE_FIELD_VALUE | None]]
parts = [] parts = []
iterable = header_parts
if isinstance(header_parts, dict): if isinstance(header_parts, dict):
iterable = header_parts.items() iterable = header_parts.items()
else:
iterable = header_parts
for name, value in iterable: for name, value in iterable:
if value is not None: if value is not None:
parts.append(self._render_part(name, value)) parts.append(self._render_part(name, value))
return u"; ".join(parts) return "; ".join(parts)
def render_headers(self): def render_headers(self) -> str:
""" """
Renders the headers for this request field. Renders the headers for this request field.
""" """
@ -236,39 +301,45 @@ class RequestField(object):
sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"] sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
for sort_key in sort_keys: for sort_key in sort_keys:
if self.headers.get(sort_key, False): if self.headers.get(sort_key, False):
lines.append(u"%s: %s" % (sort_key, self.headers[sort_key])) lines.append(f"{sort_key}: {self.headers[sort_key]}")
for header_name, header_value in self.headers.items(): for header_name, header_value in self.headers.items():
if header_name not in sort_keys: if header_name not in sort_keys:
if header_value: if header_value:
lines.append(u"%s: %s" % (header_name, header_value)) lines.append(f"{header_name}: {header_value}")
lines.append(u"\r\n") lines.append("\r\n")
return u"\r\n".join(lines) return "\r\n".join(lines)
def make_multipart( def make_multipart(
self, content_disposition=None, content_type=None, content_location=None self,
): content_disposition: str | None = None,
content_type: str | None = None,
content_location: str | None = None,
) -> None:
""" """
Makes this request field into a multipart request field. Makes this request field into a multipart request field.
This method overrides "Content-Disposition", "Content-Type" and This method overrides "Content-Disposition", "Content-Type" and
"Content-Location" headers to the request parameter. "Content-Location" headers to the request parameter.
:param content_disposition:
The 'Content-Disposition' of the request body. Defaults to 'form-data'
:param content_type: :param content_type:
The 'Content-Type' of the request body. The 'Content-Type' of the request body.
:param content_location: :param content_location:
The 'Content-Location' of the request body. The 'Content-Location' of the request body.
""" """
self.headers["Content-Disposition"] = content_disposition or u"form-data" content_disposition = (content_disposition or "form-data") + "; ".join(
self.headers["Content-Disposition"] += u"; ".join(
[ [
u"", "",
self._render_parts( self._render_parts(
((u"name", self._name), (u"filename", self._filename)) (("name", self._name), ("filename", self._filename))
), ),
] ]
) )
self.headers["Content-Disposition"] = content_disposition
self.headers["Content-Type"] = content_type self.headers["Content-Type"] = content_type
self.headers["Content-Location"] = content_location self.headers["Content-Location"] = content_location

View file

@ -1,28 +0,0 @@
# Stubs for requests.packages.urllib3.fields (Python 3.4)
from typing import Any, Callable, Mapping, Optional
def guess_content_type(filename: str, default: str) -> str: ...
def format_header_param_rfc2231(name: str, value: str) -> str: ...
def format_header_param_html5(name: str, value: str) -> str: ...
def format_header_param(name: str, value: str) -> str: ...
class RequestField:
data: Any
headers: Optional[Mapping[str, str]]
def __init__(
self,
name: str,
data: Any,
filename: Optional[str],
headers: Optional[Mapping[str, str]],
header_formatter: Callable[[str, str], str],
) -> None: ...
@classmethod
def from_tuples(
cls, fieldname: str, value: str, header_formatter: Callable[[str, str], str]
) -> RequestField: ...
def render_headers(self) -> str: ...
def make_multipart(
self, content_disposition: str, content_type: str, content_location: str
) -> None: ...

View file

@ -1,28 +1,32 @@
from __future__ import absolute_import from __future__ import annotations
import binascii import binascii
import codecs import codecs
import os import os
import typing
from io import BytesIO from io import BytesIO
from .fields import RequestField from .fields import _TYPE_FIELD_VALUE_TUPLE, RequestField
from .packages import six
from .packages.six import b
writer = codecs.lookup("utf-8")[3] writer = codecs.lookup("utf-8")[3]
_TYPE_FIELDS_SEQUENCE = typing.Sequence[
typing.Union[typing.Tuple[str, _TYPE_FIELD_VALUE_TUPLE], RequestField]
]
_TYPE_FIELDS = typing.Union[
_TYPE_FIELDS_SEQUENCE,
typing.Mapping[str, _TYPE_FIELD_VALUE_TUPLE],
]
def choose_boundary():
def choose_boundary() -> str:
""" """
Our embarrassingly-simple replacement for mimetools.choose_boundary. Our embarrassingly-simple replacement for mimetools.choose_boundary.
""" """
boundary = binascii.hexlify(os.urandom(16)) return binascii.hexlify(os.urandom(16)).decode()
if not six.PY2:
boundary = boundary.decode("ascii")
return boundary
def iter_field_objects(fields): def iter_field_objects(fields: _TYPE_FIELDS) -> typing.Iterable[RequestField]:
""" """
Iterate over fields. Iterate over fields.
@ -30,42 +34,29 @@ def iter_field_objects(fields):
:class:`~urllib3.fields.RequestField`. :class:`~urllib3.fields.RequestField`.
""" """
if isinstance(fields, dict): iterable: typing.Iterable[RequestField | tuple[str, _TYPE_FIELD_VALUE_TUPLE]]
i = six.iteritems(fields)
else:
i = iter(fields)
for field in i: if isinstance(fields, typing.Mapping):
iterable = fields.items()
else:
iterable = fields
for field in iterable:
if isinstance(field, RequestField): if isinstance(field, RequestField):
yield field yield field
else: else:
yield RequestField.from_tuples(*field) yield RequestField.from_tuples(*field)
def iter_fields(fields): def encode_multipart_formdata(
""" fields: _TYPE_FIELDS, boundary: str | None = None
.. deprecated:: 1.6 ) -> tuple[bytes, str]:
Iterate over fields.
The addition of :class:`~urllib3.fields.RequestField` makes this function
obsolete. Instead, use :func:`iter_field_objects`, which returns
:class:`~urllib3.fields.RequestField` objects.
Supports list of (k, v) tuples and dicts.
"""
if isinstance(fields, dict):
return ((k, v) for k, v in six.iteritems(fields))
return ((k, v) for k, v in fields)
def encode_multipart_formdata(fields, boundary=None):
""" """
Encode a dictionary of ``fields`` using the multipart/form-data MIME format. Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
:param fields: :param fields:
Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`). Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`).
Values are processed by :func:`urllib3.fields.RequestField.from_tuples`.
:param boundary: :param boundary:
If not specified, then a random boundary will be generated using If not specified, then a random boundary will be generated using
@ -76,7 +67,7 @@ def encode_multipart_formdata(fields, boundary=None):
boundary = choose_boundary() boundary = choose_boundary()
for field in iter_field_objects(fields): for field in iter_field_objects(fields):
body.write(b("--%s\r\n" % (boundary))) body.write(f"--{boundary}\r\n".encode("latin-1"))
writer(body).write(field.render_headers()) writer(body).write(field.render_headers())
data = field.data data = field.data
@ -84,15 +75,15 @@ def encode_multipart_formdata(fields, boundary=None):
if isinstance(data, int): if isinstance(data, int):
data = str(data) # Backwards compatibility data = str(data) # Backwards compatibility
if isinstance(data, six.text_type): if isinstance(data, str):
writer(body).write(data) writer(body).write(data)
else: else:
body.write(data) body.write(data)
body.write(b"\r\n") body.write(b"\r\n")
body.write(b("--%s--\r\n" % (boundary))) body.write(f"--{boundary}--\r\n".encode("latin-1"))
content_type = str("multipart/form-data; boundary=%s" % boundary) content_type = f"multipart/form-data; boundary={boundary}"
return body.getvalue(), content_type return body.getvalue(), content_type

View file

@ -1,16 +0,0 @@
from typing import Any, Generator, List, Mapping, Optional, Tuple, Union
from . import fields
RequestField = fields.RequestField
Fields = Union[Mapping[str, str], List[Tuple[str]], List[RequestField]]
Iterator = Generator[Tuple[str], None, None]
writer: Any
def choose_boundary() -> str: ...
def iter_field_objects(fields: Fields) -> Iterator: ...
def iter_fields(fields: Fields) -> Iterator: ...
def encode_multipart_formdata(
fields: Fields, boundary: Optional[str]
) -> Tuple[str]: ...

View file

@ -1,51 +0,0 @@
# -*- coding: utf-8 -*-
"""
backports.makefile
~~~~~~~~~~~~~~~~~~
Backports the Python 3 ``socket.makefile`` method for use with anything that
wants to create a "fake" socket object.
"""
import io
from socket import SocketIO
def backport_makefile(
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
):
"""
Backport of ``socket.makefile`` from Python 3.5.
"""
if not set(mode) <= {"r", "w", "b"}:
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
writing = "w" in mode
reading = "r" in mode or not writing
assert reading or writing
binary = "b" in mode
rawmode = ""
if reading:
rawmode += "r"
if writing:
rawmode += "w"
raw = SocketIO(self, rawmode)
self._makefile_refs += 1
if buffering is None:
buffering = -1
if buffering < 0:
buffering = io.DEFAULT_BUFFER_SIZE
if buffering == 0:
if not binary:
raise ValueError("unbuffered streams must be binary")
return raw
if reading and writing:
buffer = io.BufferedRWPair(raw, raw, buffering)
elif reading:
buffer = io.BufferedReader(raw, buffering)
else:
assert writing
buffer = io.BufferedWriter(raw, buffering)
if binary:
return buffer
text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode
return text

View file

@ -1,155 +0,0 @@
# -*- coding: utf-8 -*-
"""
backports.weakref_finalize
~~~~~~~~~~~~~~~~~~
Backports the Python 3 ``weakref.finalize`` method.
"""
from __future__ import absolute_import
import itertools
import sys
from weakref import ref
__all__ = ["weakref_finalize"]
class weakref_finalize(object):
"""Class for finalization of weakrefable objects
finalize(obj, func, *args, **kwargs) returns a callable finalizer
object which will be called when obj is garbage collected. The
first time the finalizer is called it evaluates func(*arg, **kwargs)
and returns the result. After this the finalizer is dead, and
calling it just returns None.
When the program exits any remaining finalizers for which the
atexit attribute is true will be run in reverse order of creation.
By default atexit is true.
"""
# Finalizer objects don't have any state of their own. They are
# just used as keys to lookup _Info objects in the registry. This
# ensures that they cannot be part of a ref-cycle.
__slots__ = ()
_registry = {}
_shutdown = False
_index_iter = itertools.count()
_dirty = False
_registered_with_atexit = False
class _Info(object):
__slots__ = ("weakref", "func", "args", "kwargs", "atexit", "index")
def __init__(self, obj, func, *args, **kwargs):
if not self._registered_with_atexit:
# We may register the exit function more than once because
# of a thread race, but that is harmless
import atexit
atexit.register(self._exitfunc)
weakref_finalize._registered_with_atexit = True
info = self._Info()
info.weakref = ref(obj, self)
info.func = func
info.args = args
info.kwargs = kwargs or None
info.atexit = True
info.index = next(self._index_iter)
self._registry[self] = info
weakref_finalize._dirty = True
def __call__(self, _=None):
"""If alive then mark as dead and return func(*args, **kwargs);
otherwise return None"""
info = self._registry.pop(self, None)
if info and not self._shutdown:
return info.func(*info.args, **(info.kwargs or {}))
def detach(self):
"""If alive then mark as dead and return (obj, func, args, kwargs);
otherwise return None"""
info = self._registry.get(self)
obj = info and info.weakref()
if obj is not None and self._registry.pop(self, None):
return (obj, info.func, info.args, info.kwargs or {})
def peek(self):
"""If alive then return (obj, func, args, kwargs);
otherwise return None"""
info = self._registry.get(self)
obj = info and info.weakref()
if obj is not None:
return (obj, info.func, info.args, info.kwargs or {})
@property
def alive(self):
"""Whether finalizer is alive"""
return self in self._registry
@property
def atexit(self):
"""Whether finalizer should be called at exit"""
info = self._registry.get(self)
return bool(info) and info.atexit
@atexit.setter
def atexit(self, value):
info = self._registry.get(self)
if info:
info.atexit = bool(value)
def __repr__(self):
info = self._registry.get(self)
obj = info and info.weakref()
if obj is None:
return "<%s object at %#x; dead>" % (type(self).__name__, id(self))
else:
return "<%s object at %#x; for %r at %#x>" % (
type(self).__name__,
id(self),
type(obj).__name__,
id(obj),
)
@classmethod
def _select_for_exit(cls):
# Return live finalizers marked for exit, oldest first
L = [(f, i) for (f, i) in cls._registry.items() if i.atexit]
L.sort(key=lambda item: item[1].index)
return [f for (f, i) in L]
@classmethod
def _exitfunc(cls):
# At shutdown invoke finalizers for which atexit is true.
# This is called once all other non-daemonic threads have been
# joined.
reenable_gc = False
try:
if cls._registry:
import gc
if gc.isenabled():
reenable_gc = True
gc.disable()
pending = None
while True:
if pending is None or weakref_finalize._dirty:
pending = cls._select_for_exit()
weakref_finalize._dirty = False
if not pending:
break
f = pending.pop()
try:
# gc is disabled, so (assuming no daemonic
# threads) the following is the only line in
# this function which might trigger creation
# of a new finalizer
f()
except Exception:
sys.excepthook(*sys.exc_info())
assert f not in cls._registry
finally:
# prevent any more finalizers from executing during shutdown
weakref_finalize._shutdown = True
if reenable_gc:
gc.enable()

File diff suppressed because it is too large Load diff

View file

@ -1,24 +1,33 @@
from __future__ import absolute_import from __future__ import annotations
import collections
import functools import functools
import logging import logging
import typing
import warnings
from types import TracebackType
from urllib.parse import urljoin
from ._collections import RecentlyUsedContainer from ._collections import RecentlyUsedContainer
from ._request_methods import RequestMethods
from .connection import ProxyConfig
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, port_by_scheme from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, port_by_scheme
from .exceptions import ( from .exceptions import (
LocationValueError, LocationValueError,
MaxRetryError, MaxRetryError,
ProxySchemeUnknown, ProxySchemeUnknown,
ProxySchemeUnsupported,
URLSchemeUnknown, URLSchemeUnknown,
) )
from .packages import six from .response import BaseHTTPResponse
from .packages.six.moves.urllib.parse import urljoin from .util.connection import _TYPE_SOCKET_OPTIONS
from .request import RequestMethods
from .util.proxy import connection_requires_http_tunnel from .util.proxy import connection_requires_http_tunnel
from .util.retry import Retry from .util.retry import Retry
from .util.url import parse_url from .util.timeout import Timeout
from .util.url import Url, parse_url
if typing.TYPE_CHECKING:
import ssl
from typing_extensions import Literal
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"] __all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
@ -31,52 +40,61 @@ SSL_KEYWORDS = (
"cert_reqs", "cert_reqs",
"ca_certs", "ca_certs",
"ssl_version", "ssl_version",
"ssl_minimum_version",
"ssl_maximum_version",
"ca_cert_dir", "ca_cert_dir",
"ssl_context", "ssl_context",
"key_password", "key_password",
"server_hostname", "server_hostname",
) )
# Default value for `blocksize` - a new parameter introduced to
# http.client.HTTPConnection & http.client.HTTPSConnection in Python 3.7
_DEFAULT_BLOCKSIZE = 16384
# All known keyword arguments that could be provided to the pool manager, its _SelfT = typing.TypeVar("_SelfT")
# pools, or the underlying connections. This is used to construct a pool key.
_key_fields = (
"key_scheme", # str
"key_host", # str
"key_port", # int
"key_timeout", # int or float or Timeout
"key_retries", # int or Retry
"key_strict", # bool
"key_block", # bool
"key_source_address", # str
"key_key_file", # str
"key_key_password", # str
"key_cert_file", # str
"key_cert_reqs", # str
"key_ca_certs", # str
"key_ssl_version", # str
"key_ca_cert_dir", # str
"key_ssl_context", # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
"key_maxsize", # int
"key_headers", # dict
"key__proxy", # parsed proxy url
"key__proxy_headers", # dict
"key__proxy_config", # class
"key_socket_options", # list of (level (int), optname (int), value (int or str)) tuples
"key__socks_options", # dict
"key_assert_hostname", # bool or string
"key_assert_fingerprint", # str
"key_server_hostname", # str
)
#: The namedtuple class used to construct keys for the connection pool.
#: All custom key schemes should include the fields in this key at a minimum.
PoolKey = collections.namedtuple("PoolKey", _key_fields)
_proxy_config_fields = ("ssl_context", "use_forwarding_for_https")
ProxyConfig = collections.namedtuple("ProxyConfig", _proxy_config_fields)
def _default_key_normalizer(key_class, request_context): class PoolKey(typing.NamedTuple):
"""
All known keyword arguments that could be provided to the pool manager, its
pools, or the underlying connections.
All custom key schemes should include the fields in this key at a minimum.
"""
key_scheme: str
key_host: str
key_port: int | None
key_timeout: Timeout | float | int | None
key_retries: Retry | bool | int | None
key_block: bool | None
key_source_address: tuple[str, int] | None
key_key_file: str | None
key_key_password: str | None
key_cert_file: str | None
key_cert_reqs: str | None
key_ca_certs: str | None
key_ssl_version: int | str | None
key_ssl_minimum_version: ssl.TLSVersion | None
key_ssl_maximum_version: ssl.TLSVersion | None
key_ca_cert_dir: str | None
key_ssl_context: ssl.SSLContext | None
key_maxsize: int | None
key_headers: frozenset[tuple[str, str]] | None
key__proxy: Url | None
key__proxy_headers: frozenset[tuple[str, str]] | None
key__proxy_config: ProxyConfig | None
key_socket_options: _TYPE_SOCKET_OPTIONS | None
key__socks_options: frozenset[tuple[str, str]] | None
key_assert_hostname: bool | str | None
key_assert_fingerprint: str | None
key_server_hostname: str | None
key_blocksize: int | None
def _default_key_normalizer(
key_class: type[PoolKey], request_context: dict[str, typing.Any]
) -> PoolKey:
""" """
Create a pool key out of a request context dictionary. Create a pool key out of a request context dictionary.
@ -122,6 +140,10 @@ def _default_key_normalizer(key_class, request_context):
if field not in context: if field not in context:
context[field] = None context[field] = None
# Default key_blocksize to _DEFAULT_BLOCKSIZE if missing from the context
if context.get("key_blocksize") is None:
context["key_blocksize"] = _DEFAULT_BLOCKSIZE
return key_class(**context) return key_class(**context)
@ -154,23 +176,36 @@ class PoolManager(RequestMethods):
Additional parameters are used to create fresh Additional parameters are used to create fresh
:class:`urllib3.connectionpool.ConnectionPool` instances. :class:`urllib3.connectionpool.ConnectionPool` instances.
Example:: Example:
>>> manager = PoolManager(num_pools=2) .. code-block:: python
>>> r = manager.request('GET', 'http://google.com/')
>>> r = manager.request('GET', 'http://google.com/mail') import urllib3
>>> r = manager.request('GET', 'http://yahoo.com/')
>>> len(manager.pools) http = urllib3.PoolManager(num_pools=2)
2
resp1 = http.request("GET", "https://google.com/")
resp2 = http.request("GET", "https://google.com/mail")
resp3 = http.request("GET", "https://yahoo.com/")
print(len(http.pools))
# 2
""" """
proxy = None proxy: Url | None = None
proxy_config = None proxy_config: ProxyConfig | None = None
def __init__(self, num_pools=10, headers=None, **connection_pool_kw): def __init__(
RequestMethods.__init__(self, headers) self,
num_pools: int = 10,
headers: typing.Mapping[str, str] | None = None,
**connection_pool_kw: typing.Any,
) -> None:
super().__init__(headers)
self.connection_pool_kw = connection_pool_kw self.connection_pool_kw = connection_pool_kw
self.pools: RecentlyUsedContainer[PoolKey, HTTPConnectionPool]
self.pools = RecentlyUsedContainer(num_pools) self.pools = RecentlyUsedContainer(num_pools)
# Locally set the pool classes and keys so other PoolManagers can # Locally set the pool classes and keys so other PoolManagers can
@ -178,15 +213,26 @@ class PoolManager(RequestMethods):
self.pool_classes_by_scheme = pool_classes_by_scheme self.pool_classes_by_scheme = pool_classes_by_scheme
self.key_fn_by_scheme = key_fn_by_scheme.copy() self.key_fn_by_scheme = key_fn_by_scheme.copy()
def __enter__(self): def __enter__(self: _SelfT) -> _SelfT:
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> Literal[False]:
self.clear() self.clear()
# Return False to re-raise any potential exceptions # Return False to re-raise any potential exceptions
return False return False
def _new_pool(self, scheme, host, port, request_context=None): def _new_pool(
self,
scheme: str,
host: str,
port: int,
request_context: dict[str, typing.Any] | None = None,
) -> HTTPConnectionPool:
""" """
Create a new :class:`urllib3.connectionpool.ConnectionPool` based on host, port, scheme, and Create a new :class:`urllib3.connectionpool.ConnectionPool` based on host, port, scheme, and
any additional pool keyword arguments. any additional pool keyword arguments.
@ -196,10 +242,15 @@ class PoolManager(RequestMethods):
connection pools handed out by :meth:`connection_from_url` and connection pools handed out by :meth:`connection_from_url` and
companion methods. It is intended to be overridden for customization. companion methods. It is intended to be overridden for customization.
""" """
pool_cls = self.pool_classes_by_scheme[scheme] pool_cls: type[HTTPConnectionPool] = self.pool_classes_by_scheme[scheme]
if request_context is None: if request_context is None:
request_context = self.connection_pool_kw.copy() request_context = self.connection_pool_kw.copy()
# Default blocksize to _DEFAULT_BLOCKSIZE if missing or explicitly
# set to 'None' in the request_context.
if request_context.get("blocksize") is None:
request_context["blocksize"] = _DEFAULT_BLOCKSIZE
# Although the context has everything necessary to create the pool, # Although the context has everything necessary to create the pool,
# this function has historically only used the scheme, host, and port # this function has historically only used the scheme, host, and port
# in the positional args. When an API change is acceptable these can # in the positional args. When an API change is acceptable these can
@ -213,7 +264,7 @@ class PoolManager(RequestMethods):
return pool_cls(host, port, **request_context) return pool_cls(host, port, **request_context)
def clear(self): def clear(self) -> None:
""" """
Empty our store of pools and direct them all to close. Empty our store of pools and direct them all to close.
@ -222,7 +273,13 @@ class PoolManager(RequestMethods):
""" """
self.pools.clear() self.pools.clear()
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None): def connection_from_host(
self,
host: str | None,
port: int | None = None,
scheme: str | None = "http",
pool_kwargs: dict[str, typing.Any] | None = None,
) -> HTTPConnectionPool:
""" """
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the host, port, and scheme. Get a :class:`urllib3.connectionpool.ConnectionPool` based on the host, port, and scheme.
@ -245,13 +302,23 @@ class PoolManager(RequestMethods):
return self.connection_from_context(request_context) return self.connection_from_context(request_context)
def connection_from_context(self, request_context): def connection_from_context(
self, request_context: dict[str, typing.Any]
) -> HTTPConnectionPool:
""" """
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the request context. Get a :class:`urllib3.connectionpool.ConnectionPool` based on the request context.
``request_context`` must at least contain the ``scheme`` key and its ``request_context`` must at least contain the ``scheme`` key and its
value must be a key in ``key_fn_by_scheme`` instance variable. value must be a key in ``key_fn_by_scheme`` instance variable.
""" """
if "strict" in request_context:
warnings.warn(
"The 'strict' parameter is no longer needed on Python 3+. "
"This will raise an error in urllib3 v2.1.0.",
DeprecationWarning,
)
request_context.pop("strict")
scheme = request_context["scheme"].lower() scheme = request_context["scheme"].lower()
pool_key_constructor = self.key_fn_by_scheme.get(scheme) pool_key_constructor = self.key_fn_by_scheme.get(scheme)
if not pool_key_constructor: if not pool_key_constructor:
@ -260,7 +327,9 @@ class PoolManager(RequestMethods):
return self.connection_from_pool_key(pool_key, request_context=request_context) return self.connection_from_pool_key(pool_key, request_context=request_context)
def connection_from_pool_key(self, pool_key, request_context=None): def connection_from_pool_key(
self, pool_key: PoolKey, request_context: dict[str, typing.Any]
) -> HTTPConnectionPool:
""" """
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the provided pool key. Get a :class:`urllib3.connectionpool.ConnectionPool` based on the provided pool key.
@ -284,7 +353,9 @@ class PoolManager(RequestMethods):
return pool return pool
def connection_from_url(self, url, pool_kwargs=None): def connection_from_url(
self, url: str, pool_kwargs: dict[str, typing.Any] | None = None
) -> HTTPConnectionPool:
""" """
Similar to :func:`urllib3.connectionpool.connection_from_url`. Similar to :func:`urllib3.connectionpool.connection_from_url`.
@ -300,7 +371,9 @@ class PoolManager(RequestMethods):
u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs
) )
def _merge_pool_kwargs(self, override): def _merge_pool_kwargs(
self, override: dict[str, typing.Any] | None
) -> dict[str, typing.Any]:
""" """
Merge a dictionary of override values for self.connection_pool_kw. Merge a dictionary of override values for self.connection_pool_kw.
@ -320,7 +393,7 @@ class PoolManager(RequestMethods):
base_pool_kwargs[key] = value base_pool_kwargs[key] = value
return base_pool_kwargs return base_pool_kwargs
def _proxy_requires_url_absolute_form(self, parsed_url): def _proxy_requires_url_absolute_form(self, parsed_url: Url) -> bool:
""" """
Indicates if the proxy requires the complete destination URL in the Indicates if the proxy requires the complete destination URL in the
request. Normally this is only needed when not using an HTTP CONNECT request. Normally this is only needed when not using an HTTP CONNECT
@ -333,24 +406,9 @@ class PoolManager(RequestMethods):
self.proxy, self.proxy_config, parsed_url.scheme self.proxy, self.proxy_config, parsed_url.scheme
) )
def _validate_proxy_scheme_url_selection(self, url_scheme): def urlopen( # type: ignore[override]
""" self, method: str, url: str, redirect: bool = True, **kw: typing.Any
Validates that were not attempting to do TLS in TLS connections on ) -> BaseHTTPResponse:
Python2 or with unsupported SSL implementations.
"""
if self.proxy is None or url_scheme != "https":
return
if self.proxy.scheme != "https":
return
if six.PY2 and not self.proxy_config.use_forwarding_for_https:
raise ProxySchemeUnsupported(
"Contacting HTTPS destinations through HTTPS proxies "
"'via CONNECT tunnels' is not supported in Python 2"
)
def urlopen(self, method, url, redirect=True, **kw):
""" """
Same as :meth:`urllib3.HTTPConnectionPool.urlopen` Same as :meth:`urllib3.HTTPConnectionPool.urlopen`
with custom cross-host redirect logic and only sends the request-uri with custom cross-host redirect logic and only sends the request-uri
@ -360,7 +418,16 @@ class PoolManager(RequestMethods):
:class:`urllib3.connectionpool.ConnectionPool` can be chosen for it. :class:`urllib3.connectionpool.ConnectionPool` can be chosen for it.
""" """
u = parse_url(url) u = parse_url(url)
self._validate_proxy_scheme_url_selection(u.scheme)
if u.scheme is None:
warnings.warn(
"URLs without a scheme (ie 'https://') are deprecated and will raise an error "
"in a future version of urllib3. To avoid this DeprecationWarning ensure all URLs "
"start with 'https://' or 'http://'. Read more in this issue: "
"https://github.com/urllib3/urllib3/issues/2920",
category=DeprecationWarning,
stacklevel=2,
)
conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme) conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme)
@ -368,7 +435,7 @@ class PoolManager(RequestMethods):
kw["redirect"] = False kw["redirect"] = False
if "headers" not in kw: if "headers" not in kw:
kw["headers"] = self.headers.copy() kw["headers"] = self.headers
if self._proxy_requires_url_absolute_form(u): if self._proxy_requires_url_absolute_form(u):
response = conn.urlopen(method, url, **kw) response = conn.urlopen(method, url, **kw)
@ -396,10 +463,11 @@ class PoolManager(RequestMethods):
if retries.remove_headers_on_redirect and not conn.is_same_host( if retries.remove_headers_on_redirect and not conn.is_same_host(
redirect_location redirect_location
): ):
headers = list(six.iterkeys(kw["headers"])) new_headers = kw["headers"].copy()
for header in headers: for header in kw["headers"]:
if header.lower() in retries.remove_headers_on_redirect: if header.lower() in retries.remove_headers_on_redirect:
kw["headers"].pop(header, None) new_headers.pop(header, None)
kw["headers"] = new_headers
try: try:
retries = retries.increment(method, url, response=response, _pool=conn) retries = retries.increment(method, url, response=response, _pool=conn)
@ -445,37 +513,51 @@ class ProxyManager(PoolManager):
private. IP address, target hostname, SNI, and port are always visible private. IP address, target hostname, SNI, and port are always visible
to an HTTPS proxy even when this flag is disabled. to an HTTPS proxy even when this flag is disabled.
:param proxy_assert_hostname:
The hostname of the certificate to verify against.
:param proxy_assert_fingerprint:
The fingerprint of the certificate to verify against.
Example: Example:
>>> proxy = urllib3.ProxyManager('http://localhost:3128/')
>>> r1 = proxy.request('GET', 'http://google.com/') .. code-block:: python
>>> r2 = proxy.request('GET', 'http://httpbin.org/')
>>> len(proxy.pools) import urllib3
1
>>> r3 = proxy.request('GET', 'https://httpbin.org/') proxy = urllib3.ProxyManager("https://localhost:3128/")
>>> r4 = proxy.request('GET', 'https://twitter.com/')
>>> len(proxy.pools) resp1 = proxy.request("GET", "https://google.com/")
3 resp2 = proxy.request("GET", "https://httpbin.org/")
print(len(proxy.pools))
# 1
resp3 = proxy.request("GET", "https://httpbin.org/")
resp4 = proxy.request("GET", "https://twitter.com/")
print(len(proxy.pools))
# 3
""" """
def __init__( def __init__(
self, self,
proxy_url, proxy_url: str,
num_pools=10, num_pools: int = 10,
headers=None, headers: typing.Mapping[str, str] | None = None,
proxy_headers=None, proxy_headers: typing.Mapping[str, str] | None = None,
proxy_ssl_context=None, proxy_ssl_context: ssl.SSLContext | None = None,
use_forwarding_for_https=False, use_forwarding_for_https: bool = False,
**connection_pool_kw proxy_assert_hostname: None | str | Literal[False] = None,
): proxy_assert_fingerprint: str | None = None,
**connection_pool_kw: typing.Any,
) -> None:
if isinstance(proxy_url, HTTPConnectionPool): if isinstance(proxy_url, HTTPConnectionPool):
proxy_url = "%s://%s:%i" % ( str_proxy_url = f"{proxy_url.scheme}://{proxy_url.host}:{proxy_url.port}"
proxy_url.scheme, else:
proxy_url.host, str_proxy_url = proxy_url
proxy_url.port, proxy = parse_url(str_proxy_url)
)
proxy = parse_url(proxy_url)
if proxy.scheme not in ("http", "https"): if proxy.scheme not in ("http", "https"):
raise ProxySchemeUnknown(proxy.scheme) raise ProxySchemeUnknown(proxy.scheme)
@ -487,25 +569,38 @@ class ProxyManager(PoolManager):
self.proxy = proxy self.proxy = proxy
self.proxy_headers = proxy_headers or {} self.proxy_headers = proxy_headers or {}
self.proxy_ssl_context = proxy_ssl_context self.proxy_ssl_context = proxy_ssl_context
self.proxy_config = ProxyConfig(proxy_ssl_context, use_forwarding_for_https) self.proxy_config = ProxyConfig(
proxy_ssl_context,
use_forwarding_for_https,
proxy_assert_hostname,
proxy_assert_fingerprint,
)
connection_pool_kw["_proxy"] = self.proxy connection_pool_kw["_proxy"] = self.proxy
connection_pool_kw["_proxy_headers"] = self.proxy_headers connection_pool_kw["_proxy_headers"] = self.proxy_headers
connection_pool_kw["_proxy_config"] = self.proxy_config connection_pool_kw["_proxy_config"] = self.proxy_config
super(ProxyManager, self).__init__(num_pools, headers, **connection_pool_kw) super().__init__(num_pools, headers, **connection_pool_kw)
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None): def connection_from_host(
self,
host: str | None,
port: int | None = None,
scheme: str | None = "http",
pool_kwargs: dict[str, typing.Any] | None = None,
) -> HTTPConnectionPool:
if scheme == "https": if scheme == "https":
return super(ProxyManager, self).connection_from_host( return super().connection_from_host(
host, port, scheme, pool_kwargs=pool_kwargs host, port, scheme, pool_kwargs=pool_kwargs
) )
return super(ProxyManager, self).connection_from_host( return super().connection_from_host(
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs # type: ignore[union-attr]
) )
def _set_proxy_headers(self, url, headers=None): def _set_proxy_headers(
self, url: str, headers: typing.Mapping[str, str] | None = None
) -> typing.Mapping[str, str]:
""" """
Sets headers needed by proxies: specifically, the Accept and Host Sets headers needed by proxies: specifically, the Accept and Host
headers. Only sets headers not provided by the user. headers. Only sets headers not provided by the user.
@ -520,7 +615,9 @@ class ProxyManager(PoolManager):
headers_.update(headers) headers_.update(headers)
return headers_ return headers_
def urlopen(self, method, url, redirect=True, **kw): def urlopen( # type: ignore[override]
self, method: str, url: str, redirect: bool = True, **kw: typing.Any
) -> BaseHTTPResponse:
"Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute." "Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute."
u = parse_url(url) u = parse_url(url)
if not connection_requires_http_tunnel(self.proxy, self.proxy_config, u.scheme): if not connection_requires_http_tunnel(self.proxy, self.proxy_config, u.scheme):
@ -530,8 +627,8 @@ class ProxyManager(PoolManager):
headers = kw.get("headers", self.headers) headers = kw.get("headers", self.headers)
kw["headers"] = self._set_proxy_headers(url, headers) kw["headers"] = self._set_proxy_headers(url, headers)
return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw) return super().urlopen(method, url, redirect=redirect, **kw)
def proxy_from_url(url, **kw): def proxy_from_url(url: str, **kw: typing.Any) -> ProxyManager:
return ProxyManager(proxy_url=url, **kw) return ProxyManager(proxy_url=url, **kw)

2
lib/urllib3/py.typed Normal file
View file

@ -0,0 +1,2 @@
# Instruct type checkers to look for inline type annotations in this package.
# See PEP 561.

File diff suppressed because it is too large Load diff

View file

@ -1,46 +1,41 @@
from __future__ import absolute_import
# For backwards compatibility, provide imports that used to be here. # For backwards compatibility, provide imports that used to be here.
from __future__ import annotations
from .connection import is_connection_dropped from .connection import is_connection_dropped
from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers
from .response import is_fp_closed from .response import is_fp_closed
from .retry import Retry from .retry import Retry
from .ssl_ import ( from .ssl_ import (
ALPN_PROTOCOLS, ALPN_PROTOCOLS,
HAS_SNI,
IS_PYOPENSSL, IS_PYOPENSSL,
IS_SECURETRANSPORT, IS_SECURETRANSPORT,
PROTOCOL_TLS,
SSLContext, SSLContext,
assert_fingerprint, assert_fingerprint,
create_urllib3_context,
resolve_cert_reqs, resolve_cert_reqs,
resolve_ssl_version, resolve_ssl_version,
ssl_wrap_socket, ssl_wrap_socket,
) )
from .timeout import Timeout, current_time from .timeout import Timeout
from .url import Url, get_host, parse_url, split_first from .url import Url, parse_url
from .wait import wait_for_read, wait_for_write from .wait import wait_for_read, wait_for_write
__all__ = ( __all__ = (
"HAS_SNI",
"IS_PYOPENSSL", "IS_PYOPENSSL",
"IS_SECURETRANSPORT", "IS_SECURETRANSPORT",
"SSLContext", "SSLContext",
"PROTOCOL_TLS",
"ALPN_PROTOCOLS", "ALPN_PROTOCOLS",
"Retry", "Retry",
"Timeout", "Timeout",
"Url", "Url",
"assert_fingerprint", "assert_fingerprint",
"current_time", "create_urllib3_context",
"is_connection_dropped", "is_connection_dropped",
"is_fp_closed", "is_fp_closed",
"get_host",
"parse_url", "parse_url",
"make_headers", "make_headers",
"resolve_cert_reqs", "resolve_cert_reqs",
"resolve_ssl_version", "resolve_ssl_version",
"split_first",
"ssl_wrap_socket", "ssl_wrap_socket",
"wait_for_read", "wait_for_read",
"wait_for_write", "wait_for_write",

View file

@ -1,33 +1,23 @@
from __future__ import absolute_import from __future__ import annotations
import socket import socket
import typing
from ..contrib import _appengine_environ
from ..exceptions import LocationParseError from ..exceptions import LocationParseError
from ..packages import six from .timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT
from .wait import NoWayToWaitForSocketError, wait_for_read
_TYPE_SOCKET_OPTIONS = typing.Sequence[typing.Tuple[int, int, typing.Union[int, bytes]]]
if typing.TYPE_CHECKING:
from .._base_connection import BaseHTTPConnection
def is_connection_dropped(conn): # Platform-specific def is_connection_dropped(conn: BaseHTTPConnection) -> bool: # Platform-specific
""" """
Returns True if the connection is dropped and should be closed. Returns True if the connection is dropped and should be closed.
:param conn: :class:`urllib3.connection.HTTPConnection` object.
:param conn:
:class:`http.client.HTTPConnection` object.
Note: For platforms like AppEngine, this will always return ``False`` to
let the platform handle connection recycling transparently for us.
""" """
sock = getattr(conn, "sock", False) return not conn.is_connected
if sock is False: # Platform-specific: AppEngine
return False
if sock is None: # Connection already closed (such as by httplib).
return True
try:
# Returns True if readable, which here means it's been dropped
return wait_for_read(sock, timeout=0.0)
except NoWayToWaitForSocketError: # Platform-specific: AppEngine
return False
# This function is copied from socket.py in the Python 2.7 standard # This function is copied from socket.py in the Python 2.7 standard
@ -35,11 +25,11 @@ def is_connection_dropped(conn): # Platform-specific
# One additional modification is that we avoid binding to IPv6 servers # One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality. # discovered in DNS if the system doesn't have IPv6 functionality.
def create_connection( def create_connection(
address, address: tuple[str, int],
timeout=socket._GLOBAL_DEFAULT_TIMEOUT, timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
source_address=None, source_address: tuple[str, int] | None = None,
socket_options=None, socket_options: _TYPE_SOCKET_OPTIONS | None = None,
): ) -> socket.socket:
"""Connect to *address* and return the socket object. """Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host, Convenience function. Connect to *address* (a 2-tuple ``(host,
@ -65,9 +55,7 @@ def create_connection(
try: try:
host.encode("idna") host.encode("idna")
except UnicodeError: except UnicodeError:
return six.raise_from( raise LocationParseError(f"'{host}', label empty or too long") from None
LocationParseError(u"'%s', label empty or too long" % host), None
)
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res af, socktype, proto, canonname, sa = res
@ -78,26 +66,33 @@ def create_connection(
# If provided, set socket level options before connecting. # If provided, set socket level options before connecting.
_set_socket_options(sock, socket_options) _set_socket_options(sock, socket_options)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: if timeout is not _DEFAULT_TIMEOUT:
sock.settimeout(timeout) sock.settimeout(timeout)
if source_address: if source_address:
sock.bind(source_address) sock.bind(source_address)
sock.connect(sa) sock.connect(sa)
# Break explicitly a reference cycle
err = None
return sock return sock
except socket.error as e: except OSError as _:
err = e err = _
if sock is not None: if sock is not None:
sock.close() sock.close()
sock = None
if err is not None: if err is not None:
raise err try:
raise err
raise socket.error("getaddrinfo returns an empty list") finally:
# Break explicitly a reference cycle
err = None
else:
raise OSError("getaddrinfo returns an empty list")
def _set_socket_options(sock, options): def _set_socket_options(
sock: socket.socket, options: _TYPE_SOCKET_OPTIONS | None
) -> None:
if options is None: if options is None:
return return
@ -105,7 +100,7 @@ def _set_socket_options(sock, options):
sock.setsockopt(*opt) sock.setsockopt(*opt)
def allowed_gai_family(): def allowed_gai_family() -> socket.AddressFamily:
"""This function is designed to work in the context of """This function is designed to work in the context of
getaddrinfo, where family=socket.AF_UNSPEC is the default and getaddrinfo, where family=socket.AF_UNSPEC is the default and
will perform a DNS search for both IPv6 and IPv4 records.""" will perform a DNS search for both IPv6 and IPv4 records."""
@ -116,18 +111,11 @@ def allowed_gai_family():
return family return family
def _has_ipv6(host): def _has_ipv6(host: str) -> bool:
"""Returns True if the system can bind an IPv6 address.""" """Returns True if the system can bind an IPv6 address."""
sock = None sock = None
has_ipv6 = False has_ipv6 = False
# App Engine doesn't support IPV6 sockets and actually has a quota on the
# number of sockets that can be used, so just early out here instead of
# creating a socket needlessly.
# See https://github.com/urllib3/urllib3/issues/1446
if _appengine_environ.is_appengine_sandbox():
return False
if socket.has_ipv6: if socket.has_ipv6:
# has_ipv6 returns true if cPython was compiled with IPv6 support. # has_ipv6 returns true if cPython was compiled with IPv6 support.
# It does not tell us if the system has IPv6 support enabled. To # It does not tell us if the system has IPv6 support enabled. To

View file

@ -1,9 +1,18 @@
from .ssl_ import create_urllib3_context, resolve_cert_reqs, resolve_ssl_version from __future__ import annotations
import typing
from .url import Url
if typing.TYPE_CHECKING:
from ..connection import ProxyConfig
def connection_requires_http_tunnel( def connection_requires_http_tunnel(
proxy_url=None, proxy_config=None, destination_scheme=None proxy_url: Url | None = None,
): proxy_config: ProxyConfig | None = None,
destination_scheme: str | None = None,
) -> bool:
""" """
Returns True if the connection requires an HTTP CONNECT through the proxy. Returns True if the connection requires an HTTP CONNECT through the proxy.
@ -32,26 +41,3 @@ def connection_requires_http_tunnel(
# Otherwise always use a tunnel. # Otherwise always use a tunnel.
return True return True
def create_proxy_ssl_context(
ssl_version, cert_reqs, ca_certs=None, ca_cert_dir=None, ca_cert_data=None
):
"""
Generates a default proxy ssl context if one hasn't been provided by the
user.
"""
ssl_context = create_urllib3_context(
ssl_version=resolve_ssl_version(ssl_version),
cert_reqs=resolve_cert_reqs(cert_reqs),
)
if (
not ca_certs
and not ca_cert_dir
and not ca_cert_data
and hasattr(ssl_context, "load_default_certs")
):
ssl_context.load_default_certs()
return ssl_context

View file

@ -1,22 +0,0 @@
import collections
from ..packages import six
from ..packages.six.moves import queue
if six.PY2:
# Queue is imported for side effects on MS Windows. See issue #229.
import Queue as _unused_module_Queue # noqa: F401
class LifoQueue(queue.Queue):
def _init(self, _):
self.queue = collections.deque()
def _qsize(self, len=len):
return len(self.queue)
def _put(self, item):
self.queue.append(item)
def _get(self):
return self.queue.pop()

View file

@ -1,9 +1,15 @@
from __future__ import absolute_import from __future__ import annotations
import io
import typing
from base64 import b64encode from base64 import b64encode
from enum import Enum
from ..exceptions import UnrewindableBodyError from ..exceptions import UnrewindableBodyError
from ..packages.six import b, integer_types from .util import to_bytes
if typing.TYPE_CHECKING:
from typing_extensions import Final
# Pass as a value within ``headers`` to skip # Pass as a value within ``headers`` to skip
# emitting some HTTP headers that are added automatically. # emitting some HTTP headers that are added automatically.
@ -15,25 +21,45 @@ SKIPPABLE_HEADERS = frozenset(["accept-encoding", "host", "user-agent"])
ACCEPT_ENCODING = "gzip,deflate" ACCEPT_ENCODING = "gzip,deflate"
try: try:
try: try:
import brotlicffi as _unused_module_brotli # noqa: F401 import brotlicffi as _unused_module_brotli # type: ignore[import] # noqa: F401
except ImportError: except ImportError:
import brotli as _unused_module_brotli # noqa: F401 import brotli as _unused_module_brotli # type: ignore[import] # noqa: F401
except ImportError: except ImportError:
pass pass
else: else:
ACCEPT_ENCODING += ",br" ACCEPT_ENCODING += ",br"
try:
import zstandard as _unused_module_zstd # type: ignore[import] # noqa: F401
except ImportError:
pass
else:
ACCEPT_ENCODING += ",zstd"
_FAILEDTELL = object()
class _TYPE_FAILEDTELL(Enum):
token = 0
_FAILEDTELL: Final[_TYPE_FAILEDTELL] = _TYPE_FAILEDTELL.token
_TYPE_BODY_POSITION = typing.Union[int, _TYPE_FAILEDTELL]
# When sending a request with these methods we aren't expecting
# a body so don't need to set an explicit 'Content-Length: 0'
# The reason we do this in the negative instead of tracking methods
# which 'should' have a body is because unknown methods should be
# treated as if they were 'POST' which *does* expect a body.
_METHODS_NOT_EXPECTING_BODY = {"GET", "HEAD", "DELETE", "TRACE", "OPTIONS", "CONNECT"}
def make_headers( def make_headers(
keep_alive=None, keep_alive: bool | None = None,
accept_encoding=None, accept_encoding: bool | list[str] | str | None = None,
user_agent=None, user_agent: str | None = None,
basic_auth=None, basic_auth: str | None = None,
proxy_basic_auth=None, proxy_basic_auth: str | None = None,
disable_cache=None, disable_cache: bool | None = None,
): ) -> dict[str, str]:
""" """
Shortcuts for generating request headers. Shortcuts for generating request headers.
@ -42,7 +68,8 @@ def make_headers(
:param accept_encoding: :param accept_encoding:
Can be a boolean, list, or string. Can be a boolean, list, or string.
``True`` translates to 'gzip,deflate'. ``True`` translates to 'gzip,deflate'. If either the ``brotli`` or
``brotlicffi`` package is installed 'gzip,deflate,br' is used instead.
List will get joined by comma. List will get joined by comma.
String will be used as provided. String will be used as provided.
@ -61,14 +88,18 @@ def make_headers(
:param disable_cache: :param disable_cache:
If ``True``, adds 'cache-control: no-cache' header. If ``True``, adds 'cache-control: no-cache' header.
Example:: Example:
>>> make_headers(keep_alive=True, user_agent="Batman/1.0") .. code-block:: python
{'connection': 'keep-alive', 'user-agent': 'Batman/1.0'}
>>> make_headers(accept_encoding=True) import urllib3
{'accept-encoding': 'gzip,deflate'}
print(urllib3.util.make_headers(keep_alive=True, user_agent="Batman/1.0"))
# {'connection': 'keep-alive', 'user-agent': 'Batman/1.0'}
print(urllib3.util.make_headers(accept_encoding=True))
# {'accept-encoding': 'gzip,deflate'}
""" """
headers = {} headers: dict[str, str] = {}
if accept_encoding: if accept_encoding:
if isinstance(accept_encoding, str): if isinstance(accept_encoding, str):
pass pass
@ -85,12 +116,14 @@ def make_headers(
headers["connection"] = "keep-alive" headers["connection"] = "keep-alive"
if basic_auth: if basic_auth:
headers["authorization"] = "Basic " + b64encode(b(basic_auth)).decode("utf-8") headers[
"authorization"
] = f"Basic {b64encode(basic_auth.encode('latin-1')).decode()}"
if proxy_basic_auth: if proxy_basic_auth:
headers["proxy-authorization"] = "Basic " + b64encode( headers[
b(proxy_basic_auth) "proxy-authorization"
).decode("utf-8") ] = f"Basic {b64encode(proxy_basic_auth.encode('latin-1')).decode()}"
if disable_cache: if disable_cache:
headers["cache-control"] = "no-cache" headers["cache-control"] = "no-cache"
@ -98,7 +131,9 @@ def make_headers(
return headers return headers
def set_file_position(body, pos): def set_file_position(
body: typing.Any, pos: _TYPE_BODY_POSITION | None
) -> _TYPE_BODY_POSITION | None:
""" """
If a position is provided, move file to that point. If a position is provided, move file to that point.
Otherwise, we'll attempt to record a position for future use. Otherwise, we'll attempt to record a position for future use.
@ -108,7 +143,7 @@ def set_file_position(body, pos):
elif getattr(body, "tell", None) is not None: elif getattr(body, "tell", None) is not None:
try: try:
pos = body.tell() pos = body.tell()
except (IOError, OSError): except OSError:
# This differentiates from None, allowing us to catch # This differentiates from None, allowing us to catch
# a failed `tell()` later when trying to rewind the body. # a failed `tell()` later when trying to rewind the body.
pos = _FAILEDTELL pos = _FAILEDTELL
@ -116,7 +151,7 @@ def set_file_position(body, pos):
return pos return pos
def rewind_body(body, body_pos): def rewind_body(body: typing.IO[typing.AnyStr], body_pos: _TYPE_BODY_POSITION) -> None:
""" """
Attempt to rewind body to a certain position. Attempt to rewind body to a certain position.
Primarily used for request redirects and retries. Primarily used for request redirects and retries.
@ -128,13 +163,13 @@ def rewind_body(body, body_pos):
Position to seek to in file. Position to seek to in file.
""" """
body_seek = getattr(body, "seek", None) body_seek = getattr(body, "seek", None)
if body_seek is not None and isinstance(body_pos, integer_types): if body_seek is not None and isinstance(body_pos, int):
try: try:
body_seek(body_pos) body_seek(body_pos)
except (IOError, OSError): except OSError as e:
raise UnrewindableBodyError( raise UnrewindableBodyError(
"An error occurred when rewinding request body for redirect/retry." "An error occurred when rewinding request body for redirect/retry."
) ) from e
elif body_pos is _FAILEDTELL: elif body_pos is _FAILEDTELL:
raise UnrewindableBodyError( raise UnrewindableBodyError(
"Unable to record file position for rewinding " "Unable to record file position for rewinding "
@ -142,5 +177,80 @@ def rewind_body(body, body_pos):
) )
else: else:
raise ValueError( raise ValueError(
"body_pos must be of type integer, instead it was %s." % type(body_pos) f"body_pos must be of type integer, instead it was {type(body_pos)}."
) )
class ChunksAndContentLength(typing.NamedTuple):
chunks: typing.Iterable[bytes] | None
content_length: int | None
def body_to_chunks(
body: typing.Any | None, method: str, blocksize: int
) -> ChunksAndContentLength:
"""Takes the HTTP request method, body, and blocksize and
transforms them into an iterable of chunks to pass to
socket.sendall() and an optional 'Content-Length' header.
A 'Content-Length' of 'None' indicates the length of the body
can't be determined so should use 'Transfer-Encoding: chunked'
for framing instead.
"""
chunks: typing.Iterable[bytes] | None
content_length: int | None
# No body, we need to make a recommendation on 'Content-Length'
# based on whether that request method is expected to have
# a body or not.
if body is None:
chunks = None
if method.upper() not in _METHODS_NOT_EXPECTING_BODY:
content_length = 0
else:
content_length = None
# Bytes or strings become bytes
elif isinstance(body, (str, bytes)):
chunks = (to_bytes(body),)
content_length = len(chunks[0])
# File-like object, TODO: use seek() and tell() for length?
elif hasattr(body, "read"):
def chunk_readable() -> typing.Iterable[bytes]:
nonlocal body, blocksize
encode = isinstance(body, io.TextIOBase)
while True:
datablock = body.read(blocksize)
if not datablock:
break
if encode:
datablock = datablock.encode("iso-8859-1")
yield datablock
chunks = chunk_readable()
content_length = None
# Otherwise we need to start checking via duck-typing.
else:
try:
# Check if the body implements the buffer API.
mv = memoryview(body)
except TypeError:
try:
# Check if the body is an iterable
chunks = iter(body)
content_length = None
except TypeError:
raise TypeError(
f"'body' must be a bytes-like object, file-like "
f"object, or iterable. Instead was {body!r}"
) from None
else:
# Since it implements the buffer API can be passed directly to socket.sendall()
chunks = (body,)
content_length = mv.nbytes
return ChunksAndContentLength(chunks=chunks, content_length=content_length)

View file

@ -1,12 +1,12 @@
from __future__ import absolute_import from __future__ import annotations
import http.client as httplib
from email.errors import MultipartInvariantViolationDefect, StartBoundaryNotFoundDefect from email.errors import MultipartInvariantViolationDefect, StartBoundaryNotFoundDefect
from ..exceptions import HeaderParsingError from ..exceptions import HeaderParsingError
from ..packages.six.moves import http_client as httplib
def is_fp_closed(obj): def is_fp_closed(obj: object) -> bool:
""" """
Checks whether a given file-like object is closed. Checks whether a given file-like object is closed.
@ -17,27 +17,27 @@ def is_fp_closed(obj):
try: try:
# Check `isclosed()` first, in case Python3 doesn't set `closed`. # Check `isclosed()` first, in case Python3 doesn't set `closed`.
# GH Issue #928 # GH Issue #928
return obj.isclosed() return obj.isclosed() # type: ignore[no-any-return, attr-defined]
except AttributeError: except AttributeError:
pass pass
try: try:
# Check via the official file-like-object way. # Check via the official file-like-object way.
return obj.closed return obj.closed # type: ignore[no-any-return, attr-defined]
except AttributeError: except AttributeError:
pass pass
try: try:
# Check if the object is a container for another file-like object that # Check if the object is a container for another file-like object that
# gets released on exhaustion (e.g. HTTPResponse). # gets released on exhaustion (e.g. HTTPResponse).
return obj.fp is None return obj.fp is None # type: ignore[attr-defined]
except AttributeError: except AttributeError:
pass pass
raise ValueError("Unable to determine whether fp is closed.") raise ValueError("Unable to determine whether fp is closed.")
def assert_header_parsing(headers): def assert_header_parsing(headers: httplib.HTTPMessage) -> None:
""" """
Asserts whether all headers have been successfully parsed. Asserts whether all headers have been successfully parsed.
Extracts encountered errors from the result of parsing headers. Extracts encountered errors from the result of parsing headers.
@ -53,55 +53,49 @@ def assert_header_parsing(headers):
# This will fail silently if we pass in the wrong kind of parameter. # This will fail silently if we pass in the wrong kind of parameter.
# To make debugging easier add an explicit check. # To make debugging easier add an explicit check.
if not isinstance(headers, httplib.HTTPMessage): if not isinstance(headers, httplib.HTTPMessage):
raise TypeError("expected httplib.Message, got {0}.".format(type(headers))) raise TypeError(f"expected httplib.Message, got {type(headers)}.")
defects = getattr(headers, "defects", None)
get_payload = getattr(headers, "get_payload", None)
unparsed_data = None unparsed_data = None
if get_payload:
# get_payload is actually email.message.Message.get_payload;
# we're only interested in the result if it's not a multipart message
if not headers.is_multipart():
payload = get_payload()
if isinstance(payload, (bytes, str)): # get_payload is actually email.message.Message.get_payload;
unparsed_data = payload # we're only interested in the result if it's not a multipart message
if defects: if not headers.is_multipart():
# httplib is assuming a response body is available payload = headers.get_payload()
# when parsing headers even when httplib only sends
# header data to parse_headers() This results in
# defects on multipart responses in particular.
# See: https://github.com/urllib3/urllib3/issues/800
# So we ignore the following defects: if isinstance(payload, (bytes, str)):
# - StartBoundaryNotFoundDefect: unparsed_data = payload
# The claimed start boundary was never found.
# - MultipartInvariantViolationDefect: # httplib is assuming a response body is available
# A message claimed to be a multipart but no subparts were found. # when parsing headers even when httplib only sends
defects = [ # header data to parse_headers() This results in
defect # defects on multipart responses in particular.
for defect in defects # See: https://github.com/urllib3/urllib3/issues/800
if not isinstance(
defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect) # So we ignore the following defects:
) # - StartBoundaryNotFoundDefect:
] # The claimed start boundary was never found.
# - MultipartInvariantViolationDefect:
# A message claimed to be a multipart but no subparts were found.
defects = [
defect
for defect in headers.defects
if not isinstance(
defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect)
)
]
if defects or unparsed_data: if defects or unparsed_data:
raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data) raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data)
def is_response_to_head(response): def is_response_to_head(response: httplib.HTTPResponse) -> bool:
""" """
Checks whether the request of a response has been a HEAD-request. Checks whether the request of a response has been a HEAD-request.
Handles the quirks of AppEngine.
:param http.client.HTTPResponse response: :param http.client.HTTPResponse response:
Response to check if the originating request Response to check if the originating request
used 'HEAD' as a method. used 'HEAD' as a method.
""" """
# FIXME: Can we do this somehow without accessing private httplib _method? # FIXME: Can we do this somehow without accessing private httplib _method?
method = response._method method_str = response._method # type: str # type: ignore[attr-defined]
if isinstance(method, int): # Platform-specific: Appengine return method_str.upper() == "HEAD"
return method == 3
return method.upper() == "HEAD"

View file

@ -1,12 +1,13 @@
from __future__ import absolute_import from __future__ import annotations
import email import email
import logging import logging
import random
import re import re
import time import time
import warnings import typing
from collections import namedtuple
from itertools import takewhile from itertools import takewhile
from types import TracebackType
from ..exceptions import ( from ..exceptions import (
ConnectTimeoutError, ConnectTimeoutError,
@ -17,97 +18,49 @@ from ..exceptions import (
ReadTimeoutError, ReadTimeoutError,
ResponseError, ResponseError,
) )
from ..packages import six from .util import reraise
if typing.TYPE_CHECKING:
from ..connectionpool import ConnectionPool
from ..response import BaseHTTPResponse
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Data structure for representing the metadata of requests that result in a retry. # Data structure for representing the metadata of requests that result in a retry.
RequestHistory = namedtuple( class RequestHistory(typing.NamedTuple):
"RequestHistory", ["method", "url", "error", "status", "redirect_location"] method: str | None
) url: str | None
error: Exception | None
status: int | None
redirect_location: str | None
# TODO: In v2 we can remove this sentinel and metaclass with deprecated options. class Retry:
_Default = object()
class _RetryMeta(type):
@property
def DEFAULT_METHOD_WHITELIST(cls):
warnings.warn(
"Using 'Retry.DEFAULT_METHOD_WHITELIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_ALLOWED_METHODS' instead",
DeprecationWarning,
)
return cls.DEFAULT_ALLOWED_METHODS
@DEFAULT_METHOD_WHITELIST.setter
def DEFAULT_METHOD_WHITELIST(cls, value):
warnings.warn(
"Using 'Retry.DEFAULT_METHOD_WHITELIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_ALLOWED_METHODS' instead",
DeprecationWarning,
)
cls.DEFAULT_ALLOWED_METHODS = value
@property
def DEFAULT_REDIRECT_HEADERS_BLACKLIST(cls):
warnings.warn(
"Using 'Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT' instead",
DeprecationWarning,
)
return cls.DEFAULT_REMOVE_HEADERS_ON_REDIRECT
@DEFAULT_REDIRECT_HEADERS_BLACKLIST.setter
def DEFAULT_REDIRECT_HEADERS_BLACKLIST(cls, value):
warnings.warn(
"Using 'Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT' instead",
DeprecationWarning,
)
cls.DEFAULT_REMOVE_HEADERS_ON_REDIRECT = value
@property
def BACKOFF_MAX(cls):
warnings.warn(
"Using 'Retry.BACKOFF_MAX' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_BACKOFF_MAX' instead",
DeprecationWarning,
)
return cls.DEFAULT_BACKOFF_MAX
@BACKOFF_MAX.setter
def BACKOFF_MAX(cls, value):
warnings.warn(
"Using 'Retry.BACKOFF_MAX' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_BACKOFF_MAX' instead",
DeprecationWarning,
)
cls.DEFAULT_BACKOFF_MAX = value
@six.add_metaclass(_RetryMeta)
class Retry(object):
"""Retry configuration. """Retry configuration.
Each retry attempt will create a new Retry object with updated values, so Each retry attempt will create a new Retry object with updated values, so
they can be safely reused. they can be safely reused.
Retries can be defined as a default for a pool:: Retries can be defined as a default for a pool:
.. code-block:: python
retries = Retry(connect=5, read=2, redirect=5) retries = Retry(connect=5, read=2, redirect=5)
http = PoolManager(retries=retries) http = PoolManager(retries=retries)
response = http.request('GET', 'http://example.com/') response = http.request("GET", "https://example.com/")
Or per-request (which overrides the default for the pool):: Or per-request (which overrides the default for the pool):
response = http.request('GET', 'http://example.com/', retries=Retry(10)) .. code-block:: python
Retries can be disabled by passing ``False``:: response = http.request("GET", "https://example.com/", retries=Retry(10))
response = http.request('GET', 'http://example.com/', retries=False) Retries can be disabled by passing ``False``:
.. code-block:: python
response = http.request("GET", "https://example.com/", retries=False)
Errors will be wrapped in :class:`~urllib3.exceptions.MaxRetryError` unless Errors will be wrapped in :class:`~urllib3.exceptions.MaxRetryError` unless
retries are disabled, in which case the causing exception will be raised. retries are disabled, in which case the causing exception will be raised.
@ -169,21 +122,16 @@ class Retry(object):
If ``total`` is not set, it's a good idea to set this to 0 to account If ``total`` is not set, it's a good idea to set this to 0 to account
for unexpected edge cases and avoid infinite retry loops. for unexpected edge cases and avoid infinite retry loops.
:param iterable allowed_methods: :param Collection allowed_methods:
Set of uppercased HTTP method verbs that we should retry on. Set of uppercased HTTP method verbs that we should retry on.
By default, we only retry on methods which are considered to be By default, we only retry on methods which are considered to be
idempotent (multiple requests with the same parameters end with the idempotent (multiple requests with the same parameters end with the
same state). See :attr:`Retry.DEFAULT_ALLOWED_METHODS`. same state). See :attr:`Retry.DEFAULT_ALLOWED_METHODS`.
Set to a ``False`` value to retry on any verb. Set to a ``None`` value to retry on any verb.
.. warning:: :param Collection status_forcelist:
Previously this parameter was named ``method_whitelist``, that
usage is deprecated in v1.26.0 and will be removed in v2.0.
:param iterable status_forcelist:
A set of integer HTTP status codes that we should force a retry on. A set of integer HTTP status codes that we should force a retry on.
A retry is initiated if the request method is in ``allowed_methods`` A retry is initiated if the request method is in ``allowed_methods``
and the response status code is in ``status_forcelist``. and the response status code is in ``status_forcelist``.
@ -195,13 +143,17 @@ class Retry(object):
(most errors are resolved immediately by a second try without a (most errors are resolved immediately by a second try without a
delay). urllib3 will sleep for:: delay). urllib3 will sleep for::
{backoff factor} * (2 ** ({number of total retries} - 1)) {backoff factor} * (2 ** ({number of previous retries}))
seconds. If the backoff_factor is 0.1, then :func:`.sleep` will sleep seconds. If `backoff_jitter` is non-zero, this sleep is extended by::
for [0.0s, 0.2s, 0.4s, ...] between retries. It will never be longer
than :attr:`Retry.DEFAULT_BACKOFF_MAX`.
By default, backoff is disabled (set to 0). random.uniform(0, {backoff jitter})
seconds. For example, if the backoff_factor is 0.1, then :func:`Retry.sleep` will
sleep for [0.0s, 0.2s, 0.4s, 0.8s, ...] between retries. No backoff will ever
be longer than `backoff_max`.
By default, backoff is disabled (factor set to 0).
:param bool raise_on_redirect: Whether, if the number of redirects is :param bool raise_on_redirect: Whether, if the number of redirects is
exhausted, to raise a MaxRetryError, or to return a response with a exhausted, to raise a MaxRetryError, or to return a response with a
@ -220,7 +172,7 @@ class Retry(object):
Whether to respect Retry-After header on status codes defined as Whether to respect Retry-After header on status codes defined as
:attr:`Retry.RETRY_AFTER_STATUS_CODES` or not. :attr:`Retry.RETRY_AFTER_STATUS_CODES` or not.
:param iterable remove_headers_on_redirect: :param Collection remove_headers_on_redirect:
Sequence of headers to remove from the request when a response Sequence of headers to remove from the request when a response
indicating a redirect is returned before firing off the redirected indicating a redirect is returned before firing off the redirected
request. request.
@ -237,48 +189,33 @@ class Retry(object):
#: Default headers to be used for ``remove_headers_on_redirect`` #: Default headers to be used for ``remove_headers_on_redirect``
DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset(["Authorization"]) DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset(["Authorization"])
#: Maximum backoff time. #: Default maximum backoff time.
DEFAULT_BACKOFF_MAX = 120 DEFAULT_BACKOFF_MAX = 120
# Backward compatibility; assigned outside of the class.
DEFAULT: typing.ClassVar[Retry]
def __init__( def __init__(
self, self,
total=10, total: bool | int | None = 10,
connect=None, connect: int | None = None,
read=None, read: int | None = None,
redirect=None, redirect: bool | int | None = None,
status=None, status: int | None = None,
other=None, other: int | None = None,
allowed_methods=_Default, allowed_methods: typing.Collection[str] | None = DEFAULT_ALLOWED_METHODS,
status_forcelist=None, status_forcelist: typing.Collection[int] | None = None,
backoff_factor=0, backoff_factor: float = 0,
raise_on_redirect=True, backoff_max: float = DEFAULT_BACKOFF_MAX,
raise_on_status=True, raise_on_redirect: bool = True,
history=None, raise_on_status: bool = True,
respect_retry_after_header=True, history: tuple[RequestHistory, ...] | None = None,
remove_headers_on_redirect=_Default, respect_retry_after_header: bool = True,
# TODO: Deprecated, remove in v2.0 remove_headers_on_redirect: typing.Collection[
method_whitelist=_Default, str
): ] = DEFAULT_REMOVE_HEADERS_ON_REDIRECT,
backoff_jitter: float = 0.0,
if method_whitelist is not _Default: ) -> None:
if allowed_methods is not _Default:
raise ValueError(
"Using both 'allowed_methods' and "
"'method_whitelist' together is not allowed. "
"Instead only use 'allowed_methods'"
)
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
stacklevel=2,
)
allowed_methods = method_whitelist
if allowed_methods is _Default:
allowed_methods = self.DEFAULT_ALLOWED_METHODS
if remove_headers_on_redirect is _Default:
remove_headers_on_redirect = self.DEFAULT_REMOVE_HEADERS_ON_REDIRECT
self.total = total self.total = total
self.connect = connect self.connect = connect
self.read = read self.read = read
@ -293,15 +230,17 @@ class Retry(object):
self.status_forcelist = status_forcelist or set() self.status_forcelist = status_forcelist or set()
self.allowed_methods = allowed_methods self.allowed_methods = allowed_methods
self.backoff_factor = backoff_factor self.backoff_factor = backoff_factor
self.backoff_max = backoff_max
self.raise_on_redirect = raise_on_redirect self.raise_on_redirect = raise_on_redirect
self.raise_on_status = raise_on_status self.raise_on_status = raise_on_status
self.history = history or tuple() self.history = history or ()
self.respect_retry_after_header = respect_retry_after_header self.respect_retry_after_header = respect_retry_after_header
self.remove_headers_on_redirect = frozenset( self.remove_headers_on_redirect = frozenset(
[h.lower() for h in remove_headers_on_redirect] h.lower() for h in remove_headers_on_redirect
) )
self.backoff_jitter = backoff_jitter
def new(self, **kw): def new(self, **kw: typing.Any) -> Retry:
params = dict( params = dict(
total=self.total, total=self.total,
connect=self.connect, connect=self.connect,
@ -309,36 +248,28 @@ class Retry(object):
redirect=self.redirect, redirect=self.redirect,
status=self.status, status=self.status,
other=self.other, other=self.other,
allowed_methods=self.allowed_methods,
status_forcelist=self.status_forcelist, status_forcelist=self.status_forcelist,
backoff_factor=self.backoff_factor, backoff_factor=self.backoff_factor,
backoff_max=self.backoff_max,
raise_on_redirect=self.raise_on_redirect, raise_on_redirect=self.raise_on_redirect,
raise_on_status=self.raise_on_status, raise_on_status=self.raise_on_status,
history=self.history, history=self.history,
remove_headers_on_redirect=self.remove_headers_on_redirect, remove_headers_on_redirect=self.remove_headers_on_redirect,
respect_retry_after_header=self.respect_retry_after_header, respect_retry_after_header=self.respect_retry_after_header,
backoff_jitter=self.backoff_jitter,
) )
# TODO: If already given in **kw we use what's given to us
# If not given we need to figure out what to pass. We decide
# based on whether our class has the 'method_whitelist' property
# and if so we pass the deprecated 'method_whitelist' otherwise
# we use 'allowed_methods'. Remove in v2.0
if "method_whitelist" not in kw and "allowed_methods" not in kw:
if "method_whitelist" in self.__dict__:
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
)
params["method_whitelist"] = self.allowed_methods
else:
params["allowed_methods"] = self.allowed_methods
params.update(kw) params.update(kw)
return type(self)(**params) return type(self)(**params) # type: ignore[arg-type]
@classmethod @classmethod
def from_int(cls, retries, redirect=True, default=None): def from_int(
cls,
retries: Retry | bool | int | None,
redirect: bool | int | None = True,
default: Retry | bool | int | None = None,
) -> Retry:
"""Backwards-compatibility for the old retries format.""" """Backwards-compatibility for the old retries format."""
if retries is None: if retries is None:
retries = default if default is not None else cls.DEFAULT retries = default if default is not None else cls.DEFAULT
@ -351,7 +282,7 @@ class Retry(object):
log.debug("Converted retries value: %r -> %r", retries, new_retries) log.debug("Converted retries value: %r -> %r", retries, new_retries)
return new_retries return new_retries
def get_backoff_time(self): def get_backoff_time(self) -> float:
"""Formula for computing the current backoff """Formula for computing the current backoff
:rtype: float :rtype: float
@ -366,32 +297,28 @@ class Retry(object):
return 0 return 0
backoff_value = self.backoff_factor * (2 ** (consecutive_errors_len - 1)) backoff_value = self.backoff_factor * (2 ** (consecutive_errors_len - 1))
return min(self.DEFAULT_BACKOFF_MAX, backoff_value) if self.backoff_jitter != 0.0:
backoff_value += random.random() * self.backoff_jitter
return float(max(0, min(self.backoff_max, backoff_value)))
def parse_retry_after(self, retry_after): def parse_retry_after(self, retry_after: str) -> float:
seconds: float
# Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4 # Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4
if re.match(r"^\s*[0-9]+\s*$", retry_after): if re.match(r"^\s*[0-9]+\s*$", retry_after):
seconds = int(retry_after) seconds = int(retry_after)
else: else:
retry_date_tuple = email.utils.parsedate_tz(retry_after) retry_date_tuple = email.utils.parsedate_tz(retry_after)
if retry_date_tuple is None: if retry_date_tuple is None:
raise InvalidHeader("Invalid Retry-After header: %s" % retry_after) raise InvalidHeader(f"Invalid Retry-After header: {retry_after}")
if retry_date_tuple[9] is None: # Python 2
# Assume UTC if no timezone was specified
# On Python2.7, parsedate_tz returns None for a timezone offset
# instead of 0 if no timezone is given, where mktime_tz treats
# a None timezone offset as local time.
retry_date_tuple = retry_date_tuple[:9] + (0,) + retry_date_tuple[10:]
retry_date = email.utils.mktime_tz(retry_date_tuple) retry_date = email.utils.mktime_tz(retry_date_tuple)
seconds = retry_date - time.time() seconds = retry_date - time.time()
if seconds < 0: seconds = max(seconds, 0)
seconds = 0
return seconds return seconds
def get_retry_after(self, response): def get_retry_after(self, response: BaseHTTPResponse) -> float | None:
"""Get the value of Retry-After in seconds.""" """Get the value of Retry-After in seconds."""
retry_after = response.headers.get("Retry-After") retry_after = response.headers.get("Retry-After")
@ -401,7 +328,7 @@ class Retry(object):
return self.parse_retry_after(retry_after) return self.parse_retry_after(retry_after)
def sleep_for_retry(self, response=None): def sleep_for_retry(self, response: BaseHTTPResponse) -> bool:
retry_after = self.get_retry_after(response) retry_after = self.get_retry_after(response)
if retry_after: if retry_after:
time.sleep(retry_after) time.sleep(retry_after)
@ -409,13 +336,13 @@ class Retry(object):
return False return False
def _sleep_backoff(self): def _sleep_backoff(self) -> None:
backoff = self.get_backoff_time() backoff = self.get_backoff_time()
if backoff <= 0: if backoff <= 0:
return return
time.sleep(backoff) time.sleep(backoff)
def sleep(self, response=None): def sleep(self, response: BaseHTTPResponse | None = None) -> None:
"""Sleep between retry attempts. """Sleep between retry attempts.
This method will respect a server's ``Retry-After`` response header This method will respect a server's ``Retry-After`` response header
@ -431,7 +358,7 @@ class Retry(object):
self._sleep_backoff() self._sleep_backoff()
def _is_connection_error(self, err): def _is_connection_error(self, err: Exception) -> bool:
"""Errors when we're fairly sure that the server did not receive the """Errors when we're fairly sure that the server did not receive the
request, so it should be safe to retry. request, so it should be safe to retry.
""" """
@ -439,33 +366,23 @@ class Retry(object):
err = err.original_error err = err.original_error
return isinstance(err, ConnectTimeoutError) return isinstance(err, ConnectTimeoutError)
def _is_read_error(self, err): def _is_read_error(self, err: Exception) -> bool:
"""Errors that occur after the request has been started, so we should """Errors that occur after the request has been started, so we should
assume that the server began processing it. assume that the server began processing it.
""" """
return isinstance(err, (ReadTimeoutError, ProtocolError)) return isinstance(err, (ReadTimeoutError, ProtocolError))
def _is_method_retryable(self, method): def _is_method_retryable(self, method: str) -> bool:
"""Checks if a given HTTP method should be retried upon, depending if """Checks if a given HTTP method should be retried upon, depending if
it is included in the allowed_methods it is included in the allowed_methods
""" """
# TODO: For now favor if the Retry implementation sets its own method_whitelist if self.allowed_methods and method.upper() not in self.allowed_methods:
# property outside of our constructor to avoid breaking custom implementations.
if "method_whitelist" in self.__dict__:
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
)
allowed_methods = self.method_whitelist
else:
allowed_methods = self.allowed_methods
if allowed_methods and method.upper() not in allowed_methods:
return False return False
return True return True
def is_retry(self, method, status_code, has_retry_after=False): def is_retry(
self, method: str, status_code: int, has_retry_after: bool = False
) -> bool:
"""Is this method/status code retryable? (Based on allowlists and control """Is this method/status code retryable? (Based on allowlists and control
variables such as the number of total retries to allow, whether to variables such as the number of total retries to allow, whether to
respect the Retry-After header, whether this header is present, and respect the Retry-After header, whether this header is present, and
@ -478,24 +395,27 @@ class Retry(object):
if self.status_forcelist and status_code in self.status_forcelist: if self.status_forcelist and status_code in self.status_forcelist:
return True return True
return ( return bool(
self.total self.total
and self.respect_retry_after_header and self.respect_retry_after_header
and has_retry_after and has_retry_after
and (status_code in self.RETRY_AFTER_STATUS_CODES) and (status_code in self.RETRY_AFTER_STATUS_CODES)
) )
def is_exhausted(self): def is_exhausted(self) -> bool:
"""Are we out of retries?""" """Are we out of retries?"""
retry_counts = ( retry_counts = [
self.total, x
self.connect, for x in (
self.read, self.total,
self.redirect, self.connect,
self.status, self.read,
self.other, self.redirect,
) self.status,
retry_counts = list(filter(None, retry_counts)) self.other,
)
if x
]
if not retry_counts: if not retry_counts:
return False return False
@ -503,18 +423,18 @@ class Retry(object):
def increment( def increment(
self, self,
method=None, method: str | None = None,
url=None, url: str | None = None,
response=None, response: BaseHTTPResponse | None = None,
error=None, error: Exception | None = None,
_pool=None, _pool: ConnectionPool | None = None,
_stacktrace=None, _stacktrace: TracebackType | None = None,
): ) -> Retry:
"""Return a new Retry object with incremented retry counters. """Return a new Retry object with incremented retry counters.
:param response: A response object, or None, if the server did not :param response: A response object, or None, if the server did not
return a response. return a response.
:type response: :class:`~urllib3.response.HTTPResponse` :type response: :class:`~urllib3.response.BaseHTTPResponse`
:param Exception error: An error encountered during the request, or :param Exception error: An error encountered during the request, or
None if the response was received successfully. None if the response was received successfully.
@ -522,7 +442,7 @@ class Retry(object):
""" """
if self.total is False and error: if self.total is False and error:
# Disabled, indicate to re-raise the error. # Disabled, indicate to re-raise the error.
raise six.reraise(type(error), error, _stacktrace) raise reraise(type(error), error, _stacktrace)
total = self.total total = self.total
if total is not None: if total is not None:
@ -540,14 +460,14 @@ class Retry(object):
if error and self._is_connection_error(error): if error and self._is_connection_error(error):
# Connect retry? # Connect retry?
if connect is False: if connect is False:
raise six.reraise(type(error), error, _stacktrace) raise reraise(type(error), error, _stacktrace)
elif connect is not None: elif connect is not None:
connect -= 1 connect -= 1
elif error and self._is_read_error(error): elif error and self._is_read_error(error):
# Read retry? # Read retry?
if read is False or not self._is_method_retryable(method): if read is False or method is None or not self._is_method_retryable(method):
raise six.reraise(type(error), error, _stacktrace) raise reraise(type(error), error, _stacktrace)
elif read is not None: elif read is not None:
read -= 1 read -= 1
@ -561,7 +481,9 @@ class Retry(object):
if redirect is not None: if redirect is not None:
redirect -= 1 redirect -= 1
cause = "too many redirects" cause = "too many redirects"
redirect_location = response.get_redirect_location() response_redirect_location = response.get_redirect_location()
if response_redirect_location:
redirect_location = response_redirect_location
status = response.status status = response.status
else: else:
@ -589,31 +511,18 @@ class Retry(object):
) )
if new_retry.is_exhausted(): if new_retry.is_exhausted():
raise MaxRetryError(_pool, url, error or ResponseError(cause)) reason = error or ResponseError(cause)
raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type]
log.debug("Incremented Retry for (url='%s'): %r", url, new_retry) log.debug("Incremented Retry for (url='%s'): %r", url, new_retry)
return new_retry return new_retry
def __repr__(self): def __repr__(self) -> str:
return ( return (
"{cls.__name__}(total={self.total}, connect={self.connect}, " f"{type(self).__name__}(total={self.total}, connect={self.connect}, "
"read={self.read}, redirect={self.redirect}, status={self.status})" f"read={self.read}, redirect={self.redirect}, status={self.status})"
).format(cls=type(self), self=self) )
def __getattr__(self, item):
if item == "method_whitelist":
# TODO: Remove this deprecated alias in v2.0
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
)
return self.allowed_methods
try:
return getattr(super(Retry, self), item)
except AttributeError:
return getattr(Retry, item)
# For backwards compatibility (equivalent to pre-v1.9): # For backwards compatibility (equivalent to pre-v1.9):

View file

@ -1,185 +1,152 @@
from __future__ import absolute_import from __future__ import annotations
import hmac import hmac
import os import os
import socket
import sys import sys
import typing
import warnings import warnings
from binascii import hexlify, unhexlify from binascii import unhexlify
from hashlib import md5, sha1, sha256 from hashlib import md5, sha1, sha256
from ..exceptions import ( from ..exceptions import ProxySchemeUnsupported, SSLError
InsecurePlatformWarning, from .url import _BRACELESS_IPV6_ADDRZ_RE, _IPV4_RE
ProxySchemeUnsupported,
SNIMissingWarning,
SSLError,
)
from ..packages import six
from .url import BRACELESS_IPV6_ADDRZ_RE, IPV4_RE
SSLContext = None SSLContext = None
SSLTransport = None SSLTransport = None
HAS_SNI = False HAS_NEVER_CHECK_COMMON_NAME = False
IS_PYOPENSSL = False IS_PYOPENSSL = False
IS_SECURETRANSPORT = False IS_SECURETRANSPORT = False
ALPN_PROTOCOLS = ["http/1.1"] ALPN_PROTOCOLS = ["http/1.1"]
_TYPE_VERSION_INFO = typing.Tuple[int, int, int, str, int]
# Maps the length of a digest to a possible hash function producing this digest # Maps the length of a digest to a possible hash function producing this digest
HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256} HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256}
def _const_compare_digest_backport(a, b): def _is_bpo_43522_fixed(
implementation_name: str,
version_info: _TYPE_VERSION_INFO,
pypy_version_info: _TYPE_VERSION_INFO | None,
) -> bool:
"""Return True for CPython 3.8.9+, 3.9.3+ or 3.10+ and PyPy 7.3.8+ where
setting SSLContext.hostname_checks_common_name to False works.
Outside of CPython and PyPy we don't know which implementations work
or not so we conservatively use our hostname matching as we know that works
on all implementations.
https://github.com/urllib3/urllib3/issues/2192#issuecomment-821832963
https://foss.heptapod.net/pypy/pypy/-/issues/3539
""" """
Compare two digests of equal length in constant time. if implementation_name == "pypy":
# https://foss.heptapod.net/pypy/pypy/-/issues/3129
The digests must be of type str/bytes. return pypy_version_info >= (7, 3, 8) and version_info >= (3, 8) # type: ignore[operator]
Returns True if the digests match, and False otherwise. elif implementation_name == "cpython":
""" major_minor = version_info[:2]
result = abs(len(a) - len(b)) micro = version_info[2]
for left, right in zip(bytearray(a), bytearray(b)): return (
result |= left ^ right (major_minor == (3, 8) and micro >= 9)
return result == 0 or (major_minor == (3, 9) and micro >= 3)
or major_minor >= (3, 10)
)
else: # Defensive:
return False
_const_compare_digest = getattr(hmac, "compare_digest", _const_compare_digest_backport) def _is_has_never_check_common_name_reliable(
openssl_version: str,
openssl_version_number: int,
implementation_name: str,
version_info: _TYPE_VERSION_INFO,
pypy_version_info: _TYPE_VERSION_INFO | None,
) -> bool:
# As of May 2023, all released versions of LibreSSL fail to reject certificates with
# only common names, see https://github.com/urllib3/urllib3/pull/3024
is_openssl = openssl_version.startswith("OpenSSL ")
# Before fixing OpenSSL issue #14579, the SSL_new() API was not copying hostflags
# like X509_CHECK_FLAG_NEVER_CHECK_SUBJECT, which tripped up CPython.
# https://github.com/openssl/openssl/issues/14579
# This was released in OpenSSL 1.1.1l+ (>=0x101010cf)
is_openssl_issue_14579_fixed = openssl_version_number >= 0x101010CF
try: # Test for SSL features return is_openssl and (
is_openssl_issue_14579_fixed
or _is_bpo_43522_fixed(implementation_name, version_info, pypy_version_info)
)
if typing.TYPE_CHECKING:
from ssl import VerifyMode
from typing_extensions import Literal, TypedDict
from .ssltransport import SSLTransport as SSLTransportType
class _TYPE_PEER_CERT_RET_DICT(TypedDict, total=False):
subjectAltName: tuple[tuple[str, str], ...]
subject: tuple[tuple[tuple[str, str], ...], ...]
serialNumber: str
# Mapping from 'ssl.PROTOCOL_TLSX' to 'TLSVersion.X'
_SSL_VERSION_TO_TLS_VERSION: dict[int, int] = {}
try: # Do we have ssl at all?
import ssl import ssl
from ssl import CERT_REQUIRED, wrap_socket from ssl import ( # type: ignore[assignment]
except ImportError: CERT_REQUIRED,
pass HAS_NEVER_CHECK_COMMON_NAME,
OP_NO_COMPRESSION,
try: OP_NO_TICKET,
from ssl import HAS_SNI # Has SNI? OPENSSL_VERSION,
except ImportError: OPENSSL_VERSION_NUMBER,
pass PROTOCOL_TLS,
PROTOCOL_TLS_CLIENT,
try: OP_NO_SSLv2,
from .ssltransport import SSLTransport OP_NO_SSLv3,
except ImportError: SSLContext,
pass TLSVersion,
)
try: # Platform-specific: Python 3.6
from ssl import PROTOCOL_TLS
PROTOCOL_SSLv23 = PROTOCOL_TLS PROTOCOL_SSLv23 = PROTOCOL_TLS
except ImportError:
try:
from ssl import PROTOCOL_SSLv23 as PROTOCOL_TLS
PROTOCOL_SSLv23 = PROTOCOL_TLS # Setting SSLContext.hostname_checks_common_name = False didn't work before CPython
except ImportError: # 3.8.9, 3.9.3, and 3.10 (but OK on PyPy) or OpenSSL 1.1.1l+
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2 if HAS_NEVER_CHECK_COMMON_NAME and not _is_has_never_check_common_name_reliable(
OPENSSL_VERSION,
OPENSSL_VERSION_NUMBER,
sys.implementation.name,
sys.version_info,
sys.pypy_version_info if sys.implementation.name == "pypy" else None, # type: ignore[attr-defined]
):
HAS_NEVER_CHECK_COMMON_NAME = False
try: # Need to be careful here in case old TLS versions get
from ssl import PROTOCOL_TLS_CLIENT # removed in future 'ssl' module implementations.
except ImportError: for attr in ("TLSv1", "TLSv1_1", "TLSv1_2"):
PROTOCOL_TLS_CLIENT = PROTOCOL_TLS try:
_SSL_VERSION_TO_TLS_VERSION[getattr(ssl, f"PROTOCOL_{attr}")] = getattr(
TLSVersion, attr
try:
from ssl import OP_NO_COMPRESSION, OP_NO_SSLv2, OP_NO_SSLv3
except ImportError:
OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000
OP_NO_COMPRESSION = 0x20000
try: # OP_NO_TICKET was added in Python 3.6
from ssl import OP_NO_TICKET
except ImportError:
OP_NO_TICKET = 0x4000
# A secure default.
# Sources for more information on TLS ciphers:
#
# - https://wiki.mozilla.org/Security/Server_Side_TLS
# - https://www.ssllabs.com/projects/best-practices/index.html
# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
#
# The general intent is:
# - prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE),
# - prefer ECDHE over DHE for better performance,
# - prefer any AES-GCM and ChaCha20 over any AES-CBC for better performance and
# security,
# - prefer AES-GCM over ChaCha20 because hardware-accelerated AES is common,
# - disable NULL authentication, MD5 MACs, DSS, and other
# insecure ciphers for security reasons.
# - NOTE: TLS 1.3 cipher suites are managed through a different interface
# not exposed by CPython (yet!) and are enabled by default if they're available.
DEFAULT_CIPHERS = ":".join(
[
"ECDHE+AESGCM",
"ECDHE+CHACHA20",
"DHE+AESGCM",
"DHE+CHACHA20",
"ECDH+AESGCM",
"DH+AESGCM",
"ECDH+AES",
"DH+AES",
"RSA+AESGCM",
"RSA+AES",
"!aNULL",
"!eNULL",
"!MD5",
"!DSS",
]
)
try:
from ssl import SSLContext # Modern SSL?
except ImportError:
class SSLContext(object): # Platform-specific: Python 2
def __init__(self, protocol_version):
self.protocol = protocol_version
# Use default values from a real SSLContext
self.check_hostname = False
self.verify_mode = ssl.CERT_NONE
self.ca_certs = None
self.options = 0
self.certfile = None
self.keyfile = None
self.ciphers = None
def load_cert_chain(self, certfile, keyfile):
self.certfile = certfile
self.keyfile = keyfile
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
self.ca_certs = cafile
if capath is not None:
raise SSLError("CA directories not supported in older Pythons")
if cadata is not None:
raise SSLError("CA data not supported in older Pythons")
def set_ciphers(self, cipher_suite):
self.ciphers = cipher_suite
def wrap_socket(self, socket, server_hostname=None, server_side=False):
warnings.warn(
"A true SSLContext object is not available. This prevents "
"urllib3 from configuring SSL appropriately and may cause "
"certain SSL connections to fail. You can upgrade to a newer "
"version of Python to solve this. For more information, see "
"https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html"
"#ssl-warnings",
InsecurePlatformWarning,
) )
kwargs = { except AttributeError: # Defensive:
"keyfile": self.keyfile, continue
"certfile": self.certfile,
"ca_certs": self.ca_certs, from .ssltransport import SSLTransport # type: ignore[assignment]
"cert_reqs": self.verify_mode, except ImportError:
"ssl_version": self.protocol, OP_NO_COMPRESSION = 0x20000 # type: ignore[assignment]
"server_side": server_side, OP_NO_TICKET = 0x4000 # type: ignore[assignment]
} OP_NO_SSLv2 = 0x1000000 # type: ignore[assignment]
return wrap_socket(socket, ciphers=self.ciphers, **kwargs) OP_NO_SSLv3 = 0x2000000 # type: ignore[assignment]
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2 # type: ignore[assignment]
PROTOCOL_TLS_CLIENT = 16 # type: ignore[assignment]
def assert_fingerprint(cert, fingerprint): _TYPE_PEER_CERT_RET = typing.Union["_TYPE_PEER_CERT_RET_DICT", bytes, None]
def assert_fingerprint(cert: bytes | None, fingerprint: str) -> None:
""" """
Checks if given fingerprint matches the supplied certificate. Checks if given fingerprint matches the supplied certificate.
@ -189,26 +156,27 @@ def assert_fingerprint(cert, fingerprint):
Fingerprint as string of hexdigits, can be interspersed by colons. Fingerprint as string of hexdigits, can be interspersed by colons.
""" """
if cert is None:
raise SSLError("No certificate for the peer.")
fingerprint = fingerprint.replace(":", "").lower() fingerprint = fingerprint.replace(":", "").lower()
digest_length = len(fingerprint) digest_length = len(fingerprint)
hashfunc = HASHFUNC_MAP.get(digest_length) hashfunc = HASHFUNC_MAP.get(digest_length)
if not hashfunc: if not hashfunc:
raise SSLError("Fingerprint of invalid length: {0}".format(fingerprint)) raise SSLError(f"Fingerprint of invalid length: {fingerprint}")
# We need encode() here for py32; works on py2 and p33. # We need encode() here for py32; works on py2 and p33.
fingerprint_bytes = unhexlify(fingerprint.encode()) fingerprint_bytes = unhexlify(fingerprint.encode())
cert_digest = hashfunc(cert).digest() cert_digest = hashfunc(cert).digest()
if not _const_compare_digest(cert_digest, fingerprint_bytes): if not hmac.compare_digest(cert_digest, fingerprint_bytes):
raise SSLError( raise SSLError(
'Fingerprints did not match. Expected "{0}", got "{1}".'.format( f'Fingerprints did not match. Expected "{fingerprint}", got "{cert_digest.hex()}"'
fingerprint, hexlify(cert_digest)
)
) )
def resolve_cert_reqs(candidate): def resolve_cert_reqs(candidate: None | int | str) -> VerifyMode:
""" """
Resolves the argument to a numeric constant, which can be passed to Resolves the argument to a numeric constant, which can be passed to
the wrap_socket function/method from the ssl module. the wrap_socket function/method from the ssl module.
@ -226,12 +194,12 @@ def resolve_cert_reqs(candidate):
res = getattr(ssl, candidate, None) res = getattr(ssl, candidate, None)
if res is None: if res is None:
res = getattr(ssl, "CERT_" + candidate) res = getattr(ssl, "CERT_" + candidate)
return res return res # type: ignore[no-any-return]
return candidate return candidate # type: ignore[return-value]
def resolve_ssl_version(candidate): def resolve_ssl_version(candidate: None | int | str) -> int:
""" """
like resolve_cert_reqs like resolve_cert_reqs
""" """
@ -242,35 +210,33 @@ def resolve_ssl_version(candidate):
res = getattr(ssl, candidate, None) res = getattr(ssl, candidate, None)
if res is None: if res is None:
res = getattr(ssl, "PROTOCOL_" + candidate) res = getattr(ssl, "PROTOCOL_" + candidate)
return res return typing.cast(int, res)
return candidate return candidate
def create_urllib3_context( def create_urllib3_context(
ssl_version=None, cert_reqs=None, options=None, ciphers=None ssl_version: int | None = None,
): cert_reqs: int | None = None,
"""All arguments have the same meaning as ``ssl_wrap_socket``. options: int | None = None,
ciphers: str | None = None,
By default, this function does a lot of the same work that ssl_minimum_version: int | None = None,
``ssl.create_default_context`` does on Python 3.4+. It: ssl_maximum_version: int | None = None,
) -> ssl.SSLContext:
- Disables SSLv2, SSLv3, and compression """Creates and configures an :class:`ssl.SSLContext` instance for use with urllib3.
- Sets a restricted set of server ciphers
If you wish to enable SSLv3, you can do::
from urllib3.util import ssl_
context = ssl_.create_urllib3_context()
context.options &= ~ssl_.OP_NO_SSLv3
You can do the same to enable compression (substituting ``COMPRESSION``
for ``SSLv3`` in the last line above).
:param ssl_version: :param ssl_version:
The desired protocol version to use. This will default to The desired protocol version to use. This will default to
PROTOCOL_SSLv23 which will negotiate the highest protocol that both PROTOCOL_SSLv23 which will negotiate the highest protocol that both
the server and your installation of OpenSSL support. the server and your installation of OpenSSL support.
This parameter is deprecated instead use 'ssl_minimum_version'.
:param ssl_minimum_version:
The minimum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value.
:param ssl_maximum_version:
The maximum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value.
Not recommended to set to anything other than 'ssl.TLSVersion.MAXIMUM_SUPPORTED' which is the
default value.
:param cert_reqs: :param cert_reqs:
Whether to require the certificate verification. This defaults to Whether to require the certificate verification. This defaults to
``ssl.CERT_REQUIRED``. ``ssl.CERT_REQUIRED``.
@ -278,18 +244,60 @@ def create_urllib3_context(
Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``, Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``,
``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``, and ``ssl.OP_NO_TICKET``. ``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``, and ``ssl.OP_NO_TICKET``.
:param ciphers: :param ciphers:
Which cipher suites to allow the server to select. Which cipher suites to allow the server to select. Defaults to either system configured
ciphers if OpenSSL 1.1.1+, otherwise uses a secure default set of ciphers.
:returns: :returns:
Constructed SSLContext object with specified options Constructed SSLContext object with specified options
:rtype: SSLContext :rtype: SSLContext
""" """
# PROTOCOL_TLS is deprecated in Python 3.10 if SSLContext is None:
if not ssl_version or ssl_version == PROTOCOL_TLS: raise TypeError("Can't create an SSLContext object without an ssl module")
ssl_version = PROTOCOL_TLS_CLIENT
context = SSLContext(ssl_version) # This means 'ssl_version' was specified as an exact value.
if ssl_version not in (None, PROTOCOL_TLS, PROTOCOL_TLS_CLIENT):
# Disallow setting 'ssl_version' and 'ssl_minimum|maximum_version'
# to avoid conflicts.
if ssl_minimum_version is not None or ssl_maximum_version is not None:
raise ValueError(
"Can't specify both 'ssl_version' and either "
"'ssl_minimum_version' or 'ssl_maximum_version'"
)
context.set_ciphers(ciphers or DEFAULT_CIPHERS) # 'ssl_version' is deprecated and will be removed in the future.
else:
# Use 'ssl_minimum_version' and 'ssl_maximum_version' instead.
ssl_minimum_version = _SSL_VERSION_TO_TLS_VERSION.get(
ssl_version, TLSVersion.MINIMUM_SUPPORTED
)
ssl_maximum_version = _SSL_VERSION_TO_TLS_VERSION.get(
ssl_version, TLSVersion.MAXIMUM_SUPPORTED
)
# This warning message is pushing users to use 'ssl_minimum_version'
# instead of both min/max. Best practice is to only set the minimum version and
# keep the maximum version to be it's default value: 'TLSVersion.MAXIMUM_SUPPORTED'
warnings.warn(
"'ssl_version' option is deprecated and will be "
"removed in urllib3 v2.1.0. Instead use 'ssl_minimum_version'",
category=DeprecationWarning,
stacklevel=2,
)
# PROTOCOL_TLS is deprecated in Python 3.10 so we always use PROTOCOL_TLS_CLIENT
context = SSLContext(PROTOCOL_TLS_CLIENT)
if ssl_minimum_version is not None:
context.minimum_version = ssl_minimum_version
else: # Python <3.10 defaults to 'MINIMUM_SUPPORTED' so explicitly set TLSv1.2 here
context.minimum_version = TLSVersion.TLSv1_2
if ssl_maximum_version is not None:
context.maximum_version = ssl_maximum_version
# Unless we're given ciphers defer to either system ciphers in
# the case of OpenSSL 1.1.1+ or use our own secure default ciphers.
if ciphers:
context.set_ciphers(ciphers)
# Setting the default here, as we may have no ssl module on import # Setting the default here, as we may have no ssl module on import
cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
@ -322,26 +330,23 @@ def create_urllib3_context(
) is not None: ) is not None:
context.post_handshake_auth = True context.post_handshake_auth = True
def disable_check_hostname():
if (
getattr(context, "check_hostname", None) is not None
): # Platform-specific: Python 3.2
# We do our own verification, including fingerprints and alternative
# hostnames. So disable it here
context.check_hostname = False
# The order of the below lines setting verify_mode and check_hostname # The order of the below lines setting verify_mode and check_hostname
# matter due to safe-guards SSLContext has to prevent an SSLContext with # matter due to safe-guards SSLContext has to prevent an SSLContext with
# check_hostname=True, verify_mode=NONE/OPTIONAL. This is made even more # check_hostname=True, verify_mode=NONE/OPTIONAL.
# complex because we don't know whether PROTOCOL_TLS_CLIENT will be used # We always set 'check_hostname=False' for pyOpenSSL so we rely on our own
# or not so we don't know the initial state of the freshly created SSLContext. # 'ssl.match_hostname()' implementation.
if cert_reqs == ssl.CERT_REQUIRED: if cert_reqs == ssl.CERT_REQUIRED and not IS_PYOPENSSL:
context.verify_mode = cert_reqs context.verify_mode = cert_reqs
disable_check_hostname() context.check_hostname = True
else: else:
disable_check_hostname() context.check_hostname = False
context.verify_mode = cert_reqs context.verify_mode = cert_reqs
try:
context.hostname_checks_common_name = False
except AttributeError: # Defensive: for CPython < 3.8.9 and 3.9.3; for PyPy < 7.3.8
pass
# Enable logging of TLS session keys via defacto standard environment variable # Enable logging of TLS session keys via defacto standard environment variable
# 'SSLKEYLOGFILE', if the feature is available (Python 3.8+). Skip empty values. # 'SSLKEYLOGFILE', if the feature is available (Python 3.8+). Skip empty values.
if hasattr(context, "keylog_filename"): if hasattr(context, "keylog_filename"):
@ -352,21 +357,59 @@ def create_urllib3_context(
return context return context
@typing.overload
def ssl_wrap_socket( def ssl_wrap_socket(
sock, sock: socket.socket,
keyfile=None, keyfile: str | None = ...,
certfile=None, certfile: str | None = ...,
cert_reqs=None, cert_reqs: int | None = ...,
ca_certs=None, ca_certs: str | None = ...,
server_hostname=None, server_hostname: str | None = ...,
ssl_version=None, ssl_version: int | None = ...,
ciphers=None, ciphers: str | None = ...,
ssl_context=None, ssl_context: ssl.SSLContext | None = ...,
ca_cert_dir=None, ca_cert_dir: str | None = ...,
key_password=None, key_password: str | None = ...,
ca_cert_data=None, ca_cert_data: None | str | bytes = ...,
tls_in_tls=False, tls_in_tls: Literal[False] = ...,
): ) -> ssl.SSLSocket:
...
@typing.overload
def ssl_wrap_socket(
sock: socket.socket,
keyfile: str | None = ...,
certfile: str | None = ...,
cert_reqs: int | None = ...,
ca_certs: str | None = ...,
server_hostname: str | None = ...,
ssl_version: int | None = ...,
ciphers: str | None = ...,
ssl_context: ssl.SSLContext | None = ...,
ca_cert_dir: str | None = ...,
key_password: str | None = ...,
ca_cert_data: None | str | bytes = ...,
tls_in_tls: bool = ...,
) -> ssl.SSLSocket | SSLTransportType:
...
def ssl_wrap_socket(
sock: socket.socket,
keyfile: str | None = None,
certfile: str | None = None,
cert_reqs: int | None = None,
ca_certs: str | None = None,
server_hostname: str | None = None,
ssl_version: int | None = None,
ciphers: str | None = None,
ssl_context: ssl.SSLContext | None = None,
ca_cert_dir: str | None = None,
key_password: str | None = None,
ca_cert_data: None | str | bytes = None,
tls_in_tls: bool = False,
) -> ssl.SSLSocket | SSLTransportType:
""" """
All arguments except for server_hostname, ssl_context, and ca_cert_dir have All arguments except for server_hostname, ssl_context, and ca_cert_dir have
the same meaning as they do when using :func:`ssl.wrap_socket`. the same meaning as they do when using :func:`ssl.wrap_socket`.
@ -392,19 +435,18 @@ def ssl_wrap_socket(
""" """
context = ssl_context context = ssl_context
if context is None: if context is None:
# Note: This branch of code and all the variables in it are no longer # Note: This branch of code and all the variables in it are only used in tests.
# used by urllib3 itself. We should consider deprecating and removing # We should consider deprecating and removing this code.
# this code.
context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers) context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers)
if ca_certs or ca_cert_dir or ca_cert_data: if ca_certs or ca_cert_dir or ca_cert_data:
try: try:
context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data) context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data)
except (IOError, OSError) as e: except OSError as e:
raise SSLError(e) raise SSLError(e) from e
elif ssl_context is None and hasattr(context, "load_default_certs"): elif ssl_context is None and hasattr(context, "load_default_certs"):
# try to load OS default certs; works well on Windows (require Python3.4+) # try to load OS default certs; works well on Windows.
context.load_default_certs() context.load_default_certs()
# Attempt to detect if we get the goofy behavior of the # Attempt to detect if we get the goofy behavior of the
@ -420,56 +462,30 @@ def ssl_wrap_socket(
context.load_cert_chain(certfile, keyfile, key_password) context.load_cert_chain(certfile, keyfile, key_password)
try: try:
if hasattr(context, "set_alpn_protocols"): context.set_alpn_protocols(ALPN_PROTOCOLS)
context.set_alpn_protocols(ALPN_PROTOCOLS)
except NotImplementedError: # Defensive: in CI, we always have set_alpn_protocols except NotImplementedError: # Defensive: in CI, we always have set_alpn_protocols
pass pass
# If we detect server_hostname is an IP address then the SNI ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls, server_hostname)
# extension should not be used according to RFC3546 Section 3.1
use_sni_hostname = server_hostname and not is_ipaddress(server_hostname)
# SecureTransport uses server_hostname in certificate verification.
send_sni = (use_sni_hostname and HAS_SNI) or (
IS_SECURETRANSPORT and server_hostname
)
# Do not warn the user if server_hostname is an invalid SNI hostname.
if not HAS_SNI and use_sni_hostname:
warnings.warn(
"An HTTPS request has been made, but the SNI (Server Name "
"Indication) extension to TLS is not available on this platform. "
"This may cause the server to present an incorrect TLS "
"certificate, which can cause validation failures. You can upgrade to "
"a newer version of Python to solve this. For more information, see "
"https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html"
"#ssl-warnings",
SNIMissingWarning,
)
if send_sni:
ssl_sock = _ssl_wrap_socket_impl(
sock, context, tls_in_tls, server_hostname=server_hostname
)
else:
ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls)
return ssl_sock return ssl_sock
def is_ipaddress(hostname): def is_ipaddress(hostname: str | bytes) -> bool:
"""Detects whether the hostname given is an IPv4 or IPv6 address. """Detects whether the hostname given is an IPv4 or IPv6 address.
Also detects IPv6 addresses with Zone IDs. Also detects IPv6 addresses with Zone IDs.
:param str hostname: Hostname to examine. :param str hostname: Hostname to examine.
:return: True if the hostname is an IP address, False otherwise. :return: True if the hostname is an IP address, False otherwise.
""" """
if not six.PY2 and isinstance(hostname, bytes): if isinstance(hostname, bytes):
# IDN A-label bytes are ASCII compatible. # IDN A-label bytes are ASCII compatible.
hostname = hostname.decode("ascii") hostname = hostname.decode("ascii")
return bool(IPV4_RE.match(hostname) or BRACELESS_IPV6_ADDRZ_RE.match(hostname)) return bool(_IPV4_RE.match(hostname) or _BRACELESS_IPV6_ADDRZ_RE.match(hostname))
def _is_key_file_encrypted(key_file): def _is_key_file_encrypted(key_file: str) -> bool:
"""Detects if a key file is encrypted or not.""" """Detects if a key file is encrypted or not."""
with open(key_file, "r") as f: with open(key_file) as f:
for line in f: for line in f:
# Look for Proc-Type: 4,ENCRYPTED # Look for Proc-Type: 4,ENCRYPTED
if "ENCRYPTED" in line: if "ENCRYPTED" in line:
@ -478,7 +494,12 @@ def _is_key_file_encrypted(key_file):
return False return False
def _ssl_wrap_socket_impl(sock, ssl_context, tls_in_tls, server_hostname=None): def _ssl_wrap_socket_impl(
sock: socket.socket,
ssl_context: ssl.SSLContext,
tls_in_tls: bool,
server_hostname: str | None = None,
) -> ssl.SSLSocket | SSLTransportType:
if tls_in_tls: if tls_in_tls:
if not SSLTransport: if not SSLTransport:
# Import error, ssl is not available. # Import error, ssl is not available.
@ -489,7 +510,4 @@ def _ssl_wrap_socket_impl(sock, ssl_context, tls_in_tls, server_hostname=None):
SSLTransport._validate_ssl_context_for_tls_in_tls(ssl_context) SSLTransport._validate_ssl_context_for_tls_in_tls(ssl_context)
return SSLTransport(sock, ssl_context, server_hostname) return SSLTransport(sock, ssl_context, server_hostname)
if server_hostname: return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
else:
return ssl_context.wrap_socket(sock)

View file

@ -1,19 +1,18 @@
"""The match_hostname() function from Python 3.3.3, essential when using SSL.""" """The match_hostname() function from Python 3.5, essential when using SSL."""
# Note: This file is under the PSF license as the code comes from the python # Note: This file is under the PSF license as the code comes from the python
# stdlib. http://docs.python.org/3/license.html # stdlib. http://docs.python.org/3/license.html
# It is modified to remove commonName support.
from __future__ import annotations
import ipaddress
import re import re
import sys import typing
from ipaddress import IPv4Address, IPv6Address
# ipaddress has been backported to 2.6+ in pypi. If it is installed on the if typing.TYPE_CHECKING:
# system, use it to handle IPAddress ServerAltnames (this was added in from .ssl_ import _TYPE_PEER_CERT_RET_DICT
# python-3.5) otherwise only do DNS matching. This allows
# util.ssl_match_hostname to continue to be used in Python 2.7.
try:
import ipaddress
except ImportError:
ipaddress = None
__version__ = "3.5.0.1" __version__ = "3.5.0.1"
@ -22,7 +21,9 @@ class CertificateError(ValueError):
pass pass
def _dnsname_match(dn, hostname, max_wildcards=1): def _dnsname_match(
dn: typing.Any, hostname: str, max_wildcards: int = 1
) -> typing.Match[str] | None | bool:
"""Matching according to RFC 6125, section 6.4.3 """Matching according to RFC 6125, section 6.4.3
http://tools.ietf.org/html/rfc6125#section-6.4.3 http://tools.ietf.org/html/rfc6125#section-6.4.3
@ -49,7 +50,7 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
# speed up common case w/o wildcards # speed up common case w/o wildcards
if not wildcards: if not wildcards:
return dn.lower() == hostname.lower() return bool(dn.lower() == hostname.lower())
# RFC 6125, section 6.4.3, subitem 1. # RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which # The client SHOULD NOT attempt to match a presented identifier in which
@ -76,26 +77,26 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
return pat.match(hostname) return pat.match(hostname)
def _to_unicode(obj): def _ipaddress_match(ipname: str, host_ip: IPv4Address | IPv6Address) -> bool:
if isinstance(obj, str) and sys.version_info < (3,):
# ignored flake8 # F821 to support python 2.7 function
obj = unicode(obj, encoding="ascii", errors="strict") # noqa: F821
return obj
def _ipaddress_match(ipname, host_ip):
"""Exact matching of IP addresses. """Exact matching of IP addresses.
RFC 6125 explicitly doesn't define an algorithm for this RFC 9110 section 4.3.5: "A reference identity of IP-ID contains the decoded
(section 1.7.2 - "Out of Scope"). bytes of the IP address. An IP version 4 address is 4 octets, and an IP
version 6 address is 16 octets. [...] A reference identity of type IP-ID
matches if the address is identical to an iPAddress value of the
subjectAltName extension of the certificate."
""" """
# OpenSSL may add a trailing newline to a subjectAltName's IP address # OpenSSL may add a trailing newline to a subjectAltName's IP address
# Divergence from upstream: ipaddress can't handle byte str # Divergence from upstream: ipaddress can't handle byte str
ip = ipaddress.ip_address(_to_unicode(ipname).rstrip()) ip = ipaddress.ip_address(ipname.rstrip())
return ip == host_ip return bool(ip.packed == host_ip.packed)
def match_hostname(cert, hostname): def match_hostname(
cert: _TYPE_PEER_CERT_RET_DICT | None,
hostname: str,
hostname_checks_common_name: bool = False,
) -> None:
"""Verify that *cert* (in decoded format as returned by """Verify that *cert* (in decoded format as returned by
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
rules are followed, but IP addresses are not accepted for *hostname*. rules are followed, but IP addresses are not accepted for *hostname*.
@ -111,21 +112,22 @@ def match_hostname(cert, hostname):
) )
try: try:
# Divergence from upstream: ipaddress can't handle byte str # Divergence from upstream: ipaddress can't handle byte str
host_ip = ipaddress.ip_address(_to_unicode(hostname)) #
except (UnicodeError, ValueError): # The ipaddress module shipped with Python < 3.9 does not support
# ValueError: Not an IP address (common case) # scoped IPv6 addresses so we unconditionally strip the Zone IDs for
# UnicodeError: Divergence from upstream: Have to deal with ipaddress not taking # now. Once we drop support for Python 3.9 we can remove this branch.
# byte strings. addresses should be all ascii, so we consider it not if "%" in hostname:
# an ipaddress in this case host_ip = ipaddress.ip_address(hostname[: hostname.rfind("%")])
else:
host_ip = ipaddress.ip_address(hostname)
except ValueError:
# Not an IP address (common case)
host_ip = None host_ip = None
except AttributeError:
# Divergence from upstream: Make ipaddress library optional
if ipaddress is None:
host_ip = None
else: # Defensive
raise
dnsnames = [] dnsnames = []
san = cert.get("subjectAltName", ()) san: tuple[tuple[str, str], ...] = cert.get("subjectAltName", ())
key: str
value: str
for key, value in san: for key, value in san:
if key == "DNS": if key == "DNS":
if host_ip is None and _dnsname_match(value, hostname): if host_ip is None and _dnsname_match(value, hostname):
@ -135,25 +137,23 @@ def match_hostname(cert, hostname):
if host_ip is not None and _ipaddress_match(value, host_ip): if host_ip is not None and _ipaddress_match(value, host_ip):
return return
dnsnames.append(value) dnsnames.append(value)
if not dnsnames:
# The subject is only checked when there is no dNSName entry # We only check 'commonName' if it's enabled and we're not verifying
# in subjectAltName # an IP address. IP addresses aren't valid within 'commonName'.
if hostname_checks_common_name and host_ip is None and not dnsnames:
for sub in cert.get("subject", ()): for sub in cert.get("subject", ()):
for key, value in sub: for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name
# must be used.
if key == "commonName": if key == "commonName":
if _dnsname_match(value, hostname): if _dnsname_match(value, hostname):
return return
dnsnames.append(value) dnsnames.append(value)
if len(dnsnames) > 1: if len(dnsnames) > 1:
raise CertificateError( raise CertificateError(
"hostname %r " "hostname %r "
"doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames))) "doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
) )
elif len(dnsnames) == 1: elif len(dnsnames) == 1:
raise CertificateError("hostname %r doesn't match %r" % (hostname, dnsnames[0])) raise CertificateError(f"hostname {hostname!r} doesn't match {dnsnames[0]!r}")
else: else:
raise CertificateError( raise CertificateError("no appropriate subjectAltName fields were found")
"no appropriate commonName or subjectAltName fields were found"
)

View file

@ -1,9 +1,21 @@
from __future__ import annotations
import io import io
import socket import socket
import ssl import ssl
import typing
from ..exceptions import ProxySchemeUnsupported from ..exceptions import ProxySchemeUnsupported
from ..packages import six
if typing.TYPE_CHECKING:
from typing_extensions import Literal
from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
_SelfT = typing.TypeVar("_SelfT", bound="SSLTransport")
_WriteBuffer = typing.Union[bytearray, memoryview]
_ReturnValue = typing.TypeVar("_ReturnValue")
SSL_BLOCKSIZE = 16384 SSL_BLOCKSIZE = 16384
@ -20,7 +32,7 @@ class SSLTransport:
""" """
@staticmethod @staticmethod
def _validate_ssl_context_for_tls_in_tls(ssl_context): def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
""" """
Raises a ProxySchemeUnsupported if the provided ssl_context can't be used Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
for TLS in TLS. for TLS in TLS.
@ -30,20 +42,18 @@ class SSLTransport:
""" """
if not hasattr(ssl_context, "wrap_bio"): if not hasattr(ssl_context, "wrap_bio"):
if six.PY2: raise ProxySchemeUnsupported(
raise ProxySchemeUnsupported( "TLS in TLS requires SSLContext.wrap_bio() which isn't "
"TLS in TLS requires SSLContext.wrap_bio() which isn't " "available on non-native SSLContext"
"supported on Python 2" )
)
else:
raise ProxySchemeUnsupported(
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
"available on non-native SSLContext"
)
def __init__( def __init__(
self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True self,
): socket: socket.socket,
ssl_context: ssl.SSLContext,
server_hostname: str | None = None,
suppress_ragged_eofs: bool = True,
) -> None:
""" """
Create an SSLTransport around socket using the provided ssl_context. Create an SSLTransport around socket using the provided ssl_context.
""" """
@ -60,33 +70,36 @@ class SSLTransport:
# Perform initial handshake. # Perform initial handshake.
self._ssl_io_loop(self.sslobj.do_handshake) self._ssl_io_loop(self.sslobj.do_handshake)
def __enter__(self): def __enter__(self: _SelfT) -> _SelfT:
return self return self
def __exit__(self, *_): def __exit__(self, *_: typing.Any) -> None:
self.close() self.close()
def fileno(self): def fileno(self) -> int:
return self.socket.fileno() return self.socket.fileno()
def read(self, len=1024, buffer=None): def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
return self._wrap_ssl_read(len, buffer) return self._wrap_ssl_read(len, buffer)
def recv(self, len=1024, flags=0): def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
if flags != 0: if flags != 0:
raise ValueError("non-zero flags not allowed in calls to recv") raise ValueError("non-zero flags not allowed in calls to recv")
return self._wrap_ssl_read(len) return self._wrap_ssl_read(buflen)
def recv_into(self, buffer, nbytes=None, flags=0): def recv_into(
self,
buffer: _WriteBuffer,
nbytes: int | None = None,
flags: int = 0,
) -> None | int | bytes:
if flags != 0: if flags != 0:
raise ValueError("non-zero flags not allowed in calls to recv_into") raise ValueError("non-zero flags not allowed in calls to recv_into")
if buffer and (nbytes is None): if nbytes is None:
nbytes = len(buffer) nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
return self.read(nbytes, buffer) return self.read(nbytes, buffer)
def sendall(self, data, flags=0): def sendall(self, data: bytes, flags: int = 0) -> None:
if flags != 0: if flags != 0:
raise ValueError("non-zero flags not allowed in calls to sendall") raise ValueError("non-zero flags not allowed in calls to sendall")
count = 0 count = 0
@ -96,15 +109,20 @@ class SSLTransport:
v = self.send(byte_view[count:]) v = self.send(byte_view[count:])
count += v count += v
def send(self, data, flags=0): def send(self, data: bytes, flags: int = 0) -> int:
if flags != 0: if flags != 0:
raise ValueError("non-zero flags not allowed in calls to send") raise ValueError("non-zero flags not allowed in calls to send")
response = self._ssl_io_loop(self.sslobj.write, data) return self._ssl_io_loop(self.sslobj.write, data)
return response
def makefile( def makefile(
self, mode="r", buffering=None, encoding=None, errors=None, newline=None self,
): mode: str,
buffering: int | None = None,
*,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
""" """
Python's httpclient uses makefile and buffered io when reading HTTP Python's httpclient uses makefile and buffered io when reading HTTP
messages and we need to support it. messages and we need to support it.
@ -113,7 +131,7 @@ class SSLTransport:
changes to point to the socket directly. changes to point to the socket directly.
""" """
if not set(mode) <= {"r", "w", "b"}: if not set(mode) <= {"r", "w", "b"}:
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,)) raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
writing = "w" in mode writing = "w" in mode
reading = "r" in mode or not writing reading = "r" in mode or not writing
@ -124,8 +142,8 @@ class SSLTransport:
rawmode += "r" rawmode += "r"
if writing: if writing:
rawmode += "w" rawmode += "w"
raw = socket.SocketIO(self, rawmode) raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
self.socket._io_refs += 1 self.socket._io_refs += 1 # type: ignore[attr-defined]
if buffering is None: if buffering is None:
buffering = -1 buffering = -1
if buffering < 0: if buffering < 0:
@ -134,8 +152,9 @@ class SSLTransport:
if not binary: if not binary:
raise ValueError("unbuffered streams must be binary") raise ValueError("unbuffered streams must be binary")
return raw return raw
buffer: typing.BinaryIO
if reading and writing: if reading and writing:
buffer = io.BufferedRWPair(raw, raw, buffering) buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
elif reading: elif reading:
buffer = io.BufferedReader(raw, buffering) buffer = io.BufferedReader(raw, buffering)
else: else:
@ -144,46 +163,56 @@ class SSLTransport:
if binary: if binary:
return buffer return buffer
text = io.TextIOWrapper(buffer, encoding, errors, newline) text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode text.mode = mode # type: ignore[misc]
return text return text
def unwrap(self): def unwrap(self) -> None:
self._ssl_io_loop(self.sslobj.unwrap) self._ssl_io_loop(self.sslobj.unwrap)
def close(self): def close(self) -> None:
self.socket.close() self.socket.close()
def getpeercert(self, binary_form=False): @typing.overload
return self.sslobj.getpeercert(binary_form) def getpeercert(
self, binary_form: Literal[False] = ...
) -> _TYPE_PEER_CERT_RET_DICT | None:
...
def version(self): @typing.overload
def getpeercert(self, binary_form: Literal[True]) -> bytes | None:
...
def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
def version(self) -> str | None:
return self.sslobj.version() return self.sslobj.version()
def cipher(self): def cipher(self) -> tuple[str, str, int] | None:
return self.sslobj.cipher() return self.sslobj.cipher()
def selected_alpn_protocol(self): def selected_alpn_protocol(self) -> str | None:
return self.sslobj.selected_alpn_protocol() return self.sslobj.selected_alpn_protocol()
def selected_npn_protocol(self): def selected_npn_protocol(self) -> str | None:
return self.sslobj.selected_npn_protocol() return self.sslobj.selected_npn_protocol()
def shared_ciphers(self): def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
return self.sslobj.shared_ciphers() return self.sslobj.shared_ciphers()
def compression(self): def compression(self) -> str | None:
return self.sslobj.compression() return self.sslobj.compression()
def settimeout(self, value): def settimeout(self, value: float | None) -> None:
self.socket.settimeout(value) self.socket.settimeout(value)
def gettimeout(self): def gettimeout(self) -> float | None:
return self.socket.gettimeout() return self.socket.gettimeout()
def _decref_socketios(self): def _decref_socketios(self) -> None:
self.socket._decref_socketios() self.socket._decref_socketios() # type: ignore[attr-defined]
def _wrap_ssl_read(self, len, buffer=None): def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
try: try:
return self._ssl_io_loop(self.sslobj.read, len, buffer) return self._ssl_io_loop(self.sslobj.read, len, buffer)
except ssl.SSLError as e: except ssl.SSLError as e:
@ -192,7 +221,32 @@ class SSLTransport:
else: else:
raise raise
def _ssl_io_loop(self, func, *args): # func is sslobj.do_handshake or sslobj.unwrap
@typing.overload
def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None:
...
# func is sslobj.write, arg1 is data
@typing.overload
def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int:
...
# func is sslobj.read, arg1 is len, arg2 is buffer
@typing.overload
def _ssl_io_loop(
self,
func: typing.Callable[[int, bytearray | None], bytes],
arg1: int,
arg2: bytearray | None,
) -> bytes:
...
def _ssl_io_loop(
self,
func: typing.Callable[..., _ReturnValue],
arg1: None | bytes | int = None,
arg2: bytearray | None = None,
) -> _ReturnValue:
"""Performs an I/O loop between incoming/outgoing and the socket.""" """Performs an I/O loop between incoming/outgoing and the socket."""
should_loop = True should_loop = True
ret = None ret = None
@ -200,7 +254,12 @@ class SSLTransport:
while should_loop: while should_loop:
errno = None errno = None
try: try:
ret = func(*args) if arg1 is None and arg2 is None:
ret = func()
elif arg2 is None:
ret = func(arg1)
else:
ret = func(arg1, arg2)
except ssl.SSLError as e: except ssl.SSLError as e:
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
# WANT_READ, and WANT_WRITE are expected, others are not. # WANT_READ, and WANT_WRITE are expected, others are not.
@ -218,4 +277,4 @@ class SSLTransport:
self.incoming.write(buf) self.incoming.write(buf)
else: else:
self.incoming.write_eof() self.incoming.write_eof()
return ret return typing.cast(_ReturnValue, ret)

View file

@ -1,44 +1,56 @@
from __future__ import absolute_import from __future__ import annotations
import time import time
import typing
# The default socket timeout, used by httplib to indicate that no timeout was; specified by the user from enum import Enum
from socket import _GLOBAL_DEFAULT_TIMEOUT, getdefaulttimeout from socket import getdefaulttimeout
from ..exceptions import TimeoutStateError from ..exceptions import TimeoutStateError
# A sentinel value to indicate that no timeout was specified by the user in if typing.TYPE_CHECKING:
# urllib3 from typing_extensions import Final
_Default = object()
# Use time.monotonic if available. class _TYPE_DEFAULT(Enum):
current_time = getattr(time, "monotonic", time.time) # This value should never be passed to socket.settimeout() so for safety we use a -1.
# socket.settimout() raises a ValueError for negative values.
token = -1
class Timeout(object): _DEFAULT_TIMEOUT: Final[_TYPE_DEFAULT] = _TYPE_DEFAULT.token
_TYPE_TIMEOUT = typing.Optional[typing.Union[float, _TYPE_DEFAULT]]
class Timeout:
"""Timeout configuration. """Timeout configuration.
Timeouts can be defined as a default for a pool: Timeouts can be defined as a default for a pool:
.. code-block:: python .. code-block:: python
timeout = Timeout(connect=2.0, read=7.0) import urllib3
http = PoolManager(timeout=timeout)
response = http.request('GET', 'http://example.com/') timeout = urllib3.util.Timeout(connect=2.0, read=7.0)
http = urllib3.PoolManager(timeout=timeout)
resp = http.request("GET", "https://example.com/")
print(resp.status)
Or per-request (which overrides the default for the pool): Or per-request (which overrides the default for the pool):
.. code-block:: python .. code-block:: python
response = http.request('GET', 'http://example.com/', timeout=Timeout(10)) response = http.request("GET", "https://example.com/", timeout=Timeout(10))
Timeouts can be disabled by setting all the parameters to ``None``: Timeouts can be disabled by setting all the parameters to ``None``:
.. code-block:: python .. code-block:: python
no_timeout = Timeout(connect=None, read=None) no_timeout = Timeout(connect=None, read=None)
response = http.request('GET', 'http://example.com/, timeout=no_timeout) response = http.request("GET", "https://example.com/", timeout=no_timeout)
:param total: :param total:
@ -96,31 +108,31 @@ class Timeout(object):
""" """
#: A sentinel object representing the default timeout value #: A sentinel object representing the default timeout value
DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT DEFAULT_TIMEOUT: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT
def __init__(self, total=None, connect=_Default, read=_Default): def __init__(
self,
total: _TYPE_TIMEOUT = None,
connect: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
read: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
) -> None:
self._connect = self._validate_timeout(connect, "connect") self._connect = self._validate_timeout(connect, "connect")
self._read = self._validate_timeout(read, "read") self._read = self._validate_timeout(read, "read")
self.total = self._validate_timeout(total, "total") self.total = self._validate_timeout(total, "total")
self._start_connect = None self._start_connect: float | None = None
def __repr__(self): def __repr__(self) -> str:
return "%s(connect=%r, read=%r, total=%r)" % ( return f"{type(self).__name__}(connect={self._connect!r}, read={self._read!r}, total={self.total!r})"
type(self).__name__,
self._connect,
self._read,
self.total,
)
# __str__ provided for backwards compatibility # __str__ provided for backwards compatibility
__str__ = __repr__ __str__ = __repr__
@classmethod @staticmethod
def resolve_default_timeout(cls, timeout): def resolve_default_timeout(timeout: _TYPE_TIMEOUT) -> float | None:
return getdefaulttimeout() if timeout is cls.DEFAULT_TIMEOUT else timeout return getdefaulttimeout() if timeout is _DEFAULT_TIMEOUT else timeout
@classmethod @classmethod
def _validate_timeout(cls, value, name): def _validate_timeout(cls, value: _TYPE_TIMEOUT, name: str) -> _TYPE_TIMEOUT:
"""Check that a timeout attribute is valid. """Check that a timeout attribute is valid.
:param value: The timeout value to validate :param value: The timeout value to validate
@ -130,10 +142,7 @@ class Timeout(object):
:raises ValueError: If it is a numeric value less than or equal to :raises ValueError: If it is a numeric value less than or equal to
zero, or the type is not an integer, float, or None. zero, or the type is not an integer, float, or None.
""" """
if value is _Default: if value is None or value is _DEFAULT_TIMEOUT:
return cls.DEFAULT_TIMEOUT
if value is None or value is cls.DEFAULT_TIMEOUT:
return value return value
if isinstance(value, bool): if isinstance(value, bool):
@ -147,7 +156,7 @@ class Timeout(object):
raise ValueError( raise ValueError(
"Timeout value %s was %s, but it must be an " "Timeout value %s was %s, but it must be an "
"int, float or None." % (name, value) "int, float or None." % (name, value)
) ) from None
try: try:
if value <= 0: if value <= 0:
@ -157,16 +166,15 @@ class Timeout(object):
"than or equal to 0." % (name, value) "than or equal to 0." % (name, value)
) )
except TypeError: except TypeError:
# Python 3
raise ValueError( raise ValueError(
"Timeout value %s was %s, but it must be an " "Timeout value %s was %s, but it must be an "
"int, float or None." % (name, value) "int, float or None." % (name, value)
) ) from None
return value return value
@classmethod @classmethod
def from_float(cls, timeout): def from_float(cls, timeout: _TYPE_TIMEOUT) -> Timeout:
"""Create a new Timeout from a legacy timeout value. """Create a new Timeout from a legacy timeout value.
The timeout value used by httplib.py sets the same timeout on the The timeout value used by httplib.py sets the same timeout on the
@ -175,13 +183,13 @@ class Timeout(object):
passed to this function. passed to this function.
:param timeout: The legacy timeout value. :param timeout: The legacy timeout value.
:type timeout: integer, float, sentinel default object, or None :type timeout: integer, float, :attr:`urllib3.util.Timeout.DEFAULT_TIMEOUT`, or None
:return: Timeout object :return: Timeout object
:rtype: :class:`Timeout` :rtype: :class:`Timeout`
""" """
return Timeout(read=timeout, connect=timeout) return Timeout(read=timeout, connect=timeout)
def clone(self): def clone(self) -> Timeout:
"""Create a copy of the timeout object """Create a copy of the timeout object
Timeout properties are stored per-pool but each request needs a fresh Timeout properties are stored per-pool but each request needs a fresh
@ -195,7 +203,7 @@ class Timeout(object):
# detect the user default. # detect the user default.
return Timeout(connect=self._connect, read=self._read, total=self.total) return Timeout(connect=self._connect, read=self._read, total=self.total)
def start_connect(self): def start_connect(self) -> float:
"""Start the timeout clock, used during a connect() attempt """Start the timeout clock, used during a connect() attempt
:raises urllib3.exceptions.TimeoutStateError: if you attempt :raises urllib3.exceptions.TimeoutStateError: if you attempt
@ -203,10 +211,10 @@ class Timeout(object):
""" """
if self._start_connect is not None: if self._start_connect is not None:
raise TimeoutStateError("Timeout timer has already been started.") raise TimeoutStateError("Timeout timer has already been started.")
self._start_connect = current_time() self._start_connect = time.monotonic()
return self._start_connect return self._start_connect
def get_connect_duration(self): def get_connect_duration(self) -> float:
"""Gets the time elapsed since the call to :meth:`start_connect`. """Gets the time elapsed since the call to :meth:`start_connect`.
:return: Elapsed time in seconds. :return: Elapsed time in seconds.
@ -218,10 +226,10 @@ class Timeout(object):
raise TimeoutStateError( raise TimeoutStateError(
"Can't get connect duration for timer that has not started." "Can't get connect duration for timer that has not started."
) )
return current_time() - self._start_connect return time.monotonic() - self._start_connect
@property @property
def connect_timeout(self): def connect_timeout(self) -> _TYPE_TIMEOUT:
"""Get the value to use when setting a connection timeout. """Get the value to use when setting a connection timeout.
This will be a positive float or integer, the value None This will be a positive float or integer, the value None
@ -233,13 +241,13 @@ class Timeout(object):
if self.total is None: if self.total is None:
return self._connect return self._connect
if self._connect is None or self._connect is self.DEFAULT_TIMEOUT: if self._connect is None or self._connect is _DEFAULT_TIMEOUT:
return self.total return self.total
return min(self._connect, self.total) return min(self._connect, self.total) # type: ignore[type-var]
@property @property
def read_timeout(self): def read_timeout(self) -> float | None:
"""Get the value for the read timeout. """Get the value for the read timeout.
This assumes some time has elapsed in the connection timeout and This assumes some time has elapsed in the connection timeout and
@ -251,21 +259,21 @@ class Timeout(object):
raised. raised.
:return: Value to use for the read timeout. :return: Value to use for the read timeout.
:rtype: int, float, :attr:`Timeout.DEFAULT_TIMEOUT` or None :rtype: int, float or None
:raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect` :raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect`
has not yet been called on this object. has not yet been called on this object.
""" """
if ( if (
self.total is not None self.total is not None
and self.total is not self.DEFAULT_TIMEOUT and self.total is not _DEFAULT_TIMEOUT
and self._read is not None and self._read is not None
and self._read is not self.DEFAULT_TIMEOUT and self._read is not _DEFAULT_TIMEOUT
): ):
# In case the connect timeout has not yet been established. # In case the connect timeout has not yet been established.
if self._start_connect is None: if self._start_connect is None:
return self._read return self._read
return max(0, min(self.total - self.get_connect_duration(), self._read)) return max(0, min(self.total - self.get_connect_duration(), self._read))
elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT: elif self.total is not None and self.total is not _DEFAULT_TIMEOUT:
return max(0, self.total - self.get_connect_duration()) return max(0, self.total - self.get_connect_duration())
else: else:
return self._read return self.resolve_default_timeout(self._read)

View file

@ -1,22 +1,20 @@
from __future__ import absolute_import from __future__ import annotations
import re import re
from collections import namedtuple import typing
from ..exceptions import LocationParseError from ..exceptions import LocationParseError
from ..packages import six from .util import to_str
url_attrs = ["scheme", "auth", "host", "port", "path", "query", "fragment"]
# We only want to normalize urls with an HTTP(S) scheme. # We only want to normalize urls with an HTTP(S) scheme.
# urllib3 infers URLs without a scheme (None) to be http. # urllib3 infers URLs without a scheme (None) to be http.
NORMALIZABLE_SCHEMES = ("http", "https", None) _NORMALIZABLE_SCHEMES = ("http", "https", None)
# Almost all of these patterns were derived from the # Almost all of these patterns were derived from the
# 'rfc3986' module: https://github.com/python-hyper/rfc3986 # 'rfc3986' module: https://github.com/python-hyper/rfc3986
PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}") _PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}")
SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)") _SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)")
URI_RE = re.compile( _URI_RE = re.compile(
r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?" r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?"
r"(?://([^\\/?#]*))?" r"(?://([^\\/?#]*))?"
r"([^?#]*)" r"([^?#]*)"
@ -25,10 +23,10 @@ URI_RE = re.compile(
re.UNICODE | re.DOTALL, re.UNICODE | re.DOTALL,
) )
IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}" _IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}"
HEX_PAT = "[0-9A-Fa-f]{1,4}" _HEX_PAT = "[0-9A-Fa-f]{1,4}"
LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=HEX_PAT, ipv4=IPV4_PAT) _LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=_HEX_PAT, ipv4=_IPV4_PAT)
_subs = {"hex": HEX_PAT, "ls32": LS32_PAT} _subs = {"hex": _HEX_PAT, "ls32": _LS32_PAT}
_variations = [ _variations = [
# 6( h16 ":" ) ls32 # 6( h16 ":" ) ls32
"(?:%(hex)s:){6}%(ls32)s", "(?:%(hex)s:){6}%(ls32)s",
@ -50,69 +48,78 @@ _variations = [
"(?:(?:%(hex)s:){0,6}%(hex)s)?::", "(?:(?:%(hex)s:){0,6}%(hex)s)?::",
] ]
UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._\-~" _UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._\-~"
IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")" _IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")"
ZONE_ID_PAT = "(?:%25|%)(?:[" + UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+" _ZONE_ID_PAT = "(?:%25|%)(?:[" + _UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+"
IPV6_ADDRZ_PAT = r"\[" + IPV6_PAT + r"(?:" + ZONE_ID_PAT + r")?\]" _IPV6_ADDRZ_PAT = r"\[" + _IPV6_PAT + r"(?:" + _ZONE_ID_PAT + r")?\]"
REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*" _REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*"
TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$") _TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$")
IPV4_RE = re.compile("^" + IPV4_PAT + "$") _IPV4_RE = re.compile("^" + _IPV4_PAT + "$")
IPV6_RE = re.compile("^" + IPV6_PAT + "$") _IPV6_RE = re.compile("^" + _IPV6_PAT + "$")
IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT + "$") _IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT + "$")
BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT[2:-2] + "$") _BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT[2:-2] + "$")
ZONE_ID_RE = re.compile("(" + ZONE_ID_PAT + r")\]$") _ZONE_ID_RE = re.compile("(" + _ZONE_ID_PAT + r")\]$")
_HOST_PORT_PAT = ("^(%s|%s|%s)(?::0*?(|0|[1-9][0-9]{0,4}))?$") % ( _HOST_PORT_PAT = ("^(%s|%s|%s)(?::0*?(|0|[1-9][0-9]{0,4}))?$") % (
REG_NAME_PAT, _REG_NAME_PAT,
IPV4_PAT, _IPV4_PAT,
IPV6_ADDRZ_PAT, _IPV6_ADDRZ_PAT,
) )
_HOST_PORT_RE = re.compile(_HOST_PORT_PAT, re.UNICODE | re.DOTALL) _HOST_PORT_RE = re.compile(_HOST_PORT_PAT, re.UNICODE | re.DOTALL)
UNRESERVED_CHARS = set( _UNRESERVED_CHARS = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~" "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~"
) )
SUB_DELIM_CHARS = set("!$&'()*+,;=") _SUB_DELIM_CHARS = set("!$&'()*+,;=")
USERINFO_CHARS = UNRESERVED_CHARS | SUB_DELIM_CHARS | {":"} _USERINFO_CHARS = _UNRESERVED_CHARS | _SUB_DELIM_CHARS | {":"}
PATH_CHARS = USERINFO_CHARS | {"@", "/"} _PATH_CHARS = _USERINFO_CHARS | {"@", "/"}
QUERY_CHARS = FRAGMENT_CHARS = PATH_CHARS | {"?"} _QUERY_CHARS = _FRAGMENT_CHARS = _PATH_CHARS | {"?"}
class Url(namedtuple("Url", url_attrs)): class Url(
typing.NamedTuple(
"Url",
[
("scheme", typing.Optional[str]),
("auth", typing.Optional[str]),
("host", typing.Optional[str]),
("port", typing.Optional[int]),
("path", typing.Optional[str]),
("query", typing.Optional[str]),
("fragment", typing.Optional[str]),
],
)
):
""" """
Data structure for representing an HTTP URL. Used as a return value for Data structure for representing an HTTP URL. Used as a return value for
:func:`parse_url`. Both the scheme and host are normalized as they are :func:`parse_url`. Both the scheme and host are normalized as they are
both case-insensitive according to RFC 3986. both case-insensitive according to RFC 3986.
""" """
__slots__ = () def __new__( # type: ignore[no-untyped-def]
def __new__(
cls, cls,
scheme=None, scheme: str | None = None,
auth=None, auth: str | None = None,
host=None, host: str | None = None,
port=None, port: int | None = None,
path=None, path: str | None = None,
query=None, query: str | None = None,
fragment=None, fragment: str | None = None,
): ):
if path and not path.startswith("/"): if path and not path.startswith("/"):
path = "/" + path path = "/" + path
if scheme is not None: if scheme is not None:
scheme = scheme.lower() scheme = scheme.lower()
return super(Url, cls).__new__( return super().__new__(cls, scheme, auth, host, port, path, query, fragment)
cls, scheme, auth, host, port, path, query, fragment
)
@property @property
def hostname(self): def hostname(self) -> str | None:
"""For backwards-compatibility with urlparse. We're nice like that.""" """For backwards-compatibility with urlparse. We're nice like that."""
return self.host return self.host
@property @property
def request_uri(self): def request_uri(self) -> str:
"""Absolute path including the query string.""" """Absolute path including the query string."""
uri = self.path or "/" uri = self.path or "/"
@ -122,14 +129,37 @@ class Url(namedtuple("Url", url_attrs)):
return uri return uri
@property @property
def netloc(self): def authority(self) -> str | None:
"""Network location including host and port""" """
Authority component as defined in RFC 3986 3.2.
This includes userinfo (auth), host and port.
i.e.
userinfo@host:port
"""
userinfo = self.auth
netloc = self.netloc
if netloc is None or userinfo is None:
return netloc
else:
return f"{userinfo}@{netloc}"
@property
def netloc(self) -> str | None:
"""
Network location including host and port.
If you need the equivalent of urllib.parse's ``netloc``,
use the ``authority`` property instead.
"""
if self.host is None:
return None
if self.port: if self.port:
return "%s:%d" % (self.host, self.port) return f"{self.host}:{self.port}"
return self.host return self.host
@property @property
def url(self): def url(self) -> str:
""" """
Convert self into a url Convert self into a url
@ -138,88 +168,77 @@ class Url(namedtuple("Url", url_attrs)):
:func:`.parse_url`, but it should be equivalent by the RFC (e.g., urls :func:`.parse_url`, but it should be equivalent by the RFC (e.g., urls
with a blank port will have : removed). with a blank port will have : removed).
Example: :: Example:
>>> U = parse_url('http://google.com/mail/') .. code-block:: python
>>> U.url
'http://google.com/mail/' import urllib3
>>> Url('http', 'username:password', 'host.com', 80,
... '/path', 'query', 'fragment').url U = urllib3.util.parse_url("https://google.com/mail/")
'http://username:password@host.com:80/path?query#fragment'
print(U.url)
# "https://google.com/mail/"
print( urllib3.util.Url("https", "username:password",
"host.com", 80, "/path", "query", "fragment"
).url
)
# "https://username:password@host.com:80/path?query#fragment"
""" """
scheme, auth, host, port, path, query, fragment = self scheme, auth, host, port, path, query, fragment = self
url = u"" url = ""
# We use "is not None" we want things to happen with empty strings (or 0 port) # We use "is not None" we want things to happen with empty strings (or 0 port)
if scheme is not None: if scheme is not None:
url += scheme + u"://" url += scheme + "://"
if auth is not None: if auth is not None:
url += auth + u"@" url += auth + "@"
if host is not None: if host is not None:
url += host url += host
if port is not None: if port is not None:
url += u":" + str(port) url += ":" + str(port)
if path is not None: if path is not None:
url += path url += path
if query is not None: if query is not None:
url += u"?" + query url += "?" + query
if fragment is not None: if fragment is not None:
url += u"#" + fragment url += "#" + fragment
return url return url
def __str__(self): def __str__(self) -> str:
return self.url return self.url
def split_first(s, delims): @typing.overload
""" def _encode_invalid_chars(
.. deprecated:: 1.25 component: str, allowed_chars: typing.Container[str]
) -> str: # Abstract
Given a string and an iterable of delimiters, split on the first found ...
delimiter. Return two split parts and the matched delimiter.
If not found, then the first part is the full input string.
Example::
>>> split_first('foo/bar?baz', '?/=')
('foo', 'bar?baz', '/')
>>> split_first('foo/bar?baz', '123')
('foo/bar?baz', '', None)
Scales linearly with number of delims. Not ideal for large number of delims.
"""
min_idx = None
min_delim = None
for d in delims:
idx = s.find(d)
if idx < 0:
continue
if min_idx is None or idx < min_idx:
min_idx = idx
min_delim = d
if min_idx is None or min_idx < 0:
return s, "", None
return s[:min_idx], s[min_idx + 1 :], min_delim
def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"): @typing.overload
def _encode_invalid_chars(
component: None, allowed_chars: typing.Container[str]
) -> None: # Abstract
...
def _encode_invalid_chars(
component: str | None, allowed_chars: typing.Container[str]
) -> str | None:
"""Percent-encodes a URI component without reapplying """Percent-encodes a URI component without reapplying
onto an already percent-encoded component. onto an already percent-encoded component.
""" """
if component is None: if component is None:
return component return component
component = six.ensure_text(component) component = to_str(component)
# Normalize existing percent-encoded bytes. # Normalize existing percent-encoded bytes.
# Try to see if the component we're encoding is already percent-encoded # Try to see if the component we're encoding is already percent-encoded
# so we can skip all '%' characters but still encode all others. # so we can skip all '%' characters but still encode all others.
component, percent_encodings = PERCENT_RE.subn( component, percent_encodings = _PERCENT_RE.subn(
lambda match: match.group(0).upper(), component lambda match: match.group(0).upper(), component
) )
@ -228,7 +247,7 @@ def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
encoded_component = bytearray() encoded_component = bytearray()
for i in range(0, len(uri_bytes)): for i in range(0, len(uri_bytes)):
# Will return a single character bytestring on both Python 2 & 3 # Will return a single character bytestring
byte = uri_bytes[i : i + 1] byte = uri_bytes[i : i + 1]
byte_ord = ord(byte) byte_ord = ord(byte)
if (is_percent_encoded and byte == b"%") or ( if (is_percent_encoded and byte == b"%") or (
@ -238,10 +257,10 @@ def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
continue continue
encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper())) encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper()))
return encoded_component.decode(encoding) return encoded_component.decode()
def _remove_path_dot_segments(path): def _remove_path_dot_segments(path: str) -> str:
# See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code # See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code
segments = path.split("/") # Turn the path into a list of segments segments = path.split("/") # Turn the path into a list of segments
output = [] # Initialize the variable to use to store output output = [] # Initialize the variable to use to store output
@ -251,7 +270,7 @@ def _remove_path_dot_segments(path):
if segment == ".": if segment == ".":
continue continue
# Anything other than '..', should be appended to the output # Anything other than '..', should be appended to the output
elif segment != "..": if segment != "..":
output.append(segment) output.append(segment)
# In this case segment == '..', if we can, we should pop the last # In this case segment == '..', if we can, we should pop the last
# element # element
@ -271,18 +290,25 @@ def _remove_path_dot_segments(path):
return "/".join(output) return "/".join(output)
def _normalize_host(host, scheme): @typing.overload
if host: def _normalize_host(host: None, scheme: str | None) -> None:
if isinstance(host, six.binary_type): ...
host = six.ensure_str(host)
if scheme in NORMALIZABLE_SCHEMES:
is_ipv6 = IPV6_ADDRZ_RE.match(host) @typing.overload
def _normalize_host(host: str, scheme: str | None) -> str:
...
def _normalize_host(host: str | None, scheme: str | None) -> str | None:
if host:
if scheme in _NORMALIZABLE_SCHEMES:
is_ipv6 = _IPV6_ADDRZ_RE.match(host)
if is_ipv6: if is_ipv6:
# IPv6 hosts of the form 'a::b%zone' are encoded in a URL as # IPv6 hosts of the form 'a::b%zone' are encoded in a URL as
# such per RFC 6874: 'a::b%25zone'. Unquote the ZoneID # such per RFC 6874: 'a::b%25zone'. Unquote the ZoneID
# separator as necessary to return a valid RFC 4007 scoped IP. # separator as necessary to return a valid RFC 4007 scoped IP.
match = ZONE_ID_RE.search(host) match = _ZONE_ID_RE.search(host)
if match: if match:
start, end = match.span(1) start, end = match.span(1)
zone_id = host[start:end] zone_id = host[start:end]
@ -291,46 +317,56 @@ def _normalize_host(host, scheme):
zone_id = zone_id[3:] zone_id = zone_id[3:]
else: else:
zone_id = zone_id[1:] zone_id = zone_id[1:]
zone_id = "%" + _encode_invalid_chars(zone_id, UNRESERVED_CHARS) zone_id = _encode_invalid_chars(zone_id, _UNRESERVED_CHARS)
return host[:start].lower() + zone_id + host[end:] return f"{host[:start].lower()}%{zone_id}{host[end:]}"
else: else:
return host.lower() return host.lower()
elif not IPV4_RE.match(host): elif not _IPV4_RE.match(host):
return six.ensure_str( return to_str(
b".".join([_idna_encode(label) for label in host.split(".")]) b".".join([_idna_encode(label) for label in host.split(".")]),
"ascii",
) )
return host return host
def _idna_encode(name): def _idna_encode(name: str) -> bytes:
if name and any(ord(x) >= 128 for x in name): if not name.isascii():
try: try:
import idna import idna
except ImportError: except ImportError:
six.raise_from( raise LocationParseError(
LocationParseError("Unable to parse URL without the 'idna' module"), "Unable to parse URL without the 'idna' module"
None, ) from None
)
try: try:
return idna.encode(name.lower(), strict=True, std3_rules=True) return idna.encode(name.lower(), strict=True, std3_rules=True)
except idna.IDNAError: except idna.IDNAError:
six.raise_from( raise LocationParseError(
LocationParseError(u"Name '%s' is not a valid IDNA label" % name), None f"Name '{name}' is not a valid IDNA label"
) ) from None
return name.lower().encode("ascii") return name.lower().encode("ascii")
def _encode_target(target): def _encode_target(target: str) -> str:
"""Percent-encodes a request target so that there are no invalid characters""" """Percent-encodes a request target so that there are no invalid characters
path, query = TARGET_RE.match(target).groups()
target = _encode_invalid_chars(path, PATH_CHARS) Pre-condition for this function is that 'target' must start with '/'.
query = _encode_invalid_chars(query, QUERY_CHARS) If that is the case then _TARGET_RE will always produce a match.
"""
match = _TARGET_RE.match(target)
if not match: # Defensive:
raise LocationParseError(f"{target!r} is not a valid request URI")
path, query = match.groups()
encoded_target = _encode_invalid_chars(path, _PATH_CHARS)
if query is not None: if query is not None:
target += "?" + query query = _encode_invalid_chars(query, _QUERY_CHARS)
return target encoded_target += "?" + query
return encoded_target
def parse_url(url): def parse_url(url: str) -> Url:
""" """
Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is
performed to parse incomplete urls. Fields not provided will be None. performed to parse incomplete urls. Fields not provided will be None.
@ -341,28 +377,44 @@ def parse_url(url):
:param str url: URL to parse into a :class:`.Url` namedtuple. :param str url: URL to parse into a :class:`.Url` namedtuple.
Partly backwards-compatible with :mod:`urlparse`. Partly backwards-compatible with :mod:`urllib.parse`.
Example:: Example:
>>> parse_url('http://google.com/mail/') .. code-block:: python
Url(scheme='http', host='google.com', port=None, path='/mail/', ...)
>>> parse_url('google.com:80') import urllib3
Url(scheme=None, host='google.com', port=80, path=None, ...)
>>> parse_url('/foo?bar') print( urllib3.util.parse_url('http://google.com/mail/'))
Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...) # Url(scheme='http', host='google.com', port=None, path='/mail/', ...)
print( urllib3.util.parse_url('google.com:80'))
# Url(scheme=None, host='google.com', port=80, path=None, ...)
print( urllib3.util.parse_url('/foo?bar'))
# Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
""" """
if not url: if not url:
# Empty # Empty
return Url() return Url()
source_url = url source_url = url
if not SCHEME_RE.search(url): if not _SCHEME_RE.search(url):
url = "//" + url url = "//" + url
scheme: str | None
authority: str | None
auth: str | None
host: str | None
port: str | None
port_int: int | None
path: str | None
query: str | None
fragment: str | None
try: try:
scheme, authority, path, query, fragment = URI_RE.match(url).groups() scheme, authority, path, query, fragment = _URI_RE.match(url).groups() # type: ignore[union-attr]
normalize_uri = scheme is None or scheme.lower() in NORMALIZABLE_SCHEMES normalize_uri = scheme is None or scheme.lower() in _NORMALIZABLE_SCHEMES
if scheme: if scheme:
scheme = scheme.lower() scheme = scheme.lower()
@ -370,31 +422,33 @@ def parse_url(url):
if authority: if authority:
auth, _, host_port = authority.rpartition("@") auth, _, host_port = authority.rpartition("@")
auth = auth or None auth = auth or None
host, port = _HOST_PORT_RE.match(host_port).groups() host, port = _HOST_PORT_RE.match(host_port).groups() # type: ignore[union-attr]
if auth and normalize_uri: if auth and normalize_uri:
auth = _encode_invalid_chars(auth, USERINFO_CHARS) auth = _encode_invalid_chars(auth, _USERINFO_CHARS)
if port == "": if port == "":
port = None port = None
else: else:
auth, host, port = None, None, None auth, host, port = None, None, None
if port is not None: if port is not None:
port = int(port) port_int = int(port)
if not (0 <= port <= 65535): if not (0 <= port_int <= 65535):
raise LocationParseError(url) raise LocationParseError(url)
else:
port_int = None
host = _normalize_host(host, scheme) host = _normalize_host(host, scheme)
if normalize_uri and path: if normalize_uri and path:
path = _remove_path_dot_segments(path) path = _remove_path_dot_segments(path)
path = _encode_invalid_chars(path, PATH_CHARS) path = _encode_invalid_chars(path, _PATH_CHARS)
if normalize_uri and query: if normalize_uri and query:
query = _encode_invalid_chars(query, QUERY_CHARS) query = _encode_invalid_chars(query, _QUERY_CHARS)
if normalize_uri and fragment: if normalize_uri and fragment:
fragment = _encode_invalid_chars(fragment, FRAGMENT_CHARS) fragment = _encode_invalid_chars(fragment, _FRAGMENT_CHARS)
except (ValueError, AttributeError): except (ValueError, AttributeError) as e:
return six.raise_from(LocationParseError(source_url), None) raise LocationParseError(source_url) from e
# For the sake of backwards compatibility we put empty # For the sake of backwards compatibility we put empty
# string values for path if there are any defined values # string values for path if there are any defined values
@ -406,30 +460,12 @@ def parse_url(url):
else: else:
path = None path = None
# Ensure that each part of the URL is a `str` for
# backwards compatibility.
if isinstance(url, six.text_type):
ensure_func = six.ensure_text
else:
ensure_func = six.ensure_str
def ensure_type(x):
return x if x is None else ensure_func(x)
return Url( return Url(
scheme=ensure_type(scheme), scheme=scheme,
auth=ensure_type(auth), auth=auth,
host=ensure_type(host), host=host,
port=port, port=port_int,
path=ensure_type(path), path=path,
query=ensure_type(query), query=query,
fragment=ensure_type(fragment), fragment=fragment,
) )
def get_host(url):
"""
Deprecated. Use :func:`parse_url` instead.
"""
p = parse_url(url)
return p.scheme or "http", p.hostname, p.port

View file

@ -1,32 +0,0 @@
from typing import Any, List, Optional, Tuple, Union
from .. import exceptions
LocationParseError = exceptions.LocationParseError
url_attrs: List[str]
class Url:
slots: Any
def __new__(
cls,
scheme: Optional[str],
auth: Optional[str],
host: Optional[str],
port: Optional[str],
path: Optional[str],
query: Optional[str],
fragment: Optional[str],
) -> Url: ...
@property
def hostname(self) -> str: ...
@property
def request_uri(self) -> str: ...
@property
def netloc(self) -> str: ...
@property
def url(self) -> str: ...
def split_first(s: str, delims: str) -> Tuple[str, str, Optional[str]]: ...
def parse_url(url: str) -> Url: ...
def get_host(url: str) -> Union[str, Tuple[str]]: ...

42
lib/urllib3/util/util.py Normal file
View file

@ -0,0 +1,42 @@
from __future__ import annotations
import typing
from types import TracebackType
def to_bytes(
x: str | bytes, encoding: str | None = None, errors: str | None = None
) -> bytes:
if isinstance(x, bytes):
return x
elif not isinstance(x, str):
raise TypeError(f"not expecting type {type(x).__name__}")
if encoding or errors:
return x.encode(encoding or "utf-8", errors=errors or "strict")
return x.encode()
def to_str(
x: str | bytes, encoding: str | None = None, errors: str | None = None
) -> str:
if isinstance(x, str):
return x
elif not isinstance(x, bytes):
raise TypeError(f"not expecting type {type(x).__name__}")
if encoding or errors:
return x.decode(encoding or "utf-8", errors=errors or "strict")
return x.decode()
def reraise(
tp: type[BaseException] | None,
value: BaseException,
tb: TracebackType | None = None,
) -> typing.NoReturn:
try:
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
finally:
value = None # type: ignore[assignment]
tb = None

View file

@ -1,18 +1,10 @@
import errno from __future__ import annotations
import select import select
import sys import socket
from functools import partial from functools import partial
try: __all__ = ["wait_for_read", "wait_for_write"]
from time import monotonic
except ImportError:
from time import time as monotonic
__all__ = ["NoWayToWaitForSocketError", "wait_for_read", "wait_for_write"]
class NoWayToWaitForSocketError(Exception):
pass
# How should we wait on sockets? # How should we wait on sockets?
@ -37,37 +29,13 @@ class NoWayToWaitForSocketError(Exception):
# So: on Windows we use select(), and everywhere else we use poll(). We also # So: on Windows we use select(), and everywhere else we use poll(). We also
# fall back to select() in case poll() is somehow broken or missing. # fall back to select() in case poll() is somehow broken or missing.
if sys.version_info >= (3, 5):
# Modern Python, that retries syscalls by default
def _retry_on_intr(fn, timeout):
return fn(timeout)
else: def select_wait_for_socket(
# Old and broken Pythons. sock: socket.socket,
def _retry_on_intr(fn, timeout): read: bool = False,
if timeout is None: write: bool = False,
deadline = float("inf") timeout: float | None = None,
else: ) -> bool:
deadline = monotonic() + timeout
while True:
try:
return fn(timeout)
# OSError for 3 <= pyver < 3.5, select.error for pyver <= 2.7
except (OSError, select.error) as e:
# 'e.args[0]' incantation works for both OSError and select.error
if e.args[0] != errno.EINTR:
raise
else:
timeout = deadline - monotonic()
if timeout < 0:
timeout = 0
if timeout == float("inf"):
timeout = None
continue
def select_wait_for_socket(sock, read=False, write=False, timeout=None):
if not read and not write: if not read and not write:
raise RuntimeError("must specify at least one of read=True, write=True") raise RuntimeError("must specify at least one of read=True, write=True")
rcheck = [] rcheck = []
@ -82,11 +50,16 @@ def select_wait_for_socket(sock, read=False, write=False, timeout=None):
# sockets for both conditions. (The stdlib selectors module does the same # sockets for both conditions. (The stdlib selectors module does the same
# thing.) # thing.)
fn = partial(select.select, rcheck, wcheck, wcheck) fn = partial(select.select, rcheck, wcheck, wcheck)
rready, wready, xready = _retry_on_intr(fn, timeout) rready, wready, xready = fn(timeout)
return bool(rready or wready or xready) return bool(rready or wready or xready)
def poll_wait_for_socket(sock, read=False, write=False, timeout=None): def poll_wait_for_socket(
sock: socket.socket,
read: bool = False,
write: bool = False,
timeout: float | None = None,
) -> bool:
if not read and not write: if not read and not write:
raise RuntimeError("must specify at least one of read=True, write=True") raise RuntimeError("must specify at least one of read=True, write=True")
mask = 0 mask = 0
@ -98,32 +71,33 @@ def poll_wait_for_socket(sock, read=False, write=False, timeout=None):
poll_obj.register(sock, mask) poll_obj.register(sock, mask)
# For some reason, poll() takes timeout in milliseconds # For some reason, poll() takes timeout in milliseconds
def do_poll(t): def do_poll(t: float | None) -> list[tuple[int, int]]:
if t is not None: if t is not None:
t *= 1000 t *= 1000
return poll_obj.poll(t) return poll_obj.poll(t)
return bool(_retry_on_intr(do_poll, timeout)) return bool(do_poll(timeout))
def null_wait_for_socket(*args, **kwargs): def _have_working_poll() -> bool:
raise NoWayToWaitForSocketError("no select-equivalent available")
def _have_working_poll():
# Apparently some systems have a select.poll that fails as soon as you try # Apparently some systems have a select.poll that fails as soon as you try
# to use it, either due to strange configuration or broken monkeypatching # to use it, either due to strange configuration or broken monkeypatching
# from libraries like eventlet/greenlet. # from libraries like eventlet/greenlet.
try: try:
poll_obj = select.poll() poll_obj = select.poll()
_retry_on_intr(poll_obj.poll, 0) poll_obj.poll(0)
except (AttributeError, OSError): except (AttributeError, OSError):
return False return False
else: else:
return True return True
def wait_for_socket(*args, **kwargs): def wait_for_socket(
sock: socket.socket,
read: bool = False,
write: bool = False,
timeout: float | None = None,
) -> bool:
# We delay choosing which implementation to use until the first time we're # We delay choosing which implementation to use until the first time we're
# called. We could do it at import time, but then we might make the wrong # called. We could do it at import time, but then we might make the wrong
# decision if someone goes wild with monkeypatching select.poll after # decision if someone goes wild with monkeypatching select.poll after
@ -133,19 +107,17 @@ def wait_for_socket(*args, **kwargs):
wait_for_socket = poll_wait_for_socket wait_for_socket = poll_wait_for_socket
elif hasattr(select, "select"): elif hasattr(select, "select"):
wait_for_socket = select_wait_for_socket wait_for_socket = select_wait_for_socket
else: # Platform-specific: Appengine. return wait_for_socket(sock, read, write, timeout)
wait_for_socket = null_wait_for_socket
return wait_for_socket(*args, **kwargs)
def wait_for_read(sock, timeout=None): def wait_for_read(sock: socket.socket, timeout: float | None = None) -> bool:
"""Waits for reading to be available on a given socket. """Waits for reading to be available on a given socket.
Returns True if the socket is readable, or False if the timeout expired. Returns True if the socket is readable, or False if the timeout expired.
""" """
return wait_for_socket(sock, read=True, timeout=timeout) return wait_for_socket(sock, read=True, timeout=timeout)
def wait_for_write(sock, timeout=None): def wait_for_write(sock: socket.socket, timeout: float | None = None) -> bool:
"""Waits for writing to be available on a given socket. """Waits for writing to be available on a given socket.
Returns True if the socket is readable, or False if the timeout expired. Returns True if the socket is readable, or False if the timeout expired.
""" """