mirror of
https://github.com/SickGear/SickGear.git
synced 2024-11-24 13:55:16 +00:00
Update urllib3 1.26.15 (25cca389) → 2.0.4 (af7c78fa).
This commit is contained in:
parent
4bb3ba0a15
commit
07935763df
49 changed files with 4441 additions and 5007 deletions
|
@ -10,6 +10,7 @@
|
||||||
* Update package resource API 67.5.1 (f51eccd) to 68.1.2 (1ef36f2)
|
* Update package resource API 67.5.1 (f51eccd) to 68.1.2 (1ef36f2)
|
||||||
* Update soupsieve 2.3.2.post1 (792d566) to 2.4.1 (2e66beb)
|
* Update soupsieve 2.3.2.post1 (792d566) to 2.4.1 (2e66beb)
|
||||||
* Update Tornado Web Server 6.3.2 (e3aa6c5) to 6.3.3 (e4d6984)
|
* Update Tornado Web Server 6.3.2 (e3aa6c5) to 6.3.3 (e4d6984)
|
||||||
|
* Update urllib3 1.26.15 (25cca389) to 2.0.4 (af7c78fa)
|
||||||
* Add thefuzz 0.19.0 (c2cd4f4) as a replacement with fallback to fuzzywuzzy 0.18.0 (2188520)
|
* Add thefuzz 0.19.0 (c2cd4f4) as a replacement with fallback to fuzzywuzzy 0.18.0 (2188520)
|
||||||
* Fix regex that was not using py312 notation
|
* Fix regex that was not using py312 notation
|
||||||
* Change sort backlog and manual segment search results episode number
|
* Change sort backlog and manual segment search results episode number
|
||||||
|
|
|
@ -1,23 +1,48 @@
|
||||||
"""
|
"""
|
||||||
Python HTTP library with thread-safe connection pooling, file post support, user friendly, and more
|
Python HTTP library with thread-safe connection pooling, file post support, user friendly, and more
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
# Set default logging handler to avoid "No handler found" warnings.
|
# Set default logging handler to avoid "No handler found" warnings.
|
||||||
import logging
|
import logging
|
||||||
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from logging import NullHandler
|
from logging import NullHandler
|
||||||
|
|
||||||
from . import exceptions
|
from . import exceptions
|
||||||
|
from ._base_connection import _TYPE_BODY
|
||||||
|
from ._collections import HTTPHeaderDict
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
|
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
|
||||||
from .filepost import encode_multipart_formdata
|
from .filepost import _TYPE_FIELDS, encode_multipart_formdata
|
||||||
from .poolmanager import PoolManager, ProxyManager, proxy_from_url
|
from .poolmanager import PoolManager, ProxyManager, proxy_from_url
|
||||||
from .response import HTTPResponse
|
from .response import BaseHTTPResponse, HTTPResponse
|
||||||
from .util.request import make_headers
|
from .util.request import make_headers
|
||||||
from .util.retry import Retry
|
from .util.retry import Retry
|
||||||
from .util.timeout import Timeout
|
from .util.timeout import Timeout
|
||||||
from .util.url import get_host
|
|
||||||
|
# Ensure that Python is compiled with OpenSSL 1.1.1+
|
||||||
|
# If the 'ssl' module isn't available at all that's
|
||||||
|
# fine, we only care if the module is available.
|
||||||
|
try:
|
||||||
|
import ssl
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if not ssl.OPENSSL_VERSION.startswith("OpenSSL "): # Defensive:
|
||||||
|
warnings.warn(
|
||||||
|
"urllib3 v2.0 only supports OpenSSL 1.1.1+, currently "
|
||||||
|
f"the 'ssl' module is compiled with {ssl.OPENSSL_VERSION!r}. "
|
||||||
|
"See: https://github.com/urllib3/urllib3/issues/3020",
|
||||||
|
exceptions.NotOpenSSLWarning,
|
||||||
|
)
|
||||||
|
elif ssl.OPENSSL_VERSION_INFO < (1, 1, 1): # Defensive:
|
||||||
|
raise ImportError(
|
||||||
|
"urllib3 v2.0 only supports OpenSSL 1.1.1+, currently "
|
||||||
|
f"the 'ssl' module is compiled with {ssl.OPENSSL_VERSION!r}. "
|
||||||
|
"See: https://github.com/urllib3/urllib3/issues/2168"
|
||||||
|
)
|
||||||
|
|
||||||
# === NOTE TO REPACKAGERS AND VENDORS ===
|
# === NOTE TO REPACKAGERS AND VENDORS ===
|
||||||
# Please delete this block, this logic is only
|
# Please delete this block, this logic is only
|
||||||
|
@ -25,12 +50,12 @@ from .util.url import get_host
|
||||||
# See: https://github.com/urllib3/urllib3/issues/2680
|
# See: https://github.com/urllib3/urllib3/issues/2680
|
||||||
try:
|
try:
|
||||||
import urllib3_secure_extra # type: ignore # noqa: F401
|
import urllib3_secure_extra # type: ignore # noqa: F401
|
||||||
except ImportError:
|
except ModuleNotFoundError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"'urllib3[secure]' extra is deprecated and will be removed "
|
"'urllib3[secure]' extra is deprecated and will be removed "
|
||||||
"in a future release of urllib3 2.x. Read more in this issue: "
|
"in urllib3 v2.1.0. Read more in this issue: "
|
||||||
"https://github.com/urllib3/urllib3/issues/2680",
|
"https://github.com/urllib3/urllib3/issues/2680",
|
||||||
category=DeprecationWarning,
|
category=DeprecationWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
|
@ -42,6 +67,7 @@ __version__ = __version__
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"HTTPConnectionPool",
|
"HTTPConnectionPool",
|
||||||
|
"HTTPHeaderDict",
|
||||||
"HTTPSConnectionPool",
|
"HTTPSConnectionPool",
|
||||||
"PoolManager",
|
"PoolManager",
|
||||||
"ProxyManager",
|
"ProxyManager",
|
||||||
|
@ -52,15 +78,18 @@ __all__ = (
|
||||||
"connection_from_url",
|
"connection_from_url",
|
||||||
"disable_warnings",
|
"disable_warnings",
|
||||||
"encode_multipart_formdata",
|
"encode_multipart_formdata",
|
||||||
"get_host",
|
|
||||||
"make_headers",
|
"make_headers",
|
||||||
"proxy_from_url",
|
"proxy_from_url",
|
||||||
|
"request",
|
||||||
|
"BaseHTTPResponse",
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.getLogger(__name__).addHandler(NullHandler())
|
logging.getLogger(__name__).addHandler(NullHandler())
|
||||||
|
|
||||||
|
|
||||||
def add_stderr_logger(level=logging.DEBUG):
|
def add_stderr_logger(
|
||||||
|
level: int = logging.DEBUG,
|
||||||
|
) -> logging.StreamHandler[typing.TextIO]:
|
||||||
"""
|
"""
|
||||||
Helper for quickly adding a StreamHandler to the logger. Useful for
|
Helper for quickly adding a StreamHandler to the logger. Useful for
|
||||||
debugging.
|
debugging.
|
||||||
|
@ -87,16 +116,51 @@ del NullHandler
|
||||||
# mechanisms to silence them.
|
# mechanisms to silence them.
|
||||||
# SecurityWarning's always go off by default.
|
# SecurityWarning's always go off by default.
|
||||||
warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
|
warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
|
||||||
# SubjectAltNameWarning's should go off once per host
|
|
||||||
warnings.simplefilter("default", exceptions.SubjectAltNameWarning, append=True)
|
|
||||||
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
|
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
|
||||||
warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
|
warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
|
||||||
# SNIMissingWarnings should go off only once.
|
|
||||||
warnings.simplefilter("default", exceptions.SNIMissingWarning, append=True)
|
|
||||||
|
|
||||||
|
|
||||||
def disable_warnings(category=exceptions.HTTPWarning):
|
def disable_warnings(category: type[Warning] = exceptions.HTTPWarning) -> None:
|
||||||
"""
|
"""
|
||||||
Helper for quickly disabling all urllib3 warnings.
|
Helper for quickly disabling all urllib3 warnings.
|
||||||
"""
|
"""
|
||||||
warnings.simplefilter("ignore", category)
|
warnings.simplefilter("ignore", category)
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_POOL = PoolManager()
|
||||||
|
|
||||||
|
|
||||||
|
def request(
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
body: _TYPE_BODY | None = None,
|
||||||
|
fields: _TYPE_FIELDS | None = None,
|
||||||
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
|
preload_content: bool | None = True,
|
||||||
|
decode_content: bool | None = True,
|
||||||
|
redirect: bool | None = True,
|
||||||
|
retries: Retry | bool | int | None = None,
|
||||||
|
timeout: Timeout | float | int | None = 3,
|
||||||
|
json: typing.Any | None = None,
|
||||||
|
) -> BaseHTTPResponse:
|
||||||
|
"""
|
||||||
|
A convenience, top-level request method. It uses a module-global ``PoolManager`` instance.
|
||||||
|
Therefore, its side effects could be shared across dependencies relying on it.
|
||||||
|
To avoid side effects create a new ``PoolManager`` instance and use it instead.
|
||||||
|
The method does not accept low-level ``**urlopen_kw`` keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return _DEFAULT_POOL.request(
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
body=body,
|
||||||
|
fields=fields,
|
||||||
|
headers=headers,
|
||||||
|
preload_content=preload_content,
|
||||||
|
decode_content=decode_content,
|
||||||
|
redirect=redirect,
|
||||||
|
retries=retries,
|
||||||
|
timeout=timeout,
|
||||||
|
json=json,
|
||||||
|
)
|
||||||
|
|
173
lib/urllib3/_base_connection.py
Normal file
173
lib/urllib3/_base_connection.py
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from .util.connection import _TYPE_SOCKET_OPTIONS
|
||||||
|
from .util.timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT
|
||||||
|
from .util.url import Url
|
||||||
|
|
||||||
|
_TYPE_BODY = typing.Union[bytes, typing.IO[typing.Any], typing.Iterable[bytes], str]
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyConfig(typing.NamedTuple):
|
||||||
|
ssl_context: ssl.SSLContext | None
|
||||||
|
use_forwarding_for_https: bool
|
||||||
|
assert_hostname: None | str | Literal[False]
|
||||||
|
assert_fingerprint: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class _ResponseOptions(typing.NamedTuple):
|
||||||
|
# TODO: Remove this in favor of a better
|
||||||
|
# HTTP request/response lifecycle tracking.
|
||||||
|
request_method: str
|
||||||
|
request_url: str
|
||||||
|
preload_content: bool
|
||||||
|
decode_content: bool
|
||||||
|
enforce_content_length: bool
|
||||||
|
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
from typing_extensions import Literal, Protocol
|
||||||
|
|
||||||
|
from .response import BaseHTTPResponse
|
||||||
|
|
||||||
|
class BaseHTTPConnection(Protocol):
|
||||||
|
default_port: typing.ClassVar[int]
|
||||||
|
default_socket_options: typing.ClassVar[_TYPE_SOCKET_OPTIONS]
|
||||||
|
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
timeout: None | (
|
||||||
|
float
|
||||||
|
) # Instance doesn't store _DEFAULT_TIMEOUT, must be resolved.
|
||||||
|
blocksize: int
|
||||||
|
source_address: tuple[str, int] | None
|
||||||
|
socket_options: _TYPE_SOCKET_OPTIONS | None
|
||||||
|
|
||||||
|
proxy: Url | None
|
||||||
|
proxy_config: ProxyConfig | None
|
||||||
|
|
||||||
|
is_verified: bool
|
||||||
|
proxy_is_verified: bool | None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
port: int | None = None,
|
||||||
|
*,
|
||||||
|
timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
|
||||||
|
source_address: tuple[str, int] | None = None,
|
||||||
|
blocksize: int = 8192,
|
||||||
|
socket_options: _TYPE_SOCKET_OPTIONS | None = ...,
|
||||||
|
proxy: Url | None = None,
|
||||||
|
proxy_config: ProxyConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def set_tunnel(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
port: int | None = None,
|
||||||
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
|
scheme: str = "http",
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
body: _TYPE_BODY | None = None,
|
||||||
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
|
# We know *at least* botocore is depending on the order of the
|
||||||
|
# first 3 parameters so to be safe we only mark the later ones
|
||||||
|
# as keyword-only to ensure we have space to extend.
|
||||||
|
*,
|
||||||
|
chunked: bool = False,
|
||||||
|
preload_content: bool = True,
|
||||||
|
decode_content: bool = True,
|
||||||
|
enforce_content_length: bool = True,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def getresponse(self) -> BaseHTTPResponse:
|
||||||
|
...
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""Whether the connection either is brand new or has been previously closed.
|
||||||
|
If this property is True then both ``is_connected`` and ``has_connected_to_proxy``
|
||||||
|
properties must be False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Whether the connection is actively connected to any origin (proxy or target)"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_connected_to_proxy(self) -> bool:
|
||||||
|
"""Whether the connection has successfully connected to its proxy.
|
||||||
|
This returns False if no proxy is in use. Used to determine whether
|
||||||
|
errors are coming from the proxy layer or from tunnelling to the target origin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class BaseHTTPSConnection(BaseHTTPConnection, Protocol):
|
||||||
|
default_port: typing.ClassVar[int]
|
||||||
|
default_socket_options: typing.ClassVar[_TYPE_SOCKET_OPTIONS]
|
||||||
|
|
||||||
|
# Certificate verification methods
|
||||||
|
cert_reqs: int | str | None
|
||||||
|
assert_hostname: None | str | Literal[False]
|
||||||
|
assert_fingerprint: str | None
|
||||||
|
ssl_context: ssl.SSLContext | None
|
||||||
|
|
||||||
|
# Trusted CAs
|
||||||
|
ca_certs: str | None
|
||||||
|
ca_cert_dir: str | None
|
||||||
|
ca_cert_data: None | str | bytes
|
||||||
|
|
||||||
|
# TLS version
|
||||||
|
ssl_minimum_version: int | None
|
||||||
|
ssl_maximum_version: int | None
|
||||||
|
ssl_version: int | str | None # Deprecated
|
||||||
|
|
||||||
|
# Client certificates
|
||||||
|
cert_file: str | None
|
||||||
|
key_file: str | None
|
||||||
|
key_password: str | None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
port: int | None = None,
|
||||||
|
*,
|
||||||
|
timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
|
||||||
|
source_address: tuple[str, int] | None = None,
|
||||||
|
blocksize: int = 16384,
|
||||||
|
socket_options: _TYPE_SOCKET_OPTIONS | None = ...,
|
||||||
|
proxy: Url | None = None,
|
||||||
|
proxy_config: ProxyConfig | None = None,
|
||||||
|
cert_reqs: int | str | None = None,
|
||||||
|
assert_hostname: None | str | Literal[False] = None,
|
||||||
|
assert_fingerprint: str | None = None,
|
||||||
|
server_hostname: str | None = None,
|
||||||
|
ssl_context: ssl.SSLContext | None = None,
|
||||||
|
ca_certs: str | None = None,
|
||||||
|
ca_cert_dir: str | None = None,
|
||||||
|
ca_cert_data: None | str | bytes = None,
|
||||||
|
ssl_minimum_version: int | None = None,
|
||||||
|
ssl_maximum_version: int | None = None,
|
||||||
|
ssl_version: int | str | None = None, # Deprecated
|
||||||
|
cert_file: str | None = None,
|
||||||
|
key_file: str | None = None,
|
||||||
|
key_password: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
...
|
|
@ -1,34 +1,66 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
try:
|
|
||||||
from collections.abc import Mapping, MutableMapping
|
|
||||||
except ImportError:
|
|
||||||
from collections import Mapping, MutableMapping
|
|
||||||
try:
|
|
||||||
from threading import RLock
|
|
||||||
except ImportError: # Platform-specific: No threads available
|
|
||||||
|
|
||||||
class RLock:
|
|
||||||
def __enter__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
|
import typing
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from enum import Enum, auto
|
||||||
|
from threading import RLock
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
# We can only import Protocol if TYPE_CHECKING because it's a development
|
||||||
|
# dependency, and is not available at runtime.
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
class HasGettableStringKeys(Protocol):
|
||||||
|
def keys(self) -> typing.Iterator[str]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
from .exceptions import InvalidHeader
|
|
||||||
from .packages import six
|
|
||||||
from .packages.six import iterkeys, itervalues
|
|
||||||
|
|
||||||
__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
|
__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
|
||||||
|
|
||||||
|
|
||||||
_Null = object()
|
# Key type
|
||||||
|
_KT = typing.TypeVar("_KT")
|
||||||
|
# Value type
|
||||||
|
_VT = typing.TypeVar("_VT")
|
||||||
|
# Default type
|
||||||
|
_DT = typing.TypeVar("_DT")
|
||||||
|
|
||||||
|
ValidHTTPHeaderSource = typing.Union[
|
||||||
|
"HTTPHeaderDict",
|
||||||
|
typing.Mapping[str, str],
|
||||||
|
typing.Iterable[typing.Tuple[str, str]],
|
||||||
|
"HasGettableStringKeys",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class RecentlyUsedContainer(MutableMapping):
|
class _Sentinel(Enum):
|
||||||
|
not_passed = auto()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_can_construct_http_header_dict(
|
||||||
|
potential: object,
|
||||||
|
) -> ValidHTTPHeaderSource | None:
|
||||||
|
if isinstance(potential, HTTPHeaderDict):
|
||||||
|
return potential
|
||||||
|
elif isinstance(potential, typing.Mapping):
|
||||||
|
# Full runtime checking of the contents of a Mapping is expensive, so for the
|
||||||
|
# purposes of typechecking, we assume that any Mapping is the right shape.
|
||||||
|
return typing.cast(typing.Mapping[str, str], potential)
|
||||||
|
elif isinstance(potential, typing.Iterable):
|
||||||
|
# Similarly to Mapping, full runtime checking of the contents of an Iterable is
|
||||||
|
# expensive, so for the purposes of typechecking, we assume that any Iterable
|
||||||
|
# is the right shape.
|
||||||
|
return typing.cast(typing.Iterable[typing.Tuple[str, str]], potential)
|
||||||
|
elif hasattr(potential, "keys") and hasattr(potential, "__getitem__"):
|
||||||
|
return typing.cast("HasGettableStringKeys", potential)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RecentlyUsedContainer(typing.Generic[_KT, _VT], typing.MutableMapping[_KT, _VT]):
|
||||||
"""
|
"""
|
||||||
Provides a thread-safe dict-like container which maintains up to
|
Provides a thread-safe dict-like container which maintains up to
|
||||||
``maxsize`` keys while throwing away the least-recently-used keys beyond
|
``maxsize`` keys while throwing away the least-recently-used keys beyond
|
||||||
|
@ -42,69 +74,134 @@ class RecentlyUsedContainer(MutableMapping):
|
||||||
``dispose_func(value)`` is called. Callback which will get called
|
``dispose_func(value)`` is called. Callback which will get called
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ContainerCls = OrderedDict
|
_container: typing.OrderedDict[_KT, _VT]
|
||||||
|
_maxsize: int
|
||||||
|
dispose_func: typing.Callable[[_VT], None] | None
|
||||||
|
lock: RLock
|
||||||
|
|
||||||
def __init__(self, maxsize=10, dispose_func=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
maxsize: int = 10,
|
||||||
|
dispose_func: typing.Callable[[_VT], None] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
self._maxsize = maxsize
|
self._maxsize = maxsize
|
||||||
self.dispose_func = dispose_func
|
self.dispose_func = dispose_func
|
||||||
|
self._container = OrderedDict()
|
||||||
self._container = self.ContainerCls()
|
|
||||||
self.lock = RLock()
|
self.lock = RLock()
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: _KT) -> _VT:
|
||||||
# Re-insert the item, moving it to the end of the eviction line.
|
# Re-insert the item, moving it to the end of the eviction line.
|
||||||
with self.lock:
|
with self.lock:
|
||||||
item = self._container.pop(key)
|
item = self._container.pop(key)
|
||||||
self._container[key] = item
|
self._container[key] = item
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key: _KT, value: _VT) -> None:
|
||||||
evicted_value = _Null
|
evicted_item = None
|
||||||
with self.lock:
|
with self.lock:
|
||||||
# Possibly evict the existing value of 'key'
|
# Possibly evict the existing value of 'key'
|
||||||
evicted_value = self._container.get(key, _Null)
|
try:
|
||||||
self._container[key] = value
|
# If the key exists, we'll overwrite it, which won't change the
|
||||||
|
# size of the pool. Because accessing a key should move it to
|
||||||
|
# the end of the eviction line, we pop it out first.
|
||||||
|
evicted_item = key, self._container.pop(key)
|
||||||
|
self._container[key] = value
|
||||||
|
except KeyError:
|
||||||
|
# When the key does not exist, we insert the value first so that
|
||||||
|
# evicting works in all cases, including when self._maxsize is 0
|
||||||
|
self._container[key] = value
|
||||||
|
if len(self._container) > self._maxsize:
|
||||||
|
# If we didn't evict an existing value, and we've hit our maximum
|
||||||
|
# size, then we have to evict the least recently used item from
|
||||||
|
# the beginning of the container.
|
||||||
|
evicted_item = self._container.popitem(last=False)
|
||||||
|
|
||||||
# If we didn't evict an existing value, we might have to evict the
|
# After releasing the lock on the pool, dispose of any evicted value.
|
||||||
# least recently used item from the beginning of the container.
|
if evicted_item is not None and self.dispose_func:
|
||||||
if len(self._container) > self._maxsize:
|
_, evicted_value = evicted_item
|
||||||
_key, evicted_value = self._container.popitem(last=False)
|
|
||||||
|
|
||||||
if self.dispose_func and evicted_value is not _Null:
|
|
||||||
self.dispose_func(evicted_value)
|
self.dispose_func(evicted_value)
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key: _KT) -> None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
value = self._container.pop(key)
|
value = self._container.pop(key)
|
||||||
|
|
||||||
if self.dispose_func:
|
if self.dispose_func:
|
||||||
self.dispose_func(value)
|
self.dispose_func(value)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
return len(self._container)
|
return len(self._container)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> typing.NoReturn:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Iteration over this class is unlikely to be threadsafe."
|
"Iteration over this class is unlikely to be threadsafe."
|
||||||
)
|
)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
# Copy pointers to all values, then wipe the mapping
|
# Copy pointers to all values, then wipe the mapping
|
||||||
values = list(itervalues(self._container))
|
values = list(self._container.values())
|
||||||
self._container.clear()
|
self._container.clear()
|
||||||
|
|
||||||
if self.dispose_func:
|
if self.dispose_func:
|
||||||
for value in values:
|
for value in values:
|
||||||
self.dispose_func(value)
|
self.dispose_func(value)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self) -> set[_KT]: # type: ignore[override]
|
||||||
with self.lock:
|
with self.lock:
|
||||||
return list(iterkeys(self._container))
|
return set(self._container.keys())
|
||||||
|
|
||||||
|
|
||||||
class HTTPHeaderDict(MutableMapping):
|
class HTTPHeaderDictItemView(typing.Set[typing.Tuple[str, str]]):
|
||||||
|
"""
|
||||||
|
HTTPHeaderDict is unusual for a Mapping[str, str] in that it has two modes of
|
||||||
|
address.
|
||||||
|
|
||||||
|
If we directly try to get an item with a particular name, we will get a string
|
||||||
|
back that is the concatenated version of all the values:
|
||||||
|
|
||||||
|
>>> d['X-Header-Name']
|
||||||
|
'Value1, Value2, Value3'
|
||||||
|
|
||||||
|
However, if we iterate over an HTTPHeaderDict's items, we will optionally combine
|
||||||
|
these values based on whether combine=True was called when building up the dictionary
|
||||||
|
|
||||||
|
>>> d = HTTPHeaderDict({"A": "1", "B": "foo"})
|
||||||
|
>>> d.add("A", "2", combine=True)
|
||||||
|
>>> d.add("B", "bar")
|
||||||
|
>>> list(d.items())
|
||||||
|
[
|
||||||
|
('A', '1, 2'),
|
||||||
|
('B', 'foo'),
|
||||||
|
('B', 'bar'),
|
||||||
|
]
|
||||||
|
|
||||||
|
This class conforms to the interface required by the MutableMapping ABC while
|
||||||
|
also giving us the nonstandard iteration behavior we want; items with duplicate
|
||||||
|
keys, ordered by time of first insertion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_headers: HTTPHeaderDict
|
||||||
|
|
||||||
|
def __init__(self, headers: HTTPHeaderDict) -> None:
|
||||||
|
self._headers = headers
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(list(self._headers.iteritems()))
|
||||||
|
|
||||||
|
def __iter__(self) -> typing.Iterator[tuple[str, str]]:
|
||||||
|
return self._headers.iteritems()
|
||||||
|
|
||||||
|
def __contains__(self, item: object) -> bool:
|
||||||
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
|
passed_key, passed_val = item
|
||||||
|
if isinstance(passed_key, str) and isinstance(passed_val, str):
|
||||||
|
return self._headers._has_value_for_header(passed_key, passed_val)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPHeaderDict(typing.MutableMapping[str, str]):
|
||||||
"""
|
"""
|
||||||
:param headers:
|
:param headers:
|
||||||
An iterable of field-value pairs. Must not contain multiple field names
|
An iterable of field-value pairs. Must not contain multiple field names
|
||||||
|
@ -138,9 +235,11 @@ class HTTPHeaderDict(MutableMapping):
|
||||||
'7'
|
'7'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, headers=None, **kwargs):
|
_container: typing.MutableMapping[str, list[str]]
|
||||||
super(HTTPHeaderDict, self).__init__()
|
|
||||||
self._container = OrderedDict()
|
def __init__(self, headers: ValidHTTPHeaderSource | None = None, **kwargs: str):
|
||||||
|
super().__init__()
|
||||||
|
self._container = {} # 'dict' is insert-ordered in Python 3.7+
|
||||||
if headers is not None:
|
if headers is not None:
|
||||||
if isinstance(headers, HTTPHeaderDict):
|
if isinstance(headers, HTTPHeaderDict):
|
||||||
self._copy_from(headers)
|
self._copy_from(headers)
|
||||||
|
@ -149,123 +248,147 @@ class HTTPHeaderDict(MutableMapping):
|
||||||
if kwargs:
|
if kwargs:
|
||||||
self.extend(kwargs)
|
self.extend(kwargs)
|
||||||
|
|
||||||
def __setitem__(self, key, val):
|
def __setitem__(self, key: str, val: str) -> None:
|
||||||
|
# avoid a bytes/str comparison by decoding before httplib
|
||||||
|
if isinstance(key, bytes):
|
||||||
|
key = key.decode("latin-1")
|
||||||
self._container[key.lower()] = [key, val]
|
self._container[key.lower()] = [key, val]
|
||||||
return self._container[key.lower()]
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: str) -> str:
|
||||||
val = self._container[key.lower()]
|
val = self._container[key.lower()]
|
||||||
return ", ".join(val[1:])
|
return ", ".join(val[1:])
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key: str) -> None:
|
||||||
del self._container[key.lower()]
|
del self._container[key.lower()]
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key: object) -> bool:
|
||||||
return key.lower() in self._container
|
if isinstance(key, str):
|
||||||
|
return key.lower() in self._container
|
||||||
|
return False
|
||||||
|
|
||||||
def __eq__(self, other):
|
def setdefault(self, key: str, default: str = "") -> str:
|
||||||
if not isinstance(other, Mapping) and not hasattr(other, "keys"):
|
return super().setdefault(key, default)
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
maybe_constructable = ensure_can_construct_http_header_dict(other)
|
||||||
|
if maybe_constructable is None:
|
||||||
return False
|
return False
|
||||||
if not isinstance(other, type(self)):
|
else:
|
||||||
other = type(self)(other)
|
other_as_http_header_dict = type(self)(maybe_constructable)
|
||||||
return dict((k.lower(), v) for k, v in self.itermerged()) == dict(
|
|
||||||
(k.lower(), v) for k, v in other.itermerged()
|
|
||||||
)
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
return {k.lower(): v for k, v in self.itermerged()} == {
|
||||||
|
k.lower(): v for k, v in other_as_http_header_dict.itermerged()
|
||||||
|
}
|
||||||
|
|
||||||
|
def __ne__(self, other: object) -> bool:
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
if six.PY2: # Python 2
|
def __len__(self) -> int:
|
||||||
iterkeys = MutableMapping.iterkeys
|
|
||||||
itervalues = MutableMapping.itervalues
|
|
||||||
|
|
||||||
__marker = object()
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._container)
|
return len(self._container)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> typing.Iterator[str]:
|
||||||
# Only provide the originally cased names
|
# Only provide the originally cased names
|
||||||
for vals in self._container.values():
|
for vals in self._container.values():
|
||||||
yield vals[0]
|
yield vals[0]
|
||||||
|
|
||||||
def pop(self, key, default=__marker):
|
def discard(self, key: str) -> None:
|
||||||
"""D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
|
|
||||||
If key is not found, d is returned if given, otherwise KeyError is raised.
|
|
||||||
"""
|
|
||||||
# Using the MutableMapping function directly fails due to the private marker.
|
|
||||||
# Using ordinary dict.pop would expose the internal structures.
|
|
||||||
# So let's reinvent the wheel.
|
|
||||||
try:
|
|
||||||
value = self[key]
|
|
||||||
except KeyError:
|
|
||||||
if default is self.__marker:
|
|
||||||
raise
|
|
||||||
return default
|
|
||||||
else:
|
|
||||||
del self[key]
|
|
||||||
return value
|
|
||||||
|
|
||||||
def discard(self, key):
|
|
||||||
try:
|
try:
|
||||||
del self[key]
|
del self[key]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def add(self, key, val):
|
def add(self, key: str, val: str, *, combine: bool = False) -> None:
|
||||||
"""Adds a (name, value) pair, doesn't overwrite the value if it already
|
"""Adds a (name, value) pair, doesn't overwrite the value if it already
|
||||||
exists.
|
exists.
|
||||||
|
|
||||||
|
If this is called with combine=True, instead of adding a new header value
|
||||||
|
as a distinct item during iteration, this will instead append the value to
|
||||||
|
any existing header value with a comma. If no existing header value exists
|
||||||
|
for the key, then the value will simply be added, ignoring the combine parameter.
|
||||||
|
|
||||||
>>> headers = HTTPHeaderDict(foo='bar')
|
>>> headers = HTTPHeaderDict(foo='bar')
|
||||||
>>> headers.add('Foo', 'baz')
|
>>> headers.add('Foo', 'baz')
|
||||||
>>> headers['foo']
|
>>> headers['foo']
|
||||||
'bar, baz'
|
'bar, baz'
|
||||||
|
>>> list(headers.items())
|
||||||
|
[('foo', 'bar'), ('foo', 'baz')]
|
||||||
|
>>> headers.add('foo', 'quz', combine=True)
|
||||||
|
>>> list(headers.items())
|
||||||
|
[('foo', 'bar, baz, quz')]
|
||||||
"""
|
"""
|
||||||
|
# avoid a bytes/str comparison by decoding before httplib
|
||||||
|
if isinstance(key, bytes):
|
||||||
|
key = key.decode("latin-1")
|
||||||
key_lower = key.lower()
|
key_lower = key.lower()
|
||||||
new_vals = [key, val]
|
new_vals = [key, val]
|
||||||
# Keep the common case aka no item present as fast as possible
|
# Keep the common case aka no item present as fast as possible
|
||||||
vals = self._container.setdefault(key_lower, new_vals)
|
vals = self._container.setdefault(key_lower, new_vals)
|
||||||
if new_vals is not vals:
|
if new_vals is not vals:
|
||||||
vals.append(val)
|
# if there are values here, then there is at least the initial
|
||||||
|
# key/value pair
|
||||||
|
assert len(vals) >= 2
|
||||||
|
if combine:
|
||||||
|
vals[-1] = vals[-1] + ", " + val
|
||||||
|
else:
|
||||||
|
vals.append(val)
|
||||||
|
|
||||||
def extend(self, *args, **kwargs):
|
def extend(self, *args: ValidHTTPHeaderSource, **kwargs: str) -> None:
|
||||||
"""Generic import function for any type of header-like object.
|
"""Generic import function for any type of header-like object.
|
||||||
Adapted version of MutableMapping.update in order to insert items
|
Adapted version of MutableMapping.update in order to insert items
|
||||||
with self.add instead of self.__setitem__
|
with self.add instead of self.__setitem__
|
||||||
"""
|
"""
|
||||||
if len(args) > 1:
|
if len(args) > 1:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"extend() takes at most 1 positional "
|
f"extend() takes at most 1 positional arguments ({len(args)} given)"
|
||||||
"arguments ({0} given)".format(len(args))
|
|
||||||
)
|
)
|
||||||
other = args[0] if len(args) >= 1 else ()
|
other = args[0] if len(args) >= 1 else ()
|
||||||
|
|
||||||
if isinstance(other, HTTPHeaderDict):
|
if isinstance(other, HTTPHeaderDict):
|
||||||
for key, val in other.iteritems():
|
for key, val in other.iteritems():
|
||||||
self.add(key, val)
|
self.add(key, val)
|
||||||
elif isinstance(other, Mapping):
|
elif isinstance(other, typing.Mapping):
|
||||||
for key in other:
|
for key, val in other.items():
|
||||||
self.add(key, other[key])
|
self.add(key, val)
|
||||||
elif hasattr(other, "keys"):
|
elif isinstance(other, typing.Iterable):
|
||||||
for key in other.keys():
|
other = typing.cast(typing.Iterable[typing.Tuple[str, str]], other)
|
||||||
self.add(key, other[key])
|
|
||||||
else:
|
|
||||||
for key, value in other:
|
for key, value in other:
|
||||||
self.add(key, value)
|
self.add(key, value)
|
||||||
|
elif hasattr(other, "keys") and hasattr(other, "__getitem__"):
|
||||||
|
# THIS IS NOT A TYPESAFE BRANCH
|
||||||
|
# In this branch, the object has a `keys` attr but is not a Mapping or any of
|
||||||
|
# the other types indicated in the method signature. We do some stuff with
|
||||||
|
# it as though it partially implements the Mapping interface, but we're not
|
||||||
|
# doing that stuff safely AT ALL.
|
||||||
|
for key in other.keys():
|
||||||
|
self.add(key, other[key])
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
self.add(key, value)
|
self.add(key, value)
|
||||||
|
|
||||||
def getlist(self, key, default=__marker):
|
@typing.overload
|
||||||
|
def getlist(self, key: str) -> list[str]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@typing.overload
|
||||||
|
def getlist(self, key: str, default: _DT) -> list[str] | _DT:
|
||||||
|
...
|
||||||
|
|
||||||
|
def getlist(
|
||||||
|
self, key: str, default: _Sentinel | _DT = _Sentinel.not_passed
|
||||||
|
) -> list[str] | _DT:
|
||||||
"""Returns a list of all the values for the named field. Returns an
|
"""Returns a list of all the values for the named field. Returns an
|
||||||
empty list if the key doesn't exist."""
|
empty list if the key doesn't exist."""
|
||||||
try:
|
try:
|
||||||
vals = self._container[key.lower()]
|
vals = self._container[key.lower()]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if default is self.__marker:
|
if default is _Sentinel.not_passed:
|
||||||
|
# _DT is unbound; empty list is instance of List[str]
|
||||||
return []
|
return []
|
||||||
|
# _DT is bound; default is instance of _DT
|
||||||
return default
|
return default
|
||||||
else:
|
else:
|
||||||
|
# _DT may or may not be bound; vals[1:] is instance of List[str], which
|
||||||
|
# meets our external interface requirement of `Union[List[str], _DT]`.
|
||||||
return vals[1:]
|
return vals[1:]
|
||||||
|
|
||||||
# Backwards compatibility for httplib
|
# Backwards compatibility for httplib
|
||||||
|
@ -276,62 +399,65 @@ class HTTPHeaderDict(MutableMapping):
|
||||||
# Backwards compatibility for http.cookiejar
|
# Backwards compatibility for http.cookiejar
|
||||||
get_all = getlist
|
get_all = getlist
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return "%s(%s)" % (type(self).__name__, dict(self.itermerged()))
|
return f"{type(self).__name__}({dict(self.itermerged())})"
|
||||||
|
|
||||||
def _copy_from(self, other):
|
def _copy_from(self, other: HTTPHeaderDict) -> None:
|
||||||
for key in other:
|
for key in other:
|
||||||
val = other.getlist(key)
|
val = other.getlist(key)
|
||||||
if isinstance(val, list):
|
self._container[key.lower()] = [key, *val]
|
||||||
# Don't need to convert tuples
|
|
||||||
val = list(val)
|
|
||||||
self._container[key.lower()] = [key] + val
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self) -> HTTPHeaderDict:
|
||||||
clone = type(self)()
|
clone = type(self)()
|
||||||
clone._copy_from(self)
|
clone._copy_from(self)
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
def iteritems(self):
|
def iteritems(self) -> typing.Iterator[tuple[str, str]]:
|
||||||
"""Iterate over all header lines, including duplicate ones."""
|
"""Iterate over all header lines, including duplicate ones."""
|
||||||
for key in self:
|
for key in self:
|
||||||
vals = self._container[key.lower()]
|
vals = self._container[key.lower()]
|
||||||
for val in vals[1:]:
|
for val in vals[1:]:
|
||||||
yield vals[0], val
|
yield vals[0], val
|
||||||
|
|
||||||
def itermerged(self):
|
def itermerged(self) -> typing.Iterator[tuple[str, str]]:
|
||||||
"""Iterate over all headers, merging duplicate ones together."""
|
"""Iterate over all headers, merging duplicate ones together."""
|
||||||
for key in self:
|
for key in self:
|
||||||
val = self._container[key.lower()]
|
val = self._container[key.lower()]
|
||||||
yield val[0], ", ".join(val[1:])
|
yield val[0], ", ".join(val[1:])
|
||||||
|
|
||||||
def items(self):
|
def items(self) -> HTTPHeaderDictItemView: # type: ignore[override]
|
||||||
return list(self.iteritems())
|
return HTTPHeaderDictItemView(self)
|
||||||
|
|
||||||
@classmethod
|
def _has_value_for_header(self, header_name: str, potential_value: str) -> bool:
|
||||||
def from_httplib(cls, message): # Python 2
|
if header_name in self:
|
||||||
"""Read headers from a Python 2 httplib message object."""
|
return potential_value in self._container[header_name.lower()][1:]
|
||||||
# python2.7 does not expose a proper API for exporting multiheaders
|
return False
|
||||||
# efficiently. This function re-reads raw lines from the message
|
|
||||||
# object and extracts the multiheaders properly.
|
|
||||||
obs_fold_continued_leaders = (" ", "\t")
|
|
||||||
headers = []
|
|
||||||
|
|
||||||
for line in message.headers:
|
def __ior__(self, other: object) -> HTTPHeaderDict:
|
||||||
if line.startswith(obs_fold_continued_leaders):
|
# Supports extending a header dict in-place using operator |=
|
||||||
if not headers:
|
# combining items with add instead of __setitem__
|
||||||
# We received a header line that starts with OWS as described
|
maybe_constructable = ensure_can_construct_http_header_dict(other)
|
||||||
# in RFC-7230 S3.2.4. This indicates a multiline header, but
|
if maybe_constructable is None:
|
||||||
# there exists no previous header to which we can attach it.
|
return NotImplemented
|
||||||
raise InvalidHeader(
|
self.extend(maybe_constructable)
|
||||||
"Header continuation with no previous header: %s" % line
|
return self
|
||||||
)
|
|
||||||
else:
|
|
||||||
key, value = headers[-1]
|
|
||||||
headers[-1] = (key, value + " " + line.strip())
|
|
||||||
continue
|
|
||||||
|
|
||||||
key, value = line.split(":", 1)
|
def __or__(self, other: object) -> HTTPHeaderDict:
|
||||||
headers.append((key, value.strip()))
|
# Supports merging header dicts using operator |
|
||||||
|
# combining items with add instead of __setitem__
|
||||||
|
maybe_constructable = ensure_can_construct_http_header_dict(other)
|
||||||
|
if maybe_constructable is None:
|
||||||
|
return NotImplemented
|
||||||
|
result = self.copy()
|
||||||
|
result.extend(maybe_constructable)
|
||||||
|
return result
|
||||||
|
|
||||||
return cls(headers)
|
def __ror__(self, other: object) -> HTTPHeaderDict:
|
||||||
|
# Supports merging header dicts using operator | when other is on left side
|
||||||
|
# combining items with add instead of __setitem__
|
||||||
|
maybe_constructable = ensure_can_construct_http_header_dict(other)
|
||||||
|
if maybe_constructable is None:
|
||||||
|
return NotImplemented
|
||||||
|
result = type(self)(maybe_constructable)
|
||||||
|
result.extend(self)
|
||||||
|
return result
|
||||||
|
|
|
@ -1,12 +1,23 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
from .filepost import encode_multipart_formdata
|
import json as _json
|
||||||
from .packages.six.moves.urllib.parse import urlencode
|
import typing
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
from ._base_connection import _TYPE_BODY
|
||||||
|
from ._collections import HTTPHeaderDict
|
||||||
|
from .filepost import _TYPE_FIELDS, encode_multipart_formdata
|
||||||
|
from .response import BaseHTTPResponse
|
||||||
|
|
||||||
__all__ = ["RequestMethods"]
|
__all__ = ["RequestMethods"]
|
||||||
|
|
||||||
|
_TYPE_ENCODE_URL_FIELDS = typing.Union[
|
||||||
|
typing.Sequence[typing.Tuple[str, typing.Union[str, bytes]]],
|
||||||
|
typing.Mapping[str, typing.Union[str, bytes]],
|
||||||
|
]
|
||||||
|
|
||||||
class RequestMethods(object):
|
|
||||||
|
class RequestMethods:
|
||||||
"""
|
"""
|
||||||
Convenience mixin for classes who implement a :meth:`urlopen` method, such
|
Convenience mixin for classes who implement a :meth:`urlopen` method, such
|
||||||
as :class:`urllib3.HTTPConnectionPool` and
|
as :class:`urllib3.HTTPConnectionPool` and
|
||||||
|
@ -37,25 +48,34 @@ class RequestMethods(object):
|
||||||
|
|
||||||
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
|
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
|
||||||
|
|
||||||
def __init__(self, headers=None):
|
def __init__(self, headers: typing.Mapping[str, str] | None = None) -> None:
|
||||||
self.headers = headers or {}
|
self.headers = headers or {}
|
||||||
|
|
||||||
def urlopen(
|
def urlopen(
|
||||||
self,
|
self,
|
||||||
method,
|
method: str,
|
||||||
url,
|
url: str,
|
||||||
body=None,
|
body: _TYPE_BODY | None = None,
|
||||||
headers=None,
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
encode_multipart=True,
|
encode_multipart: bool = True,
|
||||||
multipart_boundary=None,
|
multipart_boundary: str | None = None,
|
||||||
**kw
|
**kw: typing.Any,
|
||||||
): # Abstract
|
) -> BaseHTTPResponse: # Abstract
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Classes extending RequestMethods must implement "
|
"Classes extending RequestMethods must implement "
|
||||||
"their own ``urlopen`` method."
|
"their own ``urlopen`` method."
|
||||||
)
|
)
|
||||||
|
|
||||||
def request(self, method, url, fields=None, headers=None, **urlopen_kw):
|
def request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
body: _TYPE_BODY | None = None,
|
||||||
|
fields: _TYPE_FIELDS | None = None,
|
||||||
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
|
json: typing.Any | None = None,
|
||||||
|
**urlopen_kw: typing.Any,
|
||||||
|
) -> BaseHTTPResponse:
|
||||||
"""
|
"""
|
||||||
Make a request using :meth:`urlopen` with the appropriate encoding of
|
Make a request using :meth:`urlopen` with the appropriate encoding of
|
||||||
``fields`` based on the ``method`` used.
|
``fields`` based on the ``method`` used.
|
||||||
|
@ -68,18 +88,45 @@ class RequestMethods(object):
|
||||||
"""
|
"""
|
||||||
method = method.upper()
|
method = method.upper()
|
||||||
|
|
||||||
urlopen_kw["request_url"] = url
|
if json is not None and body is not None:
|
||||||
|
raise TypeError(
|
||||||
|
"request got values for both 'body' and 'json' parameters which are mutually exclusive"
|
||||||
|
)
|
||||||
|
|
||||||
|
if json is not None:
|
||||||
|
if headers is None:
|
||||||
|
headers = self.headers.copy() # type: ignore
|
||||||
|
if not ("content-type" in map(str.lower, headers.keys())):
|
||||||
|
headers["Content-Type"] = "application/json" # type: ignore
|
||||||
|
|
||||||
|
body = _json.dumps(json, separators=(",", ":"), ensure_ascii=False).encode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
if body is not None:
|
||||||
|
urlopen_kw["body"] = body
|
||||||
|
|
||||||
if method in self._encode_url_methods:
|
if method in self._encode_url_methods:
|
||||||
return self.request_encode_url(
|
return self.request_encode_url(
|
||||||
method, url, fields=fields, headers=headers, **urlopen_kw
|
method,
|
||||||
|
url,
|
||||||
|
fields=fields, # type: ignore[arg-type]
|
||||||
|
headers=headers,
|
||||||
|
**urlopen_kw,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.request_encode_body(
|
return self.request_encode_body(
|
||||||
method, url, fields=fields, headers=headers, **urlopen_kw
|
method, url, fields=fields, headers=headers, **urlopen_kw
|
||||||
)
|
)
|
||||||
|
|
||||||
def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_kw):
|
def request_encode_url(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
fields: _TYPE_ENCODE_URL_FIELDS | None = None,
|
||||||
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
|
**urlopen_kw: str,
|
||||||
|
) -> BaseHTTPResponse:
|
||||||
"""
|
"""
|
||||||
Make a request using :meth:`urlopen` with the ``fields`` encoded in
|
Make a request using :meth:`urlopen` with the ``fields`` encoded in
|
||||||
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
|
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
|
||||||
|
@ -87,7 +134,7 @@ class RequestMethods(object):
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = self.headers
|
headers = self.headers
|
||||||
|
|
||||||
extra_kw = {"headers": headers}
|
extra_kw: dict[str, typing.Any] = {"headers": headers}
|
||||||
extra_kw.update(urlopen_kw)
|
extra_kw.update(urlopen_kw)
|
||||||
|
|
||||||
if fields:
|
if fields:
|
||||||
|
@ -97,14 +144,14 @@ class RequestMethods(object):
|
||||||
|
|
||||||
def request_encode_body(
|
def request_encode_body(
|
||||||
self,
|
self,
|
||||||
method,
|
method: str,
|
||||||
url,
|
url: str,
|
||||||
fields=None,
|
fields: _TYPE_FIELDS | None = None,
|
||||||
headers=None,
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
encode_multipart=True,
|
encode_multipart: bool = True,
|
||||||
multipart_boundary=None,
|
multipart_boundary: str | None = None,
|
||||||
**urlopen_kw
|
**urlopen_kw: str,
|
||||||
):
|
) -> BaseHTTPResponse:
|
||||||
"""
|
"""
|
||||||
Make a request using :meth:`urlopen` with the ``fields`` encoded in
|
Make a request using :meth:`urlopen` with the ``fields`` encoded in
|
||||||
the body. This is useful for request methods like POST, PUT, PATCH, etc.
|
the body. This is useful for request methods like POST, PUT, PATCH, etc.
|
||||||
|
@ -143,7 +190,8 @@ class RequestMethods(object):
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = self.headers
|
headers = self.headers
|
||||||
|
|
||||||
extra_kw = {"headers": {}}
|
extra_kw: dict[str, typing.Any] = {"headers": HTTPHeaderDict(headers)}
|
||||||
|
body: bytes | str
|
||||||
|
|
||||||
if fields:
|
if fields:
|
||||||
if "body" in urlopen_kw:
|
if "body" in urlopen_kw:
|
||||||
|
@ -157,14 +205,13 @@ class RequestMethods(object):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
body, content_type = (
|
body, content_type = (
|
||||||
urlencode(fields),
|
urlencode(fields), # type: ignore[arg-type]
|
||||||
"application/x-www-form-urlencoded",
|
"application/x-www-form-urlencoded",
|
||||||
)
|
)
|
||||||
|
|
||||||
extra_kw["body"] = body
|
extra_kw["body"] = body
|
||||||
extra_kw["headers"] = {"Content-Type": content_type}
|
extra_kw["headers"].setdefault("Content-Type", content_type)
|
||||||
|
|
||||||
extra_kw["headers"].update(headers)
|
|
||||||
extra_kw.update(urlopen_kw)
|
extra_kw.update(urlopen_kw)
|
||||||
|
|
||||||
return self.urlopen(method, url, **extra_kw)
|
return self.urlopen(method, url, **extra_kw)
|
|
@ -1,2 +1,4 @@
|
||||||
# This file is protected via CODEOWNERS
|
# This file is protected via CODEOWNERS
|
||||||
__version__ = "1.26.15"
|
from __future__ import annotations
|
||||||
|
|
||||||
|
__version__ = "2.0.4"
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -1,36 +0,0 @@
|
||||||
"""
|
|
||||||
This module provides means to detect the App Engine environment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def is_appengine():
|
|
||||||
return is_local_appengine() or is_prod_appengine()
|
|
||||||
|
|
||||||
|
|
||||||
def is_appengine_sandbox():
|
|
||||||
"""Reports if the app is running in the first generation sandbox.
|
|
||||||
|
|
||||||
The second generation runtimes are technically still in a sandbox, but it
|
|
||||||
is much less restrictive, so generally you shouldn't need to check for it.
|
|
||||||
see https://cloud.google.com/appengine/docs/standard/runtimes
|
|
||||||
"""
|
|
||||||
return is_appengine() and os.environ["APPENGINE_RUNTIME"] == "python27"
|
|
||||||
|
|
||||||
|
|
||||||
def is_local_appengine():
|
|
||||||
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
|
|
||||||
"SERVER_SOFTWARE", ""
|
|
||||||
).startswith("Development/")
|
|
||||||
|
|
||||||
|
|
||||||
def is_prod_appengine():
|
|
||||||
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
|
|
||||||
"SERVER_SOFTWARE", ""
|
|
||||||
).startswith("Google App Engine/")
|
|
||||||
|
|
||||||
|
|
||||||
def is_prod_appengine_mvms():
|
|
||||||
"""Deprecated."""
|
|
||||||
return False
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# type: ignore
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This module uses ctypes to bind a whole bunch of functions and constants from
|
This module uses ctypes to bind a whole bunch of functions and constants from
|
||||||
SecureTransport. The goal here is to provide the low-level API to
|
SecureTransport. The goal here is to provide the low-level API to
|
||||||
|
@ -29,7 +31,8 @@ license and by oscrypto's:
|
||||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
DEALINGS IN THE SOFTWARE.
|
DEALINGS IN THE SOFTWARE.
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
from ctypes import (
|
from ctypes import (
|
||||||
|
@ -48,8 +51,6 @@ from ctypes import (
|
||||||
)
|
)
|
||||||
from ctypes.util import find_library
|
from ctypes.util import find_library
|
||||||
|
|
||||||
from ...packages.six import raise_from
|
|
||||||
|
|
||||||
if platform.system() != "Darwin":
|
if platform.system() != "Darwin":
|
||||||
raise ImportError("Only macOS is supported")
|
raise ImportError("Only macOS is supported")
|
||||||
|
|
||||||
|
@ -57,16 +58,16 @@ version = platform.mac_ver()[0]
|
||||||
version_info = tuple(map(int, version.split(".")))
|
version_info = tuple(map(int, version.split(".")))
|
||||||
if version_info < (10, 8):
|
if version_info < (10, 8):
|
||||||
raise OSError(
|
raise OSError(
|
||||||
"Only OS X 10.8 and newer are supported, not %s.%s"
|
f"Only OS X 10.8 and newer are supported, not {version_info[0]}.{version_info[1]}"
|
||||||
% (version_info[0], version_info[1])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_cdll(name, macos10_16_path):
|
def load_cdll(name: str, macos10_16_path: str) -> CDLL:
|
||||||
"""Loads a CDLL by name, falling back to known path on 10.16+"""
|
"""Loads a CDLL by name, falling back to known path on 10.16+"""
|
||||||
try:
|
try:
|
||||||
# Big Sur is technically 11 but we use 10.16 due to the Big Sur
|
# Big Sur is technically 11 but we use 10.16 due to the Big Sur
|
||||||
# beta being labeled as 10.16.
|
# beta being labeled as 10.16.
|
||||||
|
path: str | None
|
||||||
if version_info >= (10, 16):
|
if version_info >= (10, 16):
|
||||||
path = macos10_16_path
|
path = macos10_16_path
|
||||||
else:
|
else:
|
||||||
|
@ -75,7 +76,7 @@ def load_cdll(name, macos10_16_path):
|
||||||
raise OSError # Caught and reraised as 'ImportError'
|
raise OSError # Caught and reraised as 'ImportError'
|
||||||
return CDLL(path, use_errno=True)
|
return CDLL(path, use_errno=True)
|
||||||
except OSError:
|
except OSError:
|
||||||
raise_from(ImportError("The library %s failed to load" % name), None)
|
raise ImportError(f"The library {name} failed to load") from None
|
||||||
|
|
||||||
|
|
||||||
Security = load_cdll(
|
Security = load_cdll(
|
||||||
|
@ -416,104 +417,14 @@ try:
|
||||||
CoreFoundation.CFStringRef = CFStringRef
|
CoreFoundation.CFStringRef = CFStringRef
|
||||||
CoreFoundation.CFDictionaryRef = CFDictionaryRef
|
CoreFoundation.CFDictionaryRef = CFDictionaryRef
|
||||||
|
|
||||||
except (AttributeError):
|
except AttributeError:
|
||||||
raise ImportError("Error initializing ctypes")
|
raise ImportError("Error initializing ctypes") from None
|
||||||
|
|
||||||
|
|
||||||
class CFConst(object):
|
class CFConst:
|
||||||
"""
|
"""
|
||||||
A class object that acts as essentially a namespace for CoreFoundation
|
A class object that acts as essentially a namespace for CoreFoundation
|
||||||
constants.
|
constants.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)
|
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)
|
||||||
|
|
||||||
|
|
||||||
class SecurityConst(object):
|
|
||||||
"""
|
|
||||||
A class object that acts as essentially a namespace for Security constants.
|
|
||||||
"""
|
|
||||||
|
|
||||||
kSSLSessionOptionBreakOnServerAuth = 0
|
|
||||||
|
|
||||||
kSSLProtocol2 = 1
|
|
||||||
kSSLProtocol3 = 2
|
|
||||||
kTLSProtocol1 = 4
|
|
||||||
kTLSProtocol11 = 7
|
|
||||||
kTLSProtocol12 = 8
|
|
||||||
# SecureTransport does not support TLS 1.3 even if there's a constant for it
|
|
||||||
kTLSProtocol13 = 10
|
|
||||||
kTLSProtocolMaxSupported = 999
|
|
||||||
|
|
||||||
kSSLClientSide = 1
|
|
||||||
kSSLStreamType = 0
|
|
||||||
|
|
||||||
kSecFormatPEMSequence = 10
|
|
||||||
|
|
||||||
kSecTrustResultInvalid = 0
|
|
||||||
kSecTrustResultProceed = 1
|
|
||||||
# This gap is present on purpose: this was kSecTrustResultConfirm, which
|
|
||||||
# is deprecated.
|
|
||||||
kSecTrustResultDeny = 3
|
|
||||||
kSecTrustResultUnspecified = 4
|
|
||||||
kSecTrustResultRecoverableTrustFailure = 5
|
|
||||||
kSecTrustResultFatalTrustFailure = 6
|
|
||||||
kSecTrustResultOtherError = 7
|
|
||||||
|
|
||||||
errSSLProtocol = -9800
|
|
||||||
errSSLWouldBlock = -9803
|
|
||||||
errSSLClosedGraceful = -9805
|
|
||||||
errSSLClosedNoNotify = -9816
|
|
||||||
errSSLClosedAbort = -9806
|
|
||||||
|
|
||||||
errSSLXCertChainInvalid = -9807
|
|
||||||
errSSLCrypto = -9809
|
|
||||||
errSSLInternal = -9810
|
|
||||||
errSSLCertExpired = -9814
|
|
||||||
errSSLCertNotYetValid = -9815
|
|
||||||
errSSLUnknownRootCert = -9812
|
|
||||||
errSSLNoRootCert = -9813
|
|
||||||
errSSLHostNameMismatch = -9843
|
|
||||||
errSSLPeerHandshakeFail = -9824
|
|
||||||
errSSLPeerUserCancelled = -9839
|
|
||||||
errSSLWeakPeerEphemeralDHKey = -9850
|
|
||||||
errSSLServerAuthCompleted = -9841
|
|
||||||
errSSLRecordOverflow = -9847
|
|
||||||
|
|
||||||
errSecVerifyFailed = -67808
|
|
||||||
errSecNoTrustSettings = -25263
|
|
||||||
errSecItemNotFound = -25300
|
|
||||||
errSecInvalidTrustSettings = -25262
|
|
||||||
|
|
||||||
# Cipher suites. We only pick the ones our default cipher string allows.
|
|
||||||
# Source: https://developer.apple.com/documentation/security/1550981-ssl_cipher_suite_values
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F
|
|
||||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9
|
|
||||||
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8
|
|
||||||
TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F
|
|
||||||
TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014
|
|
||||||
TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B
|
|
||||||
TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013
|
|
||||||
TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067
|
|
||||||
TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033
|
|
||||||
TLS_RSA_WITH_AES_256_GCM_SHA384 = 0x009D
|
|
||||||
TLS_RSA_WITH_AES_128_GCM_SHA256 = 0x009C
|
|
||||||
TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x003D
|
|
||||||
TLS_RSA_WITH_AES_128_CBC_SHA256 = 0x003C
|
|
||||||
TLS_RSA_WITH_AES_256_CBC_SHA = 0x0035
|
|
||||||
TLS_RSA_WITH_AES_128_CBC_SHA = 0x002F
|
|
||||||
TLS_AES_128_GCM_SHA256 = 0x1301
|
|
||||||
TLS_AES_256_GCM_SHA384 = 0x1302
|
|
||||||
TLS_AES_128_CCM_8_SHA256 = 0x1305
|
|
||||||
TLS_AES_128_CCM_SHA256 = 0x1304
|
|
||||||
|
|
|
@ -7,6 +7,8 @@ CoreFoundation messing about and memory management. The concerns in this module
|
||||||
are almost entirely about trying to avoid memory leaks and providing
|
are almost entirely about trying to avoid memory leaks and providing
|
||||||
appropriate and useful assistance to the higher-level code.
|
appropriate and useful assistance to the higher-level code.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import ctypes
|
import ctypes
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -15,8 +17,20 @@ import re
|
||||||
import ssl
|
import ssl
|
||||||
import struct
|
import struct
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import typing
|
||||||
|
|
||||||
from .bindings import CFConst, CoreFoundation, Security
|
from .bindings import ( # type: ignore[attr-defined]
|
||||||
|
CFArray,
|
||||||
|
CFConst,
|
||||||
|
CFData,
|
||||||
|
CFDictionary,
|
||||||
|
CFMutableArray,
|
||||||
|
CFString,
|
||||||
|
CFTypeRef,
|
||||||
|
CoreFoundation,
|
||||||
|
SecKeychainRef,
|
||||||
|
Security,
|
||||||
|
)
|
||||||
|
|
||||||
# This regular expression is used to grab PEM data out of a PEM bundle.
|
# This regular expression is used to grab PEM data out of a PEM bundle.
|
||||||
_PEM_CERTS_RE = re.compile(
|
_PEM_CERTS_RE = re.compile(
|
||||||
|
@ -24,7 +38,7 @@ _PEM_CERTS_RE = re.compile(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _cf_data_from_bytes(bytestring):
|
def _cf_data_from_bytes(bytestring: bytes) -> CFData:
|
||||||
"""
|
"""
|
||||||
Given a bytestring, create a CFData object from it. This CFData object must
|
Given a bytestring, create a CFData object from it. This CFData object must
|
||||||
be CFReleased by the caller.
|
be CFReleased by the caller.
|
||||||
|
@ -34,7 +48,9 @@ def _cf_data_from_bytes(bytestring):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _cf_dictionary_from_tuples(tuples):
|
def _cf_dictionary_from_tuples(
|
||||||
|
tuples: list[tuple[typing.Any, typing.Any]]
|
||||||
|
) -> CFDictionary:
|
||||||
"""
|
"""
|
||||||
Given a list of Python tuples, create an associated CFDictionary.
|
Given a list of Python tuples, create an associated CFDictionary.
|
||||||
"""
|
"""
|
||||||
|
@ -56,7 +72,7 @@ def _cf_dictionary_from_tuples(tuples):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _cfstr(py_bstr):
|
def _cfstr(py_bstr: bytes) -> CFString:
|
||||||
"""
|
"""
|
||||||
Given a Python binary data, create a CFString.
|
Given a Python binary data, create a CFString.
|
||||||
The string must be CFReleased by the caller.
|
The string must be CFReleased by the caller.
|
||||||
|
@ -70,7 +86,7 @@ def _cfstr(py_bstr):
|
||||||
return cf_str
|
return cf_str
|
||||||
|
|
||||||
|
|
||||||
def _create_cfstring_array(lst):
|
def _create_cfstring_array(lst: list[bytes]) -> CFMutableArray:
|
||||||
"""
|
"""
|
||||||
Given a list of Python binary data, create an associated CFMutableArray.
|
Given a list of Python binary data, create an associated CFMutableArray.
|
||||||
The array must be CFReleased by the caller.
|
The array must be CFReleased by the caller.
|
||||||
|
@ -97,11 +113,11 @@ def _create_cfstring_array(lst):
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
if cf_arr:
|
if cf_arr:
|
||||||
CoreFoundation.CFRelease(cf_arr)
|
CoreFoundation.CFRelease(cf_arr)
|
||||||
raise ssl.SSLError("Unable to allocate array: %s" % (e,))
|
raise ssl.SSLError(f"Unable to allocate array: {e}") from None
|
||||||
return cf_arr
|
return cf_arr
|
||||||
|
|
||||||
|
|
||||||
def _cf_string_to_unicode(value):
|
def _cf_string_to_unicode(value: CFString) -> str | None:
|
||||||
"""
|
"""
|
||||||
Creates a Unicode string from a CFString object. Used entirely for error
|
Creates a Unicode string from a CFString object. Used entirely for error
|
||||||
reporting.
|
reporting.
|
||||||
|
@ -123,10 +139,12 @@ def _cf_string_to_unicode(value):
|
||||||
string = buffer.value
|
string = buffer.value
|
||||||
if string is not None:
|
if string is not None:
|
||||||
string = string.decode("utf-8")
|
string = string.decode("utf-8")
|
||||||
return string
|
return string # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
def _assert_no_error(error, exception_class=None):
|
def _assert_no_error(
|
||||||
|
error: int, exception_class: type[BaseException] | None = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Checks the return code and throws an exception if there is an error to
|
Checks the return code and throws an exception if there is an error to
|
||||||
report
|
report
|
||||||
|
@ -138,8 +156,8 @@ def _assert_no_error(error, exception_class=None):
|
||||||
output = _cf_string_to_unicode(cf_error_string)
|
output = _cf_string_to_unicode(cf_error_string)
|
||||||
CoreFoundation.CFRelease(cf_error_string)
|
CoreFoundation.CFRelease(cf_error_string)
|
||||||
|
|
||||||
if output is None or output == u"":
|
if output is None or output == "":
|
||||||
output = u"OSStatus %s" % error
|
output = f"OSStatus {error}"
|
||||||
|
|
||||||
if exception_class is None:
|
if exception_class is None:
|
||||||
exception_class = ssl.SSLError
|
exception_class = ssl.SSLError
|
||||||
|
@ -147,7 +165,7 @@ def _assert_no_error(error, exception_class=None):
|
||||||
raise exception_class(output)
|
raise exception_class(output)
|
||||||
|
|
||||||
|
|
||||||
def _cert_array_from_pem(pem_bundle):
|
def _cert_array_from_pem(pem_bundle: bytes) -> CFArray:
|
||||||
"""
|
"""
|
||||||
Given a bundle of certs in PEM format, turns them into a CFArray of certs
|
Given a bundle of certs in PEM format, turns them into a CFArray of certs
|
||||||
that can be used to validate a cert chain.
|
that can be used to validate a cert chain.
|
||||||
|
@ -193,23 +211,23 @@ def _cert_array_from_pem(pem_bundle):
|
||||||
return cert_array
|
return cert_array
|
||||||
|
|
||||||
|
|
||||||
def _is_cert(item):
|
def _is_cert(item: CFTypeRef) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns True if a given CFTypeRef is a certificate.
|
Returns True if a given CFTypeRef is a certificate.
|
||||||
"""
|
"""
|
||||||
expected = Security.SecCertificateGetTypeID()
|
expected = Security.SecCertificateGetTypeID()
|
||||||
return CoreFoundation.CFGetTypeID(item) == expected
|
return CoreFoundation.CFGetTypeID(item) == expected # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
def _is_identity(item):
|
def _is_identity(item: CFTypeRef) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns True if a given CFTypeRef is an identity.
|
Returns True if a given CFTypeRef is an identity.
|
||||||
"""
|
"""
|
||||||
expected = Security.SecIdentityGetTypeID()
|
expected = Security.SecIdentityGetTypeID()
|
||||||
return CoreFoundation.CFGetTypeID(item) == expected
|
return CoreFoundation.CFGetTypeID(item) == expected # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
def _temporary_keychain():
|
def _temporary_keychain() -> tuple[SecKeychainRef, str]:
|
||||||
"""
|
"""
|
||||||
This function creates a temporary Mac keychain that we can use to work with
|
This function creates a temporary Mac keychain that we can use to work with
|
||||||
credentials. This keychain uses a one-time password and a temporary file to
|
credentials. This keychain uses a one-time password and a temporary file to
|
||||||
|
@ -244,7 +262,9 @@ def _temporary_keychain():
|
||||||
return keychain, tempdirectory
|
return keychain, tempdirectory
|
||||||
|
|
||||||
|
|
||||||
def _load_items_from_file(keychain, path):
|
def _load_items_from_file(
|
||||||
|
keychain: SecKeychainRef, path: str
|
||||||
|
) -> tuple[list[CFTypeRef], list[CFTypeRef]]:
|
||||||
"""
|
"""
|
||||||
Given a single file, loads all the trust objects from it into arrays and
|
Given a single file, loads all the trust objects from it into arrays and
|
||||||
the keychain.
|
the keychain.
|
||||||
|
@ -299,7 +319,7 @@ def _load_items_from_file(keychain, path):
|
||||||
return (identities, certificates)
|
return (identities, certificates)
|
||||||
|
|
||||||
|
|
||||||
def _load_client_cert_chain(keychain, *paths):
|
def _load_client_cert_chain(keychain: SecKeychainRef, *paths: str | None) -> CFArray:
|
||||||
"""
|
"""
|
||||||
Load certificates and maybe keys from a number of files. Has the end goal
|
Load certificates and maybe keys from a number of files. Has the end goal
|
||||||
of returning a CFArray containing one SecIdentityRef, and then zero or more
|
of returning a CFArray containing one SecIdentityRef, and then zero or more
|
||||||
|
@ -335,10 +355,10 @@ def _load_client_cert_chain(keychain, *paths):
|
||||||
identities = []
|
identities = []
|
||||||
|
|
||||||
# Filter out bad paths.
|
# Filter out bad paths.
|
||||||
paths = (path for path in paths if path)
|
filtered_paths = (path for path in paths if path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for file_path in paths:
|
for file_path in filtered_paths:
|
||||||
new_identities, new_certs = _load_items_from_file(keychain, file_path)
|
new_identities, new_certs = _load_items_from_file(keychain, file_path)
|
||||||
identities.extend(new_identities)
|
identities.extend(new_identities)
|
||||||
certificates.extend(new_certs)
|
certificates.extend(new_certs)
|
||||||
|
@ -383,7 +403,7 @@ TLS_PROTOCOL_VERSIONS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _build_tls_unknown_ca_alert(version):
|
def _build_tls_unknown_ca_alert(version: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
Builds a TLS alert record for an unknown CA.
|
Builds a TLS alert record for an unknown CA.
|
||||||
"""
|
"""
|
||||||
|
@ -395,3 +415,60 @@ def _build_tls_unknown_ca_alert(version):
|
||||||
record_type_alert = 0x15
|
record_type_alert = 0x15
|
||||||
record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg
|
record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityConst:
|
||||||
|
"""
|
||||||
|
A class object that acts as essentially a namespace for Security constants.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kSSLSessionOptionBreakOnServerAuth = 0
|
||||||
|
|
||||||
|
kSSLProtocol2 = 1
|
||||||
|
kSSLProtocol3 = 2
|
||||||
|
kTLSProtocol1 = 4
|
||||||
|
kTLSProtocol11 = 7
|
||||||
|
kTLSProtocol12 = 8
|
||||||
|
# SecureTransport does not support TLS 1.3 even if there's a constant for it
|
||||||
|
kTLSProtocol13 = 10
|
||||||
|
kTLSProtocolMaxSupported = 999
|
||||||
|
|
||||||
|
kSSLClientSide = 1
|
||||||
|
kSSLStreamType = 0
|
||||||
|
|
||||||
|
kSecFormatPEMSequence = 10
|
||||||
|
|
||||||
|
kSecTrustResultInvalid = 0
|
||||||
|
kSecTrustResultProceed = 1
|
||||||
|
# This gap is present on purpose: this was kSecTrustResultConfirm, which
|
||||||
|
# is deprecated.
|
||||||
|
kSecTrustResultDeny = 3
|
||||||
|
kSecTrustResultUnspecified = 4
|
||||||
|
kSecTrustResultRecoverableTrustFailure = 5
|
||||||
|
kSecTrustResultFatalTrustFailure = 6
|
||||||
|
kSecTrustResultOtherError = 7
|
||||||
|
|
||||||
|
errSSLProtocol = -9800
|
||||||
|
errSSLWouldBlock = -9803
|
||||||
|
errSSLClosedGraceful = -9805
|
||||||
|
errSSLClosedNoNotify = -9816
|
||||||
|
errSSLClosedAbort = -9806
|
||||||
|
|
||||||
|
errSSLXCertChainInvalid = -9807
|
||||||
|
errSSLCrypto = -9809
|
||||||
|
errSSLInternal = -9810
|
||||||
|
errSSLCertExpired = -9814
|
||||||
|
errSSLCertNotYetValid = -9815
|
||||||
|
errSSLUnknownRootCert = -9812
|
||||||
|
errSSLNoRootCert = -9813
|
||||||
|
errSSLHostNameMismatch = -9843
|
||||||
|
errSSLPeerHandshakeFail = -9824
|
||||||
|
errSSLPeerUserCancelled = -9839
|
||||||
|
errSSLWeakPeerEphemeralDHKey = -9850
|
||||||
|
errSSLServerAuthCompleted = -9841
|
||||||
|
errSSLRecordOverflow = -9847
|
||||||
|
|
||||||
|
errSecVerifyFailed = -67808
|
||||||
|
errSecNoTrustSettings = -25263
|
||||||
|
errSecItemNotFound = -25300
|
||||||
|
errSecInvalidTrustSettings = -25262
|
||||||
|
|
|
@ -1,314 +0,0 @@
|
||||||
"""
|
|
||||||
This module provides a pool manager that uses Google App Engine's
|
|
||||||
`URLFetch Service <https://cloud.google.com/appengine/docs/python/urlfetch>`_.
|
|
||||||
|
|
||||||
Example usage::
|
|
||||||
|
|
||||||
from urllib3 import PoolManager
|
|
||||||
from urllib3.contrib.appengine import AppEngineManager, is_appengine_sandbox
|
|
||||||
|
|
||||||
if is_appengine_sandbox():
|
|
||||||
# AppEngineManager uses AppEngine's URLFetch API behind the scenes
|
|
||||||
http = AppEngineManager()
|
|
||||||
else:
|
|
||||||
# PoolManager uses a socket-level API behind the scenes
|
|
||||||
http = PoolManager()
|
|
||||||
|
|
||||||
r = http.request('GET', 'https://google.com/')
|
|
||||||
|
|
||||||
There are `limitations <https://cloud.google.com/appengine/docs/python/\
|
|
||||||
urlfetch/#Python_Quotas_and_limits>`_ to the URLFetch service and it may not be
|
|
||||||
the best choice for your application. There are three options for using
|
|
||||||
urllib3 on Google App Engine:
|
|
||||||
|
|
||||||
1. You can use :class:`AppEngineManager` with URLFetch. URLFetch is
|
|
||||||
cost-effective in many circumstances as long as your usage is within the
|
|
||||||
limitations.
|
|
||||||
2. You can use a normal :class:`~urllib3.PoolManager` by enabling sockets.
|
|
||||||
Sockets also have `limitations and restrictions
|
|
||||||
<https://cloud.google.com/appengine/docs/python/sockets/\
|
|
||||||
#limitations-and-restrictions>`_ and have a lower free quota than URLFetch.
|
|
||||||
To use sockets, be sure to specify the following in your ``app.yaml``::
|
|
||||||
|
|
||||||
env_variables:
|
|
||||||
GAE_USE_SOCKETS_HTTPLIB : 'true'
|
|
||||||
|
|
||||||
3. If you are using `App Engine Flexible
|
|
||||||
<https://cloud.google.com/appengine/docs/flexible/>`_, you can use the standard
|
|
||||||
:class:`PoolManager` without any configuration or special environment variables.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from ..exceptions import (
|
|
||||||
HTTPError,
|
|
||||||
HTTPWarning,
|
|
||||||
MaxRetryError,
|
|
||||||
ProtocolError,
|
|
||||||
SSLError,
|
|
||||||
TimeoutError,
|
|
||||||
)
|
|
||||||
from ..packages.six.moves.urllib.parse import urljoin
|
|
||||||
from ..request import RequestMethods
|
|
||||||
from ..response import HTTPResponse
|
|
||||||
from ..util.retry import Retry
|
|
||||||
from ..util.timeout import Timeout
|
|
||||||
from . import _appengine_environ
|
|
||||||
|
|
||||||
try:
|
|
||||||
from google.appengine.api import urlfetch
|
|
||||||
except ImportError:
|
|
||||||
urlfetch = None
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AppEnginePlatformWarning(HTTPWarning):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AppEnginePlatformError(HTTPError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AppEngineManager(RequestMethods):
|
|
||||||
"""
|
|
||||||
Connection manager for Google App Engine sandbox applications.
|
|
||||||
|
|
||||||
This manager uses the URLFetch service directly instead of using the
|
|
||||||
emulated httplib, and is subject to URLFetch limitations as described in
|
|
||||||
the App Engine documentation `here
|
|
||||||
<https://cloud.google.com/appengine/docs/python/urlfetch>`_.
|
|
||||||
|
|
||||||
Notably it will raise an :class:`AppEnginePlatformError` if:
|
|
||||||
* URLFetch is not available.
|
|
||||||
* If you attempt to use this on App Engine Flexible, as full socket
|
|
||||||
support is available.
|
|
||||||
* If a request size is more than 10 megabytes.
|
|
||||||
* If a response size is more than 32 megabytes.
|
|
||||||
* If you use an unsupported request method such as OPTIONS.
|
|
||||||
|
|
||||||
Beyond those cases, it will raise normal urllib3 errors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
headers=None,
|
|
||||||
retries=None,
|
|
||||||
validate_certificate=True,
|
|
||||||
urlfetch_retries=True,
|
|
||||||
):
|
|
||||||
if not urlfetch:
|
|
||||||
raise AppEnginePlatformError(
|
|
||||||
"URLFetch is not available in this environment."
|
|
||||||
)
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"urllib3 is using URLFetch on Google App Engine sandbox instead "
|
|
||||||
"of sockets. To use sockets directly instead of URLFetch see "
|
|
||||||
"https://urllib3.readthedocs.io/en/1.26.x/reference/urllib3.contrib.html.",
|
|
||||||
AppEnginePlatformWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
RequestMethods.__init__(self, headers)
|
|
||||||
self.validate_certificate = validate_certificate
|
|
||||||
self.urlfetch_retries = urlfetch_retries
|
|
||||||
|
|
||||||
self.retries = retries or Retry.DEFAULT
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
# Return False to re-raise any potential exceptions
|
|
||||||
return False
|
|
||||||
|
|
||||||
def urlopen(
|
|
||||||
self,
|
|
||||||
method,
|
|
||||||
url,
|
|
||||||
body=None,
|
|
||||||
headers=None,
|
|
||||||
retries=None,
|
|
||||||
redirect=True,
|
|
||||||
timeout=Timeout.DEFAULT_TIMEOUT,
|
|
||||||
**response_kw
|
|
||||||
):
|
|
||||||
|
|
||||||
retries = self._get_retries(retries, redirect)
|
|
||||||
|
|
||||||
try:
|
|
||||||
follow_redirects = redirect and retries.redirect != 0 and retries.total
|
|
||||||
response = urlfetch.fetch(
|
|
||||||
url,
|
|
||||||
payload=body,
|
|
||||||
method=method,
|
|
||||||
headers=headers or {},
|
|
||||||
allow_truncated=False,
|
|
||||||
follow_redirects=self.urlfetch_retries and follow_redirects,
|
|
||||||
deadline=self._get_absolute_timeout(timeout),
|
|
||||||
validate_certificate=self.validate_certificate,
|
|
||||||
)
|
|
||||||
except urlfetch.DeadlineExceededError as e:
|
|
||||||
raise TimeoutError(self, e)
|
|
||||||
|
|
||||||
except urlfetch.InvalidURLError as e:
|
|
||||||
if "too large" in str(e):
|
|
||||||
raise AppEnginePlatformError(
|
|
||||||
"URLFetch request too large, URLFetch only "
|
|
||||||
"supports requests up to 10mb in size.",
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
raise ProtocolError(e)
|
|
||||||
|
|
||||||
except urlfetch.DownloadError as e:
|
|
||||||
if "Too many redirects" in str(e):
|
|
||||||
raise MaxRetryError(self, url, reason=e)
|
|
||||||
raise ProtocolError(e)
|
|
||||||
|
|
||||||
except urlfetch.ResponseTooLargeError as e:
|
|
||||||
raise AppEnginePlatformError(
|
|
||||||
"URLFetch response too large, URLFetch only supports"
|
|
||||||
"responses up to 32mb in size.",
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
|
|
||||||
except urlfetch.SSLCertificateError as e:
|
|
||||||
raise SSLError(e)
|
|
||||||
|
|
||||||
except urlfetch.InvalidMethodError as e:
|
|
||||||
raise AppEnginePlatformError(
|
|
||||||
"URLFetch does not support method: %s" % method, e
|
|
||||||
)
|
|
||||||
|
|
||||||
http_response = self._urlfetch_response_to_http_response(
|
|
||||||
response, retries=retries, **response_kw
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle redirect?
|
|
||||||
redirect_location = redirect and http_response.get_redirect_location()
|
|
||||||
if redirect_location:
|
|
||||||
# Check for redirect response
|
|
||||||
if self.urlfetch_retries and retries.raise_on_redirect:
|
|
||||||
raise MaxRetryError(self, url, "too many redirects")
|
|
||||||
else:
|
|
||||||
if http_response.status == 303:
|
|
||||||
method = "GET"
|
|
||||||
|
|
||||||
try:
|
|
||||||
retries = retries.increment(
|
|
||||||
method, url, response=http_response, _pool=self
|
|
||||||
)
|
|
||||||
except MaxRetryError:
|
|
||||||
if retries.raise_on_redirect:
|
|
||||||
raise MaxRetryError(self, url, "too many redirects")
|
|
||||||
return http_response
|
|
||||||
|
|
||||||
retries.sleep_for_retry(http_response)
|
|
||||||
log.debug("Redirecting %s -> %s", url, redirect_location)
|
|
||||||
redirect_url = urljoin(url, redirect_location)
|
|
||||||
return self.urlopen(
|
|
||||||
method,
|
|
||||||
redirect_url,
|
|
||||||
body,
|
|
||||||
headers,
|
|
||||||
retries=retries,
|
|
||||||
redirect=redirect,
|
|
||||||
timeout=timeout,
|
|
||||||
**response_kw
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if we should retry the HTTP response.
|
|
||||||
has_retry_after = bool(http_response.headers.get("Retry-After"))
|
|
||||||
if retries.is_retry(method, http_response.status, has_retry_after):
|
|
||||||
retries = retries.increment(method, url, response=http_response, _pool=self)
|
|
||||||
log.debug("Retry: %s", url)
|
|
||||||
retries.sleep(http_response)
|
|
||||||
return self.urlopen(
|
|
||||||
method,
|
|
||||||
url,
|
|
||||||
body=body,
|
|
||||||
headers=headers,
|
|
||||||
retries=retries,
|
|
||||||
redirect=redirect,
|
|
||||||
timeout=timeout,
|
|
||||||
**response_kw
|
|
||||||
)
|
|
||||||
|
|
||||||
return http_response
|
|
||||||
|
|
||||||
def _urlfetch_response_to_http_response(self, urlfetch_resp, **response_kw):
|
|
||||||
|
|
||||||
if is_prod_appengine():
|
|
||||||
# Production GAE handles deflate encoding automatically, but does
|
|
||||||
# not remove the encoding header.
|
|
||||||
content_encoding = urlfetch_resp.headers.get("content-encoding")
|
|
||||||
|
|
||||||
if content_encoding == "deflate":
|
|
||||||
del urlfetch_resp.headers["content-encoding"]
|
|
||||||
|
|
||||||
transfer_encoding = urlfetch_resp.headers.get("transfer-encoding")
|
|
||||||
# We have a full response's content,
|
|
||||||
# so let's make sure we don't report ourselves as chunked data.
|
|
||||||
if transfer_encoding == "chunked":
|
|
||||||
encodings = transfer_encoding.split(",")
|
|
||||||
encodings.remove("chunked")
|
|
||||||
urlfetch_resp.headers["transfer-encoding"] = ",".join(encodings)
|
|
||||||
|
|
||||||
original_response = HTTPResponse(
|
|
||||||
# In order for decoding to work, we must present the content as
|
|
||||||
# a file-like object.
|
|
||||||
body=io.BytesIO(urlfetch_resp.content),
|
|
||||||
msg=urlfetch_resp.header_msg,
|
|
||||||
headers=urlfetch_resp.headers,
|
|
||||||
status=urlfetch_resp.status_code,
|
|
||||||
**response_kw
|
|
||||||
)
|
|
||||||
|
|
||||||
return HTTPResponse(
|
|
||||||
body=io.BytesIO(urlfetch_resp.content),
|
|
||||||
headers=urlfetch_resp.headers,
|
|
||||||
status=urlfetch_resp.status_code,
|
|
||||||
original_response=original_response,
|
|
||||||
**response_kw
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_absolute_timeout(self, timeout):
|
|
||||||
if timeout is Timeout.DEFAULT_TIMEOUT:
|
|
||||||
return None # Defer to URLFetch's default.
|
|
||||||
if isinstance(timeout, Timeout):
|
|
||||||
if timeout._read is not None or timeout._connect is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"URLFetch does not support granular timeout settings, "
|
|
||||||
"reverting to total or default URLFetch timeout.",
|
|
||||||
AppEnginePlatformWarning,
|
|
||||||
)
|
|
||||||
return timeout.total
|
|
||||||
return timeout
|
|
||||||
|
|
||||||
def _get_retries(self, retries, redirect):
|
|
||||||
if not isinstance(retries, Retry):
|
|
||||||
retries = Retry.from_int(retries, redirect=redirect, default=self.retries)
|
|
||||||
|
|
||||||
if retries.connect or retries.read or retries.redirect:
|
|
||||||
warnings.warn(
|
|
||||||
"URLFetch only supports total retries and does not "
|
|
||||||
"recognize connect, read, or redirect retry parameters.",
|
|
||||||
AppEnginePlatformWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
return retries
|
|
||||||
|
|
||||||
|
|
||||||
# Alias methods from _appengine_environ to maintain public API interface.
|
|
||||||
|
|
||||||
is_appengine = _appengine_environ.is_appengine
|
|
||||||
is_appengine_sandbox = _appengine_environ.is_appengine_sandbox
|
|
||||||
is_local_appengine = _appengine_environ.is_local_appengine
|
|
||||||
is_prod_appengine = _appengine_environ.is_prod_appengine
|
|
||||||
is_prod_appengine_mvms = _appengine_environ.is_prod_appengine_mvms
|
|
|
@ -1,130 +0,0 @@
|
||||||
"""
|
|
||||||
NTLM authenticating pool, contributed by erikcederstran
|
|
||||||
|
|
||||||
Issue #10, see: http://code.google.com/p/urllib3/issues/detail?id=10
|
|
||||||
"""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
from ntlm import ntlm
|
|
||||||
|
|
||||||
from .. import HTTPSConnectionPool
|
|
||||||
from ..packages.six.moves.http_client import HTTPSConnection
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"The 'urllib3.contrib.ntlmpool' module is deprecated and will be removed "
|
|
||||||
"in urllib3 v2.0 release, urllib3 is not able to support it properly due "
|
|
||||||
"to reasons listed in issue: https://github.com/urllib3/urllib3/issues/2282. "
|
|
||||||
"If you are a user of this module please comment in the mentioned issue.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
log = getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class NTLMConnectionPool(HTTPSConnectionPool):
|
|
||||||
"""
|
|
||||||
Implements an NTLM authentication version of an urllib3 connection pool
|
|
||||||
"""
|
|
||||||
|
|
||||||
scheme = "https"
|
|
||||||
|
|
||||||
def __init__(self, user, pw, authurl, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
authurl is a random URL on the server that is protected by NTLM.
|
|
||||||
user is the Windows user, probably in the DOMAIN\\username format.
|
|
||||||
pw is the password for the user.
|
|
||||||
"""
|
|
||||||
super(NTLMConnectionPool, self).__init__(*args, **kwargs)
|
|
||||||
self.authurl = authurl
|
|
||||||
self.rawuser = user
|
|
||||||
user_parts = user.split("\\", 1)
|
|
||||||
self.domain = user_parts[0].upper()
|
|
||||||
self.user = user_parts[1]
|
|
||||||
self.pw = pw
|
|
||||||
|
|
||||||
def _new_conn(self):
|
|
||||||
# Performs the NTLM handshake that secures the connection. The socket
|
|
||||||
# must be kept open while requests are performed.
|
|
||||||
self.num_connections += 1
|
|
||||||
log.debug(
|
|
||||||
"Starting NTLM HTTPS connection no. %d: https://%s%s",
|
|
||||||
self.num_connections,
|
|
||||||
self.host,
|
|
||||||
self.authurl,
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = {"Connection": "Keep-Alive"}
|
|
||||||
req_header = "Authorization"
|
|
||||||
resp_header = "www-authenticate"
|
|
||||||
|
|
||||||
conn = HTTPSConnection(host=self.host, port=self.port)
|
|
||||||
|
|
||||||
# Send negotiation message
|
|
||||||
headers[req_header] = "NTLM %s" % ntlm.create_NTLM_NEGOTIATE_MESSAGE(
|
|
||||||
self.rawuser
|
|
||||||
)
|
|
||||||
log.debug("Request headers: %s", headers)
|
|
||||||
conn.request("GET", self.authurl, None, headers)
|
|
||||||
res = conn.getresponse()
|
|
||||||
reshdr = dict(res.headers)
|
|
||||||
log.debug("Response status: %s %s", res.status, res.reason)
|
|
||||||
log.debug("Response headers: %s", reshdr)
|
|
||||||
log.debug("Response data: %s [...]", res.read(100))
|
|
||||||
|
|
||||||
# Remove the reference to the socket, so that it can not be closed by
|
|
||||||
# the response object (we want to keep the socket open)
|
|
||||||
res.fp = None
|
|
||||||
|
|
||||||
# Server should respond with a challenge message
|
|
||||||
auth_header_values = reshdr[resp_header].split(", ")
|
|
||||||
auth_header_value = None
|
|
||||||
for s in auth_header_values:
|
|
||||||
if s[:5] == "NTLM ":
|
|
||||||
auth_header_value = s[5:]
|
|
||||||
if auth_header_value is None:
|
|
||||||
raise Exception(
|
|
||||||
"Unexpected %s response header: %s" % (resp_header, reshdr[resp_header])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send authentication message
|
|
||||||
ServerChallenge, NegotiateFlags = ntlm.parse_NTLM_CHALLENGE_MESSAGE(
|
|
||||||
auth_header_value
|
|
||||||
)
|
|
||||||
auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(
|
|
||||||
ServerChallenge, self.user, self.domain, self.pw, NegotiateFlags
|
|
||||||
)
|
|
||||||
headers[req_header] = "NTLM %s" % auth_msg
|
|
||||||
log.debug("Request headers: %s", headers)
|
|
||||||
conn.request("GET", self.authurl, None, headers)
|
|
||||||
res = conn.getresponse()
|
|
||||||
log.debug("Response status: %s %s", res.status, res.reason)
|
|
||||||
log.debug("Response headers: %s", dict(res.headers))
|
|
||||||
log.debug("Response data: %s [...]", res.read()[:100])
|
|
||||||
if res.status != 200:
|
|
||||||
if res.status == 401:
|
|
||||||
raise Exception("Server rejected request: wrong username or password")
|
|
||||||
raise Exception("Wrong server response: %s %s" % (res.status, res.reason))
|
|
||||||
|
|
||||||
res.fp = None
|
|
||||||
log.debug("Connection established")
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def urlopen(
|
|
||||||
self,
|
|
||||||
method,
|
|
||||||
url,
|
|
||||||
body=None,
|
|
||||||
headers=None,
|
|
||||||
retries=3,
|
|
||||||
redirect=True,
|
|
||||||
assert_same_host=True,
|
|
||||||
):
|
|
||||||
if headers is None:
|
|
||||||
headers = {}
|
|
||||||
headers["Connection"] = "Keep-Alive"
|
|
||||||
return super(NTLMConnectionPool, self).urlopen(
|
|
||||||
method, url, body, headers, retries, redirect, assert_same_host
|
|
||||||
)
|
|
|
@ -1,8 +1,8 @@
|
||||||
"""
|
"""
|
||||||
TLS with SNI_-support for Python 2. Follow these instructions if you would
|
Module for using pyOpenSSL as a TLS backend. This module was relevant before
|
||||||
like to verify TLS certificates in Python 2. Note, the default libraries do
|
the standard library ``ssl`` module supported SNI, but now that we've dropped
|
||||||
*not* do certificate checking; you need to do additional work to validate
|
support for Python 2.7 all relevant Python versions support SNI so
|
||||||
certificates yourself.
|
**this module is no longer recommended**.
|
||||||
|
|
||||||
This needs the following packages installed:
|
This needs the following packages installed:
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ This needs the following packages installed:
|
||||||
* `cryptography`_ (minimum 1.3.4, from pyopenssl)
|
* `cryptography`_ (minimum 1.3.4, from pyopenssl)
|
||||||
* `idna`_ (minimum 2.0, from cryptography)
|
* `idna`_ (minimum 2.0, from cryptography)
|
||||||
|
|
||||||
However, pyopenssl depends on cryptography, which depends on idna, so while we
|
However, pyOpenSSL depends on cryptography, which depends on idna, so while we
|
||||||
use all three directly here we end up having relatively few packages required.
|
use all three directly here we end up having relatively few packages required.
|
||||||
|
|
||||||
You can install them with the following command:
|
You can install them with the following command:
|
||||||
|
@ -33,75 +33,55 @@ like this:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
Now you can use :mod:`urllib3` as you normally would, and it will support SNI
|
|
||||||
when the required modules are installed.
|
|
||||||
|
|
||||||
Activating this module also has the positive side effect of disabling SSL/TLS
|
|
||||||
compression in Python 2 (see `CRIME attack`_).
|
|
||||||
|
|
||||||
.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication
|
|
||||||
.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit)
|
|
||||||
.. _pyopenssl: https://www.pyopenssl.org
|
.. _pyopenssl: https://www.pyopenssl.org
|
||||||
.. _cryptography: https://cryptography.io
|
.. _cryptography: https://cryptography.io
|
||||||
.. _idna: https://github.com/kjd/idna
|
.. _idna: https://github.com/kjd/idna
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import OpenSSL.crypto
|
from __future__ import annotations
|
||||||
import OpenSSL.SSL
|
|
||||||
|
import OpenSSL.SSL # type: ignore[import]
|
||||||
from cryptography import x509
|
from cryptography import x509
|
||||||
from cryptography.hazmat.backends.openssl import backend as openssl_backend
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from cryptography.x509 import UnsupportedExtension
|
from cryptography.x509 import UnsupportedExtension # type: ignore[attr-defined]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# UnsupportedExtension is gone in cryptography >= 2.1.0
|
# UnsupportedExtension is gone in cryptography >= 2.1.0
|
||||||
class UnsupportedExtension(Exception):
|
class UnsupportedExtension(Exception): # type: ignore[no-redef]
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
from io import BytesIO
|
|
||||||
from socket import error as SocketError
|
|
||||||
from socket import timeout
|
|
||||||
|
|
||||||
try: # Platform-specific: Python 2
|
|
||||||
from socket import _fileobject
|
|
||||||
except ImportError: # Platform-specific: Python 3
|
|
||||||
_fileobject = None
|
|
||||||
from ..packages.backports.makefile import backport_makefile
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
|
from io import BytesIO
|
||||||
|
from socket import socket as socket_cls
|
||||||
|
from socket import timeout
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..packages import six
|
|
||||||
from ..util.ssl_ import PROTOCOL_TLS_CLIENT
|
|
||||||
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"'urllib3.contrib.pyopenssl' module is deprecated and will be removed "
|
"'urllib3.contrib.pyopenssl' module is deprecated and will be removed "
|
||||||
"in a future release of urllib3 2.x. Read more in this issue: "
|
"in urllib3 v2.1.0. Read more in this issue: "
|
||||||
"https://github.com/urllib3/urllib3/issues/2680",
|
"https://github.com/urllib3/urllib3/issues/2680",
|
||||||
category=DeprecationWarning,
|
category=DeprecationWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
|
if typing.TYPE_CHECKING:
|
||||||
|
from OpenSSL.crypto import X509 # type: ignore[import]
|
||||||
|
|
||||||
# SNI always works.
|
|
||||||
HAS_SNI = True
|
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
|
||||||
|
|
||||||
# Map from urllib3 to PyOpenSSL compatible parameter-values.
|
# Map from urllib3 to PyOpenSSL compatible parameter-values.
|
||||||
_openssl_versions = {
|
_openssl_versions = {
|
||||||
util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD,
|
util.ssl_.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined]
|
||||||
PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD,
|
util.ssl_.PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined]
|
||||||
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
|
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasattr(ssl, "PROTOCOL_SSLv3") and hasattr(OpenSSL.SSL, "SSLv3_METHOD"):
|
|
||||||
_openssl_versions[ssl.PROTOCOL_SSLv3] = OpenSSL.SSL.SSLv3_METHOD
|
|
||||||
|
|
||||||
if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
|
if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
|
||||||
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
|
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
|
||||||
|
|
||||||
|
@ -115,43 +95,77 @@ _stdlib_to_openssl_verify = {
|
||||||
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
|
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
|
||||||
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
||||||
}
|
}
|
||||||
_openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items())
|
_openssl_to_stdlib_verify = {v: k for k, v in _stdlib_to_openssl_verify.items()}
|
||||||
|
|
||||||
|
# The SSLvX values are the most likely to be missing in the future
|
||||||
|
# but we check them all just to be sure.
|
||||||
|
_OP_NO_SSLv2_OR_SSLv3: int = getattr(OpenSSL.SSL, "OP_NO_SSLv2", 0) | getattr(
|
||||||
|
OpenSSL.SSL, "OP_NO_SSLv3", 0
|
||||||
|
)
|
||||||
|
_OP_NO_TLSv1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1", 0)
|
||||||
|
_OP_NO_TLSv1_1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_1", 0)
|
||||||
|
_OP_NO_TLSv1_2: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_2", 0)
|
||||||
|
_OP_NO_TLSv1_3: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_3", 0)
|
||||||
|
|
||||||
|
_openssl_to_ssl_minimum_version: dict[int, int] = {
|
||||||
|
ssl.TLSVersion.MINIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3,
|
||||||
|
ssl.TLSVersion.TLSv1: _OP_NO_SSLv2_OR_SSLv3,
|
||||||
|
ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1,
|
||||||
|
ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1,
|
||||||
|
ssl.TLSVersion.TLSv1_3: (
|
||||||
|
_OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2
|
||||||
|
),
|
||||||
|
ssl.TLSVersion.MAXIMUM_SUPPORTED: (
|
||||||
|
_OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2
|
||||||
|
),
|
||||||
|
}
|
||||||
|
_openssl_to_ssl_maximum_version: dict[int, int] = {
|
||||||
|
ssl.TLSVersion.MINIMUM_SUPPORTED: (
|
||||||
|
_OP_NO_SSLv2_OR_SSLv3
|
||||||
|
| _OP_NO_TLSv1
|
||||||
|
| _OP_NO_TLSv1_1
|
||||||
|
| _OP_NO_TLSv1_2
|
||||||
|
| _OP_NO_TLSv1_3
|
||||||
|
),
|
||||||
|
ssl.TLSVersion.TLSv1: (
|
||||||
|
_OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3
|
||||||
|
),
|
||||||
|
ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3,
|
||||||
|
ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_3,
|
||||||
|
ssl.TLSVersion.TLSv1_3: _OP_NO_SSLv2_OR_SSLv3,
|
||||||
|
ssl.TLSVersion.MAXIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3,
|
||||||
|
}
|
||||||
|
|
||||||
# OpenSSL will only write 16K at a time
|
# OpenSSL will only write 16K at a time
|
||||||
SSL_WRITE_BLOCKSIZE = 16384
|
SSL_WRITE_BLOCKSIZE = 16384
|
||||||
|
|
||||||
orig_util_HAS_SNI = util.HAS_SNI
|
|
||||||
orig_util_SSLContext = util.ssl_.SSLContext
|
orig_util_SSLContext = util.ssl_.SSLContext
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def inject_into_urllib3():
|
def inject_into_urllib3() -> None:
|
||||||
"Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
|
"Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
|
||||||
|
|
||||||
_validate_dependencies_met()
|
_validate_dependencies_met()
|
||||||
|
|
||||||
util.SSLContext = PyOpenSSLContext
|
util.SSLContext = PyOpenSSLContext # type: ignore[assignment]
|
||||||
util.ssl_.SSLContext = PyOpenSSLContext
|
util.ssl_.SSLContext = PyOpenSSLContext # type: ignore[assignment]
|
||||||
util.HAS_SNI = HAS_SNI
|
|
||||||
util.ssl_.HAS_SNI = HAS_SNI
|
|
||||||
util.IS_PYOPENSSL = True
|
util.IS_PYOPENSSL = True
|
||||||
util.ssl_.IS_PYOPENSSL = True
|
util.ssl_.IS_PYOPENSSL = True
|
||||||
|
|
||||||
|
|
||||||
def extract_from_urllib3():
|
def extract_from_urllib3() -> None:
|
||||||
"Undo monkey-patching by :func:`inject_into_urllib3`."
|
"Undo monkey-patching by :func:`inject_into_urllib3`."
|
||||||
|
|
||||||
util.SSLContext = orig_util_SSLContext
|
util.SSLContext = orig_util_SSLContext
|
||||||
util.ssl_.SSLContext = orig_util_SSLContext
|
util.ssl_.SSLContext = orig_util_SSLContext
|
||||||
util.HAS_SNI = orig_util_HAS_SNI
|
|
||||||
util.ssl_.HAS_SNI = orig_util_HAS_SNI
|
|
||||||
util.IS_PYOPENSSL = False
|
util.IS_PYOPENSSL = False
|
||||||
util.ssl_.IS_PYOPENSSL = False
|
util.ssl_.IS_PYOPENSSL = False
|
||||||
|
|
||||||
|
|
||||||
def _validate_dependencies_met():
|
def _validate_dependencies_met() -> None:
|
||||||
"""
|
"""
|
||||||
Verifies that PyOpenSSL's package-level dependencies have been met.
|
Verifies that PyOpenSSL's package-level dependencies have been met.
|
||||||
Throws `ImportError` if they are not met.
|
Throws `ImportError` if they are not met.
|
||||||
|
@ -177,7 +191,7 @@ def _validate_dependencies_met():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _dnsname_to_stdlib(name):
|
def _dnsname_to_stdlib(name: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
Converts a dNSName SubjectAlternativeName field to the form used by the
|
Converts a dNSName SubjectAlternativeName field to the form used by the
|
||||||
standard library on the given Python version.
|
standard library on the given Python version.
|
||||||
|
@ -191,7 +205,7 @@ def _dnsname_to_stdlib(name):
|
||||||
the name given should be skipped.
|
the name given should be skipped.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def idna_encode(name):
|
def idna_encode(name: str) -> bytes | None:
|
||||||
"""
|
"""
|
||||||
Borrowed wholesale from the Python Cryptography Project. It turns out
|
Borrowed wholesale from the Python Cryptography Project. It turns out
|
||||||
that we can't just safely call `idna.encode`: it can explode for
|
that we can't just safely call `idna.encode`: it can explode for
|
||||||
|
@ -200,7 +214,7 @@ def _dnsname_to_stdlib(name):
|
||||||
import idna
|
import idna
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for prefix in [u"*.", u"."]:
|
for prefix in ["*.", "."]:
|
||||||
if name.startswith(prefix):
|
if name.startswith(prefix):
|
||||||
name = name[len(prefix) :]
|
name = name[len(prefix) :]
|
||||||
return prefix.encode("ascii") + idna.encode(name)
|
return prefix.encode("ascii") + idna.encode(name)
|
||||||
|
@ -212,24 +226,17 @@ def _dnsname_to_stdlib(name):
|
||||||
if ":" in name:
|
if ":" in name:
|
||||||
return name
|
return name
|
||||||
|
|
||||||
name = idna_encode(name)
|
encoded_name = idna_encode(name)
|
||||||
if name is None:
|
if encoded_name is None:
|
||||||
return None
|
return None
|
||||||
elif sys.version_info >= (3, 0):
|
return encoded_name.decode("utf-8")
|
||||||
name = name.decode("utf-8")
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def get_subj_alt_name(peer_cert):
|
def get_subj_alt_name(peer_cert: X509) -> list[tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
Given an PyOpenSSL certificate, provides all the subject alternative names.
|
Given an PyOpenSSL certificate, provides all the subject alternative names.
|
||||||
"""
|
"""
|
||||||
# Pass the cert to cryptography, which has much better APIs for this.
|
cert = peer_cert.to_cryptography()
|
||||||
if hasattr(peer_cert, "to_cryptography"):
|
|
||||||
cert = peer_cert.to_cryptography()
|
|
||||||
else:
|
|
||||||
der = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, peer_cert)
|
|
||||||
cert = x509.load_der_x509_certificate(der, openssl_backend)
|
|
||||||
|
|
||||||
# We want to find the SAN extension. Ask Cryptography to locate it (it's
|
# We want to find the SAN extension. Ask Cryptography to locate it (it's
|
||||||
# faster than looping in Python)
|
# faster than looping in Python)
|
||||||
|
@ -273,93 +280,94 @@ def get_subj_alt_name(peer_cert):
|
||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
class WrappedSocket(object):
|
class WrappedSocket:
|
||||||
"""API-compatibility wrapper for Python OpenSSL's Connection-class.
|
"""API-compatibility wrapper for Python OpenSSL's Connection-class."""
|
||||||
|
|
||||||
Note: _makefile_refs, _drop() and _reuse() are needed for the garbage
|
def __init__(
|
||||||
collector of pypy.
|
self,
|
||||||
"""
|
connection: OpenSSL.SSL.Connection,
|
||||||
|
socket: socket_cls,
|
||||||
def __init__(self, connection, socket, suppress_ragged_eofs=True):
|
suppress_ragged_eofs: bool = True,
|
||||||
|
) -> None:
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.socket = socket
|
self.socket = socket
|
||||||
self.suppress_ragged_eofs = suppress_ragged_eofs
|
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||||||
self._makefile_refs = 0
|
self._io_refs = 0
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
def fileno(self):
|
def fileno(self) -> int:
|
||||||
return self.socket.fileno()
|
return self.socket.fileno()
|
||||||
|
|
||||||
# Copy-pasted from Python 3.5 source code
|
# Copy-pasted from Python 3.5 source code
|
||||||
def _decref_socketios(self):
|
def _decref_socketios(self) -> None:
|
||||||
if self._makefile_refs > 0:
|
if self._io_refs > 0:
|
||||||
self._makefile_refs -= 1
|
self._io_refs -= 1
|
||||||
if self._closed:
|
if self._closed:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def recv(self, *args, **kwargs):
|
def recv(self, *args: typing.Any, **kwargs: typing.Any) -> bytes:
|
||||||
try:
|
try:
|
||||||
data = self.connection.recv(*args, **kwargs)
|
data = self.connection.recv(*args, **kwargs)
|
||||||
except OpenSSL.SSL.SysCallError as e:
|
except OpenSSL.SSL.SysCallError as e:
|
||||||
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
|
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
|
||||||
return b""
|
return b""
|
||||||
else:
|
else:
|
||||||
raise SocketError(str(e))
|
raise OSError(e.args[0], str(e)) from e
|
||||||
except OpenSSL.SSL.ZeroReturnError:
|
except OpenSSL.SSL.ZeroReturnError:
|
||||||
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
|
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
|
||||||
return b""
|
return b""
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
except OpenSSL.SSL.WantReadError:
|
except OpenSSL.SSL.WantReadError as e:
|
||||||
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
|
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
|
||||||
raise timeout("The read operation timed out")
|
raise timeout("The read operation timed out") from e
|
||||||
else:
|
else:
|
||||||
return self.recv(*args, **kwargs)
|
return self.recv(*args, **kwargs)
|
||||||
|
|
||||||
# TLS 1.3 post-handshake authentication
|
# TLS 1.3 post-handshake authentication
|
||||||
except OpenSSL.SSL.Error as e:
|
except OpenSSL.SSL.Error as e:
|
||||||
raise ssl.SSLError("read error: %r" % e)
|
raise ssl.SSLError(f"read error: {e!r}") from e
|
||||||
else:
|
else:
|
||||||
return data
|
return data # type: ignore[no-any-return]
|
||||||
|
|
||||||
def recv_into(self, *args, **kwargs):
|
def recv_into(self, *args: typing.Any, **kwargs: typing.Any) -> int:
|
||||||
try:
|
try:
|
||||||
return self.connection.recv_into(*args, **kwargs)
|
return self.connection.recv_into(*args, **kwargs) # type: ignore[no-any-return]
|
||||||
except OpenSSL.SSL.SysCallError as e:
|
except OpenSSL.SSL.SysCallError as e:
|
||||||
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
|
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
raise SocketError(str(e))
|
raise OSError(e.args[0], str(e)) from e
|
||||||
except OpenSSL.SSL.ZeroReturnError:
|
except OpenSSL.SSL.ZeroReturnError:
|
||||||
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
|
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
except OpenSSL.SSL.WantReadError:
|
except OpenSSL.SSL.WantReadError as e:
|
||||||
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
|
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
|
||||||
raise timeout("The read operation timed out")
|
raise timeout("The read operation timed out") from e
|
||||||
else:
|
else:
|
||||||
return self.recv_into(*args, **kwargs)
|
return self.recv_into(*args, **kwargs)
|
||||||
|
|
||||||
# TLS 1.3 post-handshake authentication
|
# TLS 1.3 post-handshake authentication
|
||||||
except OpenSSL.SSL.Error as e:
|
except OpenSSL.SSL.Error as e:
|
||||||
raise ssl.SSLError("read error: %r" % e)
|
raise ssl.SSLError(f"read error: {e!r}") from e
|
||||||
|
|
||||||
def settimeout(self, timeout):
|
def settimeout(self, timeout: float) -> None:
|
||||||
return self.socket.settimeout(timeout)
|
return self.socket.settimeout(timeout)
|
||||||
|
|
||||||
def _send_until_done(self, data):
|
def _send_until_done(self, data: bytes) -> int:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
return self.connection.send(data)
|
return self.connection.send(data) # type: ignore[no-any-return]
|
||||||
except OpenSSL.SSL.WantWriteError:
|
except OpenSSL.SSL.WantWriteError as e:
|
||||||
if not util.wait_for_write(self.socket, self.socket.gettimeout()):
|
if not util.wait_for_write(self.socket, self.socket.gettimeout()):
|
||||||
raise timeout()
|
raise timeout() from e
|
||||||
continue
|
continue
|
||||||
except OpenSSL.SSL.SysCallError as e:
|
except OpenSSL.SSL.SysCallError as e:
|
||||||
raise SocketError(str(e))
|
raise OSError(e.args[0], str(e)) from e
|
||||||
|
|
||||||
def sendall(self, data):
|
def sendall(self, data: bytes) -> None:
|
||||||
total_sent = 0
|
total_sent = 0
|
||||||
while total_sent < len(data):
|
while total_sent < len(data):
|
||||||
sent = self._send_until_done(
|
sent = self._send_until_done(
|
||||||
|
@ -367,135 +375,135 @@ class WrappedSocket(object):
|
||||||
)
|
)
|
||||||
total_sent += sent
|
total_sent += sent
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self) -> None:
|
||||||
# FIXME rethrow compatible exceptions should we ever use this
|
# FIXME rethrow compatible exceptions should we ever use this
|
||||||
self.connection.shutdown()
|
self.connection.shutdown()
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
if self._makefile_refs < 1:
|
self._closed = True
|
||||||
try:
|
if self._io_refs <= 0:
|
||||||
self._closed = True
|
self._real_close()
|
||||||
return self.connection.close()
|
|
||||||
except OpenSSL.SSL.Error:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self._makefile_refs -= 1
|
|
||||||
|
|
||||||
def getpeercert(self, binary_form=False):
|
def _real_close(self) -> None:
|
||||||
|
try:
|
||||||
|
return self.connection.close() # type: ignore[no-any-return]
|
||||||
|
except OpenSSL.SSL.Error:
|
||||||
|
return
|
||||||
|
|
||||||
|
def getpeercert(
|
||||||
|
self, binary_form: bool = False
|
||||||
|
) -> dict[str, list[typing.Any]] | None:
|
||||||
x509 = self.connection.get_peer_certificate()
|
x509 = self.connection.get_peer_certificate()
|
||||||
|
|
||||||
if not x509:
|
if not x509:
|
||||||
return x509
|
return x509 # type: ignore[no-any-return]
|
||||||
|
|
||||||
if binary_form:
|
if binary_form:
|
||||||
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509)
|
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) # type: ignore[no-any-return]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"subject": ((("commonName", x509.get_subject().CN),),),
|
"subject": ((("commonName", x509.get_subject().CN),),), # type: ignore[dict-item]
|
||||||
"subjectAltName": get_subj_alt_name(x509),
|
"subjectAltName": get_subj_alt_name(x509),
|
||||||
}
|
}
|
||||||
|
|
||||||
def version(self):
|
def version(self) -> str:
|
||||||
return self.connection.get_protocol_version_name()
|
return self.connection.get_protocol_version_name() # type: ignore[no-any-return]
|
||||||
|
|
||||||
def _reuse(self):
|
|
||||||
self._makefile_refs += 1
|
|
||||||
|
|
||||||
def _drop(self):
|
|
||||||
if self._makefile_refs < 1:
|
|
||||||
self.close()
|
|
||||||
else:
|
|
||||||
self._makefile_refs -= 1
|
|
||||||
|
|
||||||
|
|
||||||
if _fileobject: # Platform-specific: Python 2
|
WrappedSocket.makefile = socket_cls.makefile # type: ignore[attr-defined]
|
||||||
|
|
||||||
def makefile(self, mode, bufsize=-1):
|
|
||||||
self._makefile_refs += 1
|
|
||||||
return _fileobject(self, mode, bufsize, close=True)
|
|
||||||
|
|
||||||
else: # Platform-specific: Python 3
|
|
||||||
makefile = backport_makefile
|
|
||||||
|
|
||||||
WrappedSocket.makefile = makefile
|
|
||||||
|
|
||||||
|
|
||||||
class PyOpenSSLContext(object):
|
class PyOpenSSLContext:
|
||||||
"""
|
"""
|
||||||
I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible
|
I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible
|
||||||
for translating the interface of the standard library ``SSLContext`` object
|
for translating the interface of the standard library ``SSLContext`` object
|
||||||
to calls into PyOpenSSL.
|
to calls into PyOpenSSL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, protocol):
|
def __init__(self, protocol: int) -> None:
|
||||||
self.protocol = _openssl_versions[protocol]
|
self.protocol = _openssl_versions[protocol]
|
||||||
self._ctx = OpenSSL.SSL.Context(self.protocol)
|
self._ctx = OpenSSL.SSL.Context(self.protocol)
|
||||||
self._options = 0
|
self._options = 0
|
||||||
self.check_hostname = False
|
self.check_hostname = False
|
||||||
|
self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED
|
||||||
|
self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def options(self):
|
def options(self) -> int:
|
||||||
return self._options
|
return self._options
|
||||||
|
|
||||||
@options.setter
|
@options.setter
|
||||||
def options(self, value):
|
def options(self, value: int) -> None:
|
||||||
self._options = value
|
self._options = value
|
||||||
self._ctx.set_options(value)
|
self._set_ctx_options()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def verify_mode(self):
|
def verify_mode(self) -> int:
|
||||||
return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()]
|
return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()]
|
||||||
|
|
||||||
@verify_mode.setter
|
@verify_mode.setter
|
||||||
def verify_mode(self, value):
|
def verify_mode(self, value: ssl.VerifyMode) -> None:
|
||||||
self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
|
self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
|
||||||
|
|
||||||
def set_default_verify_paths(self):
|
def set_default_verify_paths(self) -> None:
|
||||||
self._ctx.set_default_verify_paths()
|
self._ctx.set_default_verify_paths()
|
||||||
|
|
||||||
def set_ciphers(self, ciphers):
|
def set_ciphers(self, ciphers: bytes | str) -> None:
|
||||||
if isinstance(ciphers, six.text_type):
|
if isinstance(ciphers, str):
|
||||||
ciphers = ciphers.encode("utf-8")
|
ciphers = ciphers.encode("utf-8")
|
||||||
self._ctx.set_cipher_list(ciphers)
|
self._ctx.set_cipher_list(ciphers)
|
||||||
|
|
||||||
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
|
def load_verify_locations(
|
||||||
|
self,
|
||||||
|
cafile: str | None = None,
|
||||||
|
capath: str | None = None,
|
||||||
|
cadata: bytes | None = None,
|
||||||
|
) -> None:
|
||||||
if cafile is not None:
|
if cafile is not None:
|
||||||
cafile = cafile.encode("utf-8")
|
cafile = cafile.encode("utf-8") # type: ignore[assignment]
|
||||||
if capath is not None:
|
if capath is not None:
|
||||||
capath = capath.encode("utf-8")
|
capath = capath.encode("utf-8") # type: ignore[assignment]
|
||||||
try:
|
try:
|
||||||
self._ctx.load_verify_locations(cafile, capath)
|
self._ctx.load_verify_locations(cafile, capath)
|
||||||
if cadata is not None:
|
if cadata is not None:
|
||||||
self._ctx.load_verify_locations(BytesIO(cadata))
|
self._ctx.load_verify_locations(BytesIO(cadata))
|
||||||
except OpenSSL.SSL.Error as e:
|
except OpenSSL.SSL.Error as e:
|
||||||
raise ssl.SSLError("unable to load trusted certificates: %r" % e)
|
raise ssl.SSLError(f"unable to load trusted certificates: {e!r}") from e
|
||||||
|
|
||||||
def load_cert_chain(self, certfile, keyfile=None, password=None):
|
def load_cert_chain(
|
||||||
self._ctx.use_certificate_chain_file(certfile)
|
self,
|
||||||
if password is not None:
|
certfile: str,
|
||||||
if not isinstance(password, six.binary_type):
|
keyfile: str | None = None,
|
||||||
password = password.encode("utf-8")
|
password: str | None = None,
|
||||||
self._ctx.set_passwd_cb(lambda *_: password)
|
) -> None:
|
||||||
self._ctx.use_privatekey_file(keyfile or certfile)
|
try:
|
||||||
|
self._ctx.use_certificate_chain_file(certfile)
|
||||||
|
if password is not None:
|
||||||
|
if not isinstance(password, bytes):
|
||||||
|
password = password.encode("utf-8") # type: ignore[assignment]
|
||||||
|
self._ctx.set_passwd_cb(lambda *_: password)
|
||||||
|
self._ctx.use_privatekey_file(keyfile or certfile)
|
||||||
|
except OpenSSL.SSL.Error as e:
|
||||||
|
raise ssl.SSLError(f"Unable to load certificate chain: {e!r}") from e
|
||||||
|
|
||||||
def set_alpn_protocols(self, protocols):
|
def set_alpn_protocols(self, protocols: list[bytes | str]) -> None:
|
||||||
protocols = [six.ensure_binary(p) for p in protocols]
|
protocols = [util.util.to_bytes(p, "ascii") for p in protocols]
|
||||||
return self._ctx.set_alpn_protos(protocols)
|
return self._ctx.set_alpn_protos(protocols) # type: ignore[no-any-return]
|
||||||
|
|
||||||
def wrap_socket(
|
def wrap_socket(
|
||||||
self,
|
self,
|
||||||
sock,
|
sock: socket_cls,
|
||||||
server_side=False,
|
server_side: bool = False,
|
||||||
do_handshake_on_connect=True,
|
do_handshake_on_connect: bool = True,
|
||||||
suppress_ragged_eofs=True,
|
suppress_ragged_eofs: bool = True,
|
||||||
server_hostname=None,
|
server_hostname: bytes | str | None = None,
|
||||||
):
|
) -> WrappedSocket:
|
||||||
cnx = OpenSSL.SSL.Connection(self._ctx, sock)
|
cnx = OpenSSL.SSL.Connection(self._ctx, sock)
|
||||||
|
|
||||||
if isinstance(server_hostname, six.text_type): # Platform-specific: Python 3
|
# If server_hostname is an IP, don't use it for SNI, per RFC6066 Section 3
|
||||||
server_hostname = server_hostname.encode("utf-8")
|
if server_hostname and not util.ssl_.is_ipaddress(server_hostname):
|
||||||
|
if isinstance(server_hostname, str):
|
||||||
if server_hostname is not None:
|
server_hostname = server_hostname.encode("utf-8")
|
||||||
cnx.set_tlsext_host_name(server_hostname)
|
cnx.set_tlsext_host_name(server_hostname)
|
||||||
|
|
||||||
cnx.set_connect_state()
|
cnx.set_connect_state()
|
||||||
|
@ -503,16 +511,47 @@ class PyOpenSSLContext(object):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
cnx.do_handshake()
|
cnx.do_handshake()
|
||||||
except OpenSSL.SSL.WantReadError:
|
except OpenSSL.SSL.WantReadError as e:
|
||||||
if not util.wait_for_read(sock, sock.gettimeout()):
|
if not util.wait_for_read(sock, sock.gettimeout()):
|
||||||
raise timeout("select timed out")
|
raise timeout("select timed out") from e
|
||||||
continue
|
continue
|
||||||
except OpenSSL.SSL.Error as e:
|
except OpenSSL.SSL.Error as e:
|
||||||
raise ssl.SSLError("bad handshake: %r" % e)
|
raise ssl.SSLError(f"bad handshake: {e!r}") from e
|
||||||
break
|
break
|
||||||
|
|
||||||
return WrappedSocket(cnx, sock)
|
return WrappedSocket(cnx, sock)
|
||||||
|
|
||||||
|
def _set_ctx_options(self) -> None:
|
||||||
|
self._ctx.set_options(
|
||||||
|
self._options
|
||||||
|
| _openssl_to_ssl_minimum_version[self._minimum_version]
|
||||||
|
| _openssl_to_ssl_maximum_version[self._maximum_version]
|
||||||
|
)
|
||||||
|
|
||||||
def _verify_callback(cnx, x509, err_no, err_depth, return_code):
|
@property
|
||||||
|
def minimum_version(self) -> int:
|
||||||
|
return self._minimum_version
|
||||||
|
|
||||||
|
@minimum_version.setter
|
||||||
|
def minimum_version(self, minimum_version: int) -> None:
|
||||||
|
self._minimum_version = minimum_version
|
||||||
|
self._set_ctx_options()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def maximum_version(self) -> int:
|
||||||
|
return self._maximum_version
|
||||||
|
|
||||||
|
@maximum_version.setter
|
||||||
|
def maximum_version(self, maximum_version: int) -> None:
|
||||||
|
self._maximum_version = maximum_version
|
||||||
|
self._set_ctx_options()
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_callback(
|
||||||
|
cnx: OpenSSL.SSL.Connection,
|
||||||
|
x509: X509,
|
||||||
|
err_no: int,
|
||||||
|
err_depth: int,
|
||||||
|
return_code: int,
|
||||||
|
) -> bool:
|
||||||
return err_no == 0
|
return err_no == 0
|
||||||
|
|
|
@ -51,7 +51,8 @@ license and by oscrypto's:
|
||||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
DEALINGS IN THE SOFTWARE.
|
DEALINGS IN THE SOFTWARE.
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import ctypes
|
import ctypes
|
||||||
|
@ -62,14 +63,18 @@ import socket
|
||||||
import ssl
|
import ssl
|
||||||
import struct
|
import struct
|
||||||
import threading
|
import threading
|
||||||
|
import typing
|
||||||
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
|
from socket import socket as socket_cls
|
||||||
import six
|
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..util.ssl_ import PROTOCOL_TLS_CLIENT
|
from ._securetransport.bindings import ( # type: ignore[attr-defined]
|
||||||
from ._securetransport.bindings import CoreFoundation, Security, SecurityConst
|
CoreFoundation,
|
||||||
|
Security,
|
||||||
|
)
|
||||||
from ._securetransport.low_level import (
|
from ._securetransport.low_level import (
|
||||||
|
SecurityConst,
|
||||||
_assert_no_error,
|
_assert_no_error,
|
||||||
_build_tls_unknown_ca_alert,
|
_build_tls_unknown_ca_alert,
|
||||||
_cert_array_from_pem,
|
_cert_array_from_pem,
|
||||||
|
@ -78,18 +83,19 @@ from ._securetransport.low_level import (
|
||||||
_temporary_keychain,
|
_temporary_keychain,
|
||||||
)
|
)
|
||||||
|
|
||||||
try: # Platform-specific: Python 2
|
warnings.warn(
|
||||||
from socket import _fileobject
|
"'urllib3.contrib.securetransport' module is deprecated and will be removed "
|
||||||
except ImportError: # Platform-specific: Python 3
|
"in urllib3 v2.1.0. Read more in this issue: "
|
||||||
_fileobject = None
|
"https://github.com/urllib3/urllib3/issues/2681",
|
||||||
from ..packages.backports.makefile import backport_makefile
|
category=DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
|
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
|
||||||
|
|
||||||
# SNI always works
|
|
||||||
HAS_SNI = True
|
|
||||||
|
|
||||||
orig_util_HAS_SNI = util.HAS_SNI
|
|
||||||
orig_util_SSLContext = util.ssl_.SSLContext
|
orig_util_SSLContext = util.ssl_.SSLContext
|
||||||
|
|
||||||
# This dictionary is used by the read callback to obtain a handle to the
|
# This dictionary is used by the read callback to obtain a handle to the
|
||||||
|
@ -108,55 +114,24 @@ orig_util_SSLContext = util.ssl_.SSLContext
|
||||||
#
|
#
|
||||||
# This is good: if we had to lock in the callbacks we'd drastically slow down
|
# This is good: if we had to lock in the callbacks we'd drastically slow down
|
||||||
# the performance of this code.
|
# the performance of this code.
|
||||||
_connection_refs = weakref.WeakValueDictionary()
|
_connection_refs: weakref.WeakValueDictionary[
|
||||||
|
int, WrappedSocket
|
||||||
|
] = weakref.WeakValueDictionary()
|
||||||
_connection_ref_lock = threading.Lock()
|
_connection_ref_lock = threading.Lock()
|
||||||
|
|
||||||
# Limit writes to 16kB. This is OpenSSL's limit, but we'll cargo-cult it over
|
# Limit writes to 16kB. This is OpenSSL's limit, but we'll cargo-cult it over
|
||||||
# for no better reason than we need *a* limit, and this one is right there.
|
# for no better reason than we need *a* limit, and this one is right there.
|
||||||
SSL_WRITE_BLOCKSIZE = 16384
|
SSL_WRITE_BLOCKSIZE = 16384
|
||||||
|
|
||||||
# This is our equivalent of util.ssl_.DEFAULT_CIPHERS, but expanded out to
|
|
||||||
# individual cipher suites. We need to do this because this is how
|
|
||||||
# SecureTransport wants them.
|
|
||||||
CIPHER_SUITES = [
|
|
||||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
|
||||||
SecurityConst.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
|
||||||
SecurityConst.TLS_DHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
SecurityConst.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
|
|
||||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
|
||||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
|
||||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
|
||||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
|
|
||||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
|
||||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
|
||||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
|
||||||
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
|
|
||||||
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
|
|
||||||
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
|
|
||||||
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
|
|
||||||
SecurityConst.TLS_AES_256_GCM_SHA384,
|
|
||||||
SecurityConst.TLS_AES_128_GCM_SHA256,
|
|
||||||
SecurityConst.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
SecurityConst.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
SecurityConst.TLS_AES_128_CCM_8_SHA256,
|
|
||||||
SecurityConst.TLS_AES_128_CCM_SHA256,
|
|
||||||
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA256,
|
|
||||||
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
|
||||||
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA,
|
|
||||||
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of
|
# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of
|
||||||
# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version.
|
# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version.
|
||||||
# TLSv1 to 1.2 are supported on macOS 10.8+
|
# TLSv1 to 1.2 are supported on macOS 10.8+
|
||||||
_protocol_to_min_max = {
|
_protocol_to_min_max = {
|
||||||
util.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12),
|
util.ssl_.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12), # type: ignore[attr-defined]
|
||||||
PROTOCOL_TLS_CLIENT: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12),
|
util.ssl_.PROTOCOL_TLS_CLIENT: ( # type: ignore[attr-defined]
|
||||||
|
SecurityConst.kTLSProtocol1,
|
||||||
|
SecurityConst.kTLSProtocol12,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasattr(ssl, "PROTOCOL_SSLv2"):
|
if hasattr(ssl, "PROTOCOL_SSLv2"):
|
||||||
|
@ -186,31 +161,38 @@ if hasattr(ssl, "PROTOCOL_TLSv1_2"):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def inject_into_urllib3():
|
_tls_version_to_st: dict[int, int] = {
|
||||||
|
ssl.TLSVersion.MINIMUM_SUPPORTED: SecurityConst.kTLSProtocol1,
|
||||||
|
ssl.TLSVersion.TLSv1: SecurityConst.kTLSProtocol1,
|
||||||
|
ssl.TLSVersion.TLSv1_1: SecurityConst.kTLSProtocol11,
|
||||||
|
ssl.TLSVersion.TLSv1_2: SecurityConst.kTLSProtocol12,
|
||||||
|
ssl.TLSVersion.MAXIMUM_SUPPORTED: SecurityConst.kTLSProtocol12,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def inject_into_urllib3() -> None:
|
||||||
"""
|
"""
|
||||||
Monkey-patch urllib3 with SecureTransport-backed SSL-support.
|
Monkey-patch urllib3 with SecureTransport-backed SSL-support.
|
||||||
"""
|
"""
|
||||||
util.SSLContext = SecureTransportContext
|
util.SSLContext = SecureTransportContext # type: ignore[assignment]
|
||||||
util.ssl_.SSLContext = SecureTransportContext
|
util.ssl_.SSLContext = SecureTransportContext # type: ignore[assignment]
|
||||||
util.HAS_SNI = HAS_SNI
|
|
||||||
util.ssl_.HAS_SNI = HAS_SNI
|
|
||||||
util.IS_SECURETRANSPORT = True
|
util.IS_SECURETRANSPORT = True
|
||||||
util.ssl_.IS_SECURETRANSPORT = True
|
util.ssl_.IS_SECURETRANSPORT = True
|
||||||
|
|
||||||
|
|
||||||
def extract_from_urllib3():
|
def extract_from_urllib3() -> None:
|
||||||
"""
|
"""
|
||||||
Undo monkey-patching by :func:`inject_into_urllib3`.
|
Undo monkey-patching by :func:`inject_into_urllib3`.
|
||||||
"""
|
"""
|
||||||
util.SSLContext = orig_util_SSLContext
|
util.SSLContext = orig_util_SSLContext
|
||||||
util.ssl_.SSLContext = orig_util_SSLContext
|
util.ssl_.SSLContext = orig_util_SSLContext
|
||||||
util.HAS_SNI = orig_util_HAS_SNI
|
|
||||||
util.ssl_.HAS_SNI = orig_util_HAS_SNI
|
|
||||||
util.IS_SECURETRANSPORT = False
|
util.IS_SECURETRANSPORT = False
|
||||||
util.ssl_.IS_SECURETRANSPORT = False
|
util.ssl_.IS_SECURETRANSPORT = False
|
||||||
|
|
||||||
|
|
||||||
def _read_callback(connection_id, data_buffer, data_length_pointer):
|
def _read_callback(
|
||||||
|
connection_id: int, data_buffer: int, data_length_pointer: bytearray
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
SecureTransport read callback. This is called by ST to request that data
|
SecureTransport read callback. This is called by ST to request that data
|
||||||
be returned from the socket.
|
be returned from the socket.
|
||||||
|
@ -232,7 +214,7 @@ def _read_callback(connection_id, data_buffer, data_length_pointer):
|
||||||
while read_count < requested_length:
|
while read_count < requested_length:
|
||||||
if timeout is None or timeout >= 0:
|
if timeout is None or timeout >= 0:
|
||||||
if not util.wait_for_read(base_socket, timeout):
|
if not util.wait_for_read(base_socket, timeout):
|
||||||
raise socket.error(errno.EAGAIN, "timed out")
|
raise OSError(errno.EAGAIN, "timed out")
|
||||||
|
|
||||||
remaining = requested_length - read_count
|
remaining = requested_length - read_count
|
||||||
buffer = (ctypes.c_char * remaining).from_address(
|
buffer = (ctypes.c_char * remaining).from_address(
|
||||||
|
@ -244,7 +226,7 @@ def _read_callback(connection_id, data_buffer, data_length_pointer):
|
||||||
if not read_count:
|
if not read_count:
|
||||||
return SecurityConst.errSSLClosedGraceful
|
return SecurityConst.errSSLClosedGraceful
|
||||||
break
|
break
|
||||||
except (socket.error) as e:
|
except OSError as e:
|
||||||
error = e.errno
|
error = e.errno
|
||||||
|
|
||||||
if error is not None and error != errno.EAGAIN:
|
if error is not None and error != errno.EAGAIN:
|
||||||
|
@ -265,7 +247,9 @@ def _read_callback(connection_id, data_buffer, data_length_pointer):
|
||||||
return SecurityConst.errSSLInternal
|
return SecurityConst.errSSLInternal
|
||||||
|
|
||||||
|
|
||||||
def _write_callback(connection_id, data_buffer, data_length_pointer):
|
def _write_callback(
|
||||||
|
connection_id: int, data_buffer: int, data_length_pointer: bytearray
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
SecureTransport write callback. This is called by ST to request that data
|
SecureTransport write callback. This is called by ST to request that data
|
||||||
actually be sent on the network.
|
actually be sent on the network.
|
||||||
|
@ -288,14 +272,14 @@ def _write_callback(connection_id, data_buffer, data_length_pointer):
|
||||||
while sent < bytes_to_write:
|
while sent < bytes_to_write:
|
||||||
if timeout is None or timeout >= 0:
|
if timeout is None or timeout >= 0:
|
||||||
if not util.wait_for_write(base_socket, timeout):
|
if not util.wait_for_write(base_socket, timeout):
|
||||||
raise socket.error(errno.EAGAIN, "timed out")
|
raise OSError(errno.EAGAIN, "timed out")
|
||||||
chunk_sent = base_socket.send(data)
|
chunk_sent = base_socket.send(data)
|
||||||
sent += chunk_sent
|
sent += chunk_sent
|
||||||
|
|
||||||
# This has some needless copying here, but I'm not sure there's
|
# This has some needless copying here, but I'm not sure there's
|
||||||
# much value in optimising this data path.
|
# much value in optimising this data path.
|
||||||
data = data[chunk_sent:]
|
data = data[chunk_sent:]
|
||||||
except (socket.error) as e:
|
except OSError as e:
|
||||||
error = e.errno
|
error = e.errno
|
||||||
|
|
||||||
if error is not None and error != errno.EAGAIN:
|
if error is not None and error != errno.EAGAIN:
|
||||||
|
@ -323,22 +307,20 @@ _read_callback_pointer = Security.SSLReadFunc(_read_callback)
|
||||||
_write_callback_pointer = Security.SSLWriteFunc(_write_callback)
|
_write_callback_pointer = Security.SSLWriteFunc(_write_callback)
|
||||||
|
|
||||||
|
|
||||||
class WrappedSocket(object):
|
class WrappedSocket:
|
||||||
"""
|
"""
|
||||||
API-compatibility wrapper for Python's OpenSSL wrapped socket object.
|
API-compatibility wrapper for Python's OpenSSL wrapped socket object.
|
||||||
|
|
||||||
Note: _makefile_refs, _drop(), and _reuse() are needed for the garbage
|
|
||||||
collector of PyPy.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, socket):
|
def __init__(self, socket: socket_cls) -> None:
|
||||||
self.socket = socket
|
self.socket = socket
|
||||||
self.context = None
|
self.context = None
|
||||||
self._makefile_refs = 0
|
self._io_refs = 0
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._exception = None
|
self._real_closed = False
|
||||||
|
self._exception: Exception | None = None
|
||||||
self._keychain = None
|
self._keychain = None
|
||||||
self._keychain_dir = None
|
self._keychain_dir: str | None = None
|
||||||
self._client_cert_chain = None
|
self._client_cert_chain = None
|
||||||
|
|
||||||
# We save off the previously-configured timeout and then set it to
|
# We save off the previously-configured timeout and then set it to
|
||||||
|
@ -350,7 +332,7 @@ class WrappedSocket(object):
|
||||||
self.socket.settimeout(0)
|
self.socket.settimeout(0)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _raise_on_error(self):
|
def _raise_on_error(self) -> typing.Generator[None, None, None]:
|
||||||
"""
|
"""
|
||||||
A context manager that can be used to wrap calls that do I/O from
|
A context manager that can be used to wrap calls that do I/O from
|
||||||
SecureTransport. If any of the I/O callbacks hit an exception, this
|
SecureTransport. If any of the I/O callbacks hit an exception, this
|
||||||
|
@ -367,23 +349,10 @@ class WrappedSocket(object):
|
||||||
yield
|
yield
|
||||||
if self._exception is not None:
|
if self._exception is not None:
|
||||||
exception, self._exception = self._exception, None
|
exception, self._exception = self._exception, None
|
||||||
self.close()
|
self._real_close()
|
||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
def _set_ciphers(self):
|
def _set_alpn_protocols(self, protocols: list[bytes] | None) -> None:
|
||||||
"""
|
|
||||||
Sets up the allowed ciphers. By default this matches the set in
|
|
||||||
util.ssl_.DEFAULT_CIPHERS, at least as supported by macOS. This is done
|
|
||||||
custom and doesn't allow changing at this time, mostly because parsing
|
|
||||||
OpenSSL cipher strings is going to be a freaking nightmare.
|
|
||||||
"""
|
|
||||||
ciphers = (Security.SSLCipherSuite * len(CIPHER_SUITES))(*CIPHER_SUITES)
|
|
||||||
result = Security.SSLSetEnabledCiphers(
|
|
||||||
self.context, ciphers, len(CIPHER_SUITES)
|
|
||||||
)
|
|
||||||
_assert_no_error(result)
|
|
||||||
|
|
||||||
def _set_alpn_protocols(self, protocols):
|
|
||||||
"""
|
"""
|
||||||
Sets up the ALPN protocols on the context.
|
Sets up the ALPN protocols on the context.
|
||||||
"""
|
"""
|
||||||
|
@ -396,7 +365,7 @@ class WrappedSocket(object):
|
||||||
finally:
|
finally:
|
||||||
CoreFoundation.CFRelease(protocols_arr)
|
CoreFoundation.CFRelease(protocols_arr)
|
||||||
|
|
||||||
def _custom_validate(self, verify, trust_bundle):
|
def _custom_validate(self, verify: bool, trust_bundle: bytes | None) -> None:
|
||||||
"""
|
"""
|
||||||
Called when we have set custom validation. We do this in two cases:
|
Called when we have set custom validation. We do this in two cases:
|
||||||
first, when cert validation is entirely disabled; and second, when
|
first, when cert validation is entirely disabled; and second, when
|
||||||
|
@ -404,7 +373,7 @@ class WrappedSocket(object):
|
||||||
Raises an SSLError if the connection is not trusted.
|
Raises an SSLError if the connection is not trusted.
|
||||||
"""
|
"""
|
||||||
# If we disabled cert validation, just say: cool.
|
# If we disabled cert validation, just say: cool.
|
||||||
if not verify:
|
if not verify or trust_bundle is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
successes = (
|
successes = (
|
||||||
|
@ -415,10 +384,12 @@ class WrappedSocket(object):
|
||||||
trust_result = self._evaluate_trust(trust_bundle)
|
trust_result = self._evaluate_trust(trust_bundle)
|
||||||
if trust_result in successes:
|
if trust_result in successes:
|
||||||
return
|
return
|
||||||
reason = "error code: %d" % (trust_result,)
|
reason = f"error code: {int(trust_result)}"
|
||||||
|
exc = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Do not trust on error
|
# Do not trust on error
|
||||||
reason = "exception: %r" % (e,)
|
reason = f"exception: {e!r}"
|
||||||
|
exc = e
|
||||||
|
|
||||||
# SecureTransport does not send an alert nor shuts down the connection.
|
# SecureTransport does not send an alert nor shuts down the connection.
|
||||||
rec = _build_tls_unknown_ca_alert(self.version())
|
rec = _build_tls_unknown_ca_alert(self.version())
|
||||||
|
@ -428,10 +399,10 @@ class WrappedSocket(object):
|
||||||
# l_linger = 0, linger for 0 seoncds
|
# l_linger = 0, linger for 0 seoncds
|
||||||
opts = struct.pack("ii", 1, 0)
|
opts = struct.pack("ii", 1, 0)
|
||||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts)
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts)
|
||||||
self.close()
|
self._real_close()
|
||||||
raise ssl.SSLError("certificate verify failed, %s" % reason)
|
raise ssl.SSLError(f"certificate verify failed, {reason}") from exc
|
||||||
|
|
||||||
def _evaluate_trust(self, trust_bundle):
|
def _evaluate_trust(self, trust_bundle: bytes) -> int:
|
||||||
# We want data in memory, so load it up.
|
# We want data in memory, so load it up.
|
||||||
if os.path.isfile(trust_bundle):
|
if os.path.isfile(trust_bundle):
|
||||||
with open(trust_bundle, "rb") as f:
|
with open(trust_bundle, "rb") as f:
|
||||||
|
@ -469,20 +440,20 @@ class WrappedSocket(object):
|
||||||
if cert_array is not None:
|
if cert_array is not None:
|
||||||
CoreFoundation.CFRelease(cert_array)
|
CoreFoundation.CFRelease(cert_array)
|
||||||
|
|
||||||
return trust_result.value
|
return trust_result.value # type: ignore[no-any-return]
|
||||||
|
|
||||||
def handshake(
|
def handshake(
|
||||||
self,
|
self,
|
||||||
server_hostname,
|
server_hostname: bytes | str | None,
|
||||||
verify,
|
verify: bool,
|
||||||
trust_bundle,
|
trust_bundle: bytes | None,
|
||||||
min_version,
|
min_version: int,
|
||||||
max_version,
|
max_version: int,
|
||||||
client_cert,
|
client_cert: str | None,
|
||||||
client_key,
|
client_key: str | None,
|
||||||
client_key_passphrase,
|
client_key_passphrase: typing.Any,
|
||||||
alpn_protocols,
|
alpn_protocols: list[bytes] | None,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Actually performs the TLS handshake. This is run automatically by
|
Actually performs the TLS handshake. This is run automatically by
|
||||||
wrapped socket, and shouldn't be needed in user code.
|
wrapped socket, and shouldn't be needed in user code.
|
||||||
|
@ -510,6 +481,8 @@ class WrappedSocket(object):
|
||||||
_assert_no_error(result)
|
_assert_no_error(result)
|
||||||
|
|
||||||
# If we have a server hostname, we should set that too.
|
# If we have a server hostname, we should set that too.
|
||||||
|
# RFC6066 Section 3 tells us not to use SNI when the host is an IP, but we have
|
||||||
|
# to do it anyway to match server_hostname against the server certificate
|
||||||
if server_hostname:
|
if server_hostname:
|
||||||
if not isinstance(server_hostname, bytes):
|
if not isinstance(server_hostname, bytes):
|
||||||
server_hostname = server_hostname.encode("utf-8")
|
server_hostname = server_hostname.encode("utf-8")
|
||||||
|
@ -519,9 +492,6 @@ class WrappedSocket(object):
|
||||||
)
|
)
|
||||||
_assert_no_error(result)
|
_assert_no_error(result)
|
||||||
|
|
||||||
# Setup the ciphers.
|
|
||||||
self._set_ciphers()
|
|
||||||
|
|
||||||
# Setup the ALPN protocols.
|
# Setup the ALPN protocols.
|
||||||
self._set_alpn_protocols(alpn_protocols)
|
self._set_alpn_protocols(alpn_protocols)
|
||||||
|
|
||||||
|
@ -564,25 +534,27 @@ class WrappedSocket(object):
|
||||||
_assert_no_error(result)
|
_assert_no_error(result)
|
||||||
break
|
break
|
||||||
|
|
||||||
def fileno(self):
|
def fileno(self) -> int:
|
||||||
return self.socket.fileno()
|
return self.socket.fileno()
|
||||||
|
|
||||||
# Copy-pasted from Python 3.5 source code
|
# Copy-pasted from Python 3.5 source code
|
||||||
def _decref_socketios(self):
|
def _decref_socketios(self) -> None:
|
||||||
if self._makefile_refs > 0:
|
if self._io_refs > 0:
|
||||||
self._makefile_refs -= 1
|
self._io_refs -= 1
|
||||||
if self._closed:
|
if self._closed:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def recv(self, bufsiz):
|
def recv(self, bufsiz: int) -> bytes:
|
||||||
buffer = ctypes.create_string_buffer(bufsiz)
|
buffer = ctypes.create_string_buffer(bufsiz)
|
||||||
bytes_read = self.recv_into(buffer, bufsiz)
|
bytes_read = self.recv_into(buffer, bufsiz)
|
||||||
data = buffer[:bytes_read]
|
data = buffer[:bytes_read]
|
||||||
return data
|
return typing.cast(bytes, data)
|
||||||
|
|
||||||
def recv_into(self, buffer, nbytes=None):
|
def recv_into(
|
||||||
|
self, buffer: ctypes.Array[ctypes.c_char], nbytes: int | None = None
|
||||||
|
) -> int:
|
||||||
# Read short on EOF.
|
# Read short on EOF.
|
||||||
if self._closed:
|
if self._real_closed:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if nbytes is None:
|
if nbytes is None:
|
||||||
|
@ -615,7 +587,7 @@ class WrappedSocket(object):
|
||||||
# well. Note that we don't actually return here because in
|
# well. Note that we don't actually return here because in
|
||||||
# principle this could actually be fired along with return data.
|
# principle this could actually be fired along with return data.
|
||||||
# It's unlikely though.
|
# It's unlikely though.
|
||||||
self.close()
|
self._real_close()
|
||||||
else:
|
else:
|
||||||
_assert_no_error(result)
|
_assert_no_error(result)
|
||||||
|
|
||||||
|
@ -623,13 +595,13 @@ class WrappedSocket(object):
|
||||||
# was actually read.
|
# was actually read.
|
||||||
return processed_bytes.value
|
return processed_bytes.value
|
||||||
|
|
||||||
def settimeout(self, timeout):
|
def settimeout(self, timeout: float) -> None:
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
|
|
||||||
def gettimeout(self):
|
def gettimeout(self) -> float | None:
|
||||||
return self._timeout
|
return self._timeout
|
||||||
|
|
||||||
def send(self, data):
|
def send(self, data: bytes) -> int:
|
||||||
processed_bytes = ctypes.c_size_t(0)
|
processed_bytes = ctypes.c_size_t(0)
|
||||||
|
|
||||||
with self._raise_on_error():
|
with self._raise_on_error():
|
||||||
|
@ -646,36 +618,38 @@ class WrappedSocket(object):
|
||||||
# We sent, and probably succeeded. Tell them how much we sent.
|
# We sent, and probably succeeded. Tell them how much we sent.
|
||||||
return processed_bytes.value
|
return processed_bytes.value
|
||||||
|
|
||||||
def sendall(self, data):
|
def sendall(self, data: bytes) -> None:
|
||||||
total_sent = 0
|
total_sent = 0
|
||||||
while total_sent < len(data):
|
while total_sent < len(data):
|
||||||
sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE])
|
sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE])
|
||||||
total_sent += sent
|
total_sent += sent
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self) -> None:
|
||||||
with self._raise_on_error():
|
with self._raise_on_error():
|
||||||
Security.SSLClose(self.context)
|
Security.SSLClose(self.context)
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
|
self._closed = True
|
||||||
# TODO: should I do clean shutdown here? Do I have to?
|
# TODO: should I do clean shutdown here? Do I have to?
|
||||||
if self._makefile_refs < 1:
|
if self._io_refs <= 0:
|
||||||
self._closed = True
|
self._real_close()
|
||||||
if self.context:
|
|
||||||
CoreFoundation.CFRelease(self.context)
|
|
||||||
self.context = None
|
|
||||||
if self._client_cert_chain:
|
|
||||||
CoreFoundation.CFRelease(self._client_cert_chain)
|
|
||||||
self._client_cert_chain = None
|
|
||||||
if self._keychain:
|
|
||||||
Security.SecKeychainDelete(self._keychain)
|
|
||||||
CoreFoundation.CFRelease(self._keychain)
|
|
||||||
shutil.rmtree(self._keychain_dir)
|
|
||||||
self._keychain = self._keychain_dir = None
|
|
||||||
return self.socket.close()
|
|
||||||
else:
|
|
||||||
self._makefile_refs -= 1
|
|
||||||
|
|
||||||
def getpeercert(self, binary_form=False):
|
def _real_close(self) -> None:
|
||||||
|
self._real_closed = True
|
||||||
|
if self.context:
|
||||||
|
CoreFoundation.CFRelease(self.context)
|
||||||
|
self.context = None
|
||||||
|
if self._client_cert_chain:
|
||||||
|
CoreFoundation.CFRelease(self._client_cert_chain)
|
||||||
|
self._client_cert_chain = None
|
||||||
|
if self._keychain:
|
||||||
|
Security.SecKeychainDelete(self._keychain)
|
||||||
|
CoreFoundation.CFRelease(self._keychain)
|
||||||
|
shutil.rmtree(self._keychain_dir)
|
||||||
|
self._keychain = self._keychain_dir = None
|
||||||
|
return self.socket.close()
|
||||||
|
|
||||||
|
def getpeercert(self, binary_form: bool = False) -> bytes | None:
|
||||||
# Urgh, annoying.
|
# Urgh, annoying.
|
||||||
#
|
#
|
||||||
# Here's how we do this:
|
# Here's how we do this:
|
||||||
|
@ -733,7 +707,7 @@ class WrappedSocket(object):
|
||||||
|
|
||||||
return der_bytes
|
return der_bytes
|
||||||
|
|
||||||
def version(self):
|
def version(self) -> str:
|
||||||
protocol = Security.SSLProtocol()
|
protocol = Security.SSLProtocol()
|
||||||
result = Security.SSLGetNegotiatedProtocolVersion(
|
result = Security.SSLGetNegotiatedProtocolVersion(
|
||||||
self.context, ctypes.byref(protocol)
|
self.context, ctypes.byref(protocol)
|
||||||
|
@ -752,55 +726,50 @@ class WrappedSocket(object):
|
||||||
elif protocol.value == SecurityConst.kSSLProtocol2:
|
elif protocol.value == SecurityConst.kSSLProtocol2:
|
||||||
return "SSLv2"
|
return "SSLv2"
|
||||||
else:
|
else:
|
||||||
raise ssl.SSLError("Unknown TLS version: %r" % protocol)
|
raise ssl.SSLError(f"Unknown TLS version: {protocol!r}")
|
||||||
|
|
||||||
def _reuse(self):
|
|
||||||
self._makefile_refs += 1
|
|
||||||
|
|
||||||
def _drop(self):
|
|
||||||
if self._makefile_refs < 1:
|
|
||||||
self.close()
|
|
||||||
else:
|
|
||||||
self._makefile_refs -= 1
|
|
||||||
|
|
||||||
|
|
||||||
if _fileobject: # Platform-specific: Python 2
|
def makefile(
|
||||||
|
self: socket_cls,
|
||||||
def makefile(self, mode, bufsize=-1):
|
mode: (
|
||||||
self._makefile_refs += 1
|
Literal["r"] | Literal["w"] | Literal["rw"] | Literal["wr"] | Literal[""]
|
||||||
return _fileobject(self, mode, bufsize, close=True)
|
) = "r",
|
||||||
|
buffering: int | None = None,
|
||||||
else: # Platform-specific: Python 3
|
*args: typing.Any,
|
||||||
|
**kwargs: typing.Any,
|
||||||
def makefile(self, mode="r", buffering=None, *args, **kwargs):
|
) -> typing.BinaryIO | typing.TextIO:
|
||||||
# We disable buffering with SecureTransport because it conflicts with
|
# We disable buffering with SecureTransport because it conflicts with
|
||||||
# the buffering that ST does internally (see issue #1153 for more).
|
# the buffering that ST does internally (see issue #1153 for more).
|
||||||
buffering = 0
|
buffering = 0
|
||||||
return backport_makefile(self, mode, buffering, *args, **kwargs)
|
return socket_cls.makefile(self, mode, buffering, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
WrappedSocket.makefile = makefile
|
WrappedSocket.makefile = makefile # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
class SecureTransportContext(object):
|
class SecureTransportContext:
|
||||||
"""
|
"""
|
||||||
I am a wrapper class for the SecureTransport library, to translate the
|
I am a wrapper class for the SecureTransport library, to translate the
|
||||||
interface of the standard library ``SSLContext`` object to calls into
|
interface of the standard library ``SSLContext`` object to calls into
|
||||||
SecureTransport.
|
SecureTransport.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, protocol):
|
def __init__(self, protocol: int) -> None:
|
||||||
self._min_version, self._max_version = _protocol_to_min_max[protocol]
|
self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED
|
||||||
|
self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED
|
||||||
|
if protocol not in (None, ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_CLIENT):
|
||||||
|
self._min_version, self._max_version = _protocol_to_min_max[protocol]
|
||||||
|
|
||||||
self._options = 0
|
self._options = 0
|
||||||
self._verify = False
|
self._verify = False
|
||||||
self._trust_bundle = None
|
self._trust_bundle: bytes | None = None
|
||||||
self._client_cert = None
|
self._client_cert: str | None = None
|
||||||
self._client_key = None
|
self._client_key: str | None = None
|
||||||
self._client_key_passphrase = None
|
self._client_key_passphrase = None
|
||||||
self._alpn_protocols = None
|
self._alpn_protocols: list[bytes] | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def check_hostname(self):
|
def check_hostname(self) -> Literal[True]:
|
||||||
"""
|
"""
|
||||||
SecureTransport cannot have its hostname checking disabled. For more,
|
SecureTransport cannot have its hostname checking disabled. For more,
|
||||||
see the comment on getpeercert() in this file.
|
see the comment on getpeercert() in this file.
|
||||||
|
@ -808,15 +777,14 @@ class SecureTransportContext(object):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@check_hostname.setter
|
@check_hostname.setter
|
||||||
def check_hostname(self, value):
|
def check_hostname(self, value: typing.Any) -> None:
|
||||||
"""
|
"""
|
||||||
SecureTransport cannot have its hostname checking disabled. For more,
|
SecureTransport cannot have its hostname checking disabled. For more,
|
||||||
see the comment on getpeercert() in this file.
|
see the comment on getpeercert() in this file.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def options(self):
|
def options(self) -> int:
|
||||||
# TODO: Well, crap.
|
# TODO: Well, crap.
|
||||||
#
|
#
|
||||||
# So this is the bit of the code that is the most likely to cause us
|
# So this is the bit of the code that is the most likely to cause us
|
||||||
|
@ -826,19 +794,19 @@ class SecureTransportContext(object):
|
||||||
return self._options
|
return self._options
|
||||||
|
|
||||||
@options.setter
|
@options.setter
|
||||||
def options(self, value):
|
def options(self, value: int) -> None:
|
||||||
# TODO: Update in line with above.
|
# TODO: Update in line with above.
|
||||||
self._options = value
|
self._options = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def verify_mode(self):
|
def verify_mode(self) -> int:
|
||||||
return ssl.CERT_REQUIRED if self._verify else ssl.CERT_NONE
|
return ssl.CERT_REQUIRED if self._verify else ssl.CERT_NONE
|
||||||
|
|
||||||
@verify_mode.setter
|
@verify_mode.setter
|
||||||
def verify_mode(self, value):
|
def verify_mode(self, value: int) -> None:
|
||||||
self._verify = True if value == ssl.CERT_REQUIRED else False
|
self._verify = value == ssl.CERT_REQUIRED
|
||||||
|
|
||||||
def set_default_verify_paths(self):
|
def set_default_verify_paths(self) -> None:
|
||||||
# So, this has to do something a bit weird. Specifically, what it does
|
# So, this has to do something a bit weird. Specifically, what it does
|
||||||
# is nothing.
|
# is nothing.
|
||||||
#
|
#
|
||||||
|
@ -850,15 +818,18 @@ class SecureTransportContext(object):
|
||||||
# ignoring it.
|
# ignoring it.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_default_certs(self):
|
def load_default_certs(self) -> None:
|
||||||
return self.set_default_verify_paths()
|
return self.set_default_verify_paths()
|
||||||
|
|
||||||
def set_ciphers(self, ciphers):
|
def set_ciphers(self, ciphers: typing.Any) -> None:
|
||||||
# For now, we just require the default cipher string.
|
raise ValueError("SecureTransport doesn't support custom cipher strings")
|
||||||
if ciphers != util.ssl_.DEFAULT_CIPHERS:
|
|
||||||
raise ValueError("SecureTransport doesn't support custom cipher strings")
|
|
||||||
|
|
||||||
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
|
def load_verify_locations(
|
||||||
|
self,
|
||||||
|
cafile: str | None = None,
|
||||||
|
capath: str | None = None,
|
||||||
|
cadata: bytes | None = None,
|
||||||
|
) -> None:
|
||||||
# OK, we only really support cadata and cafile.
|
# OK, we only really support cadata and cafile.
|
||||||
if capath is not None:
|
if capath is not None:
|
||||||
raise ValueError("SecureTransport does not support cert directories")
|
raise ValueError("SecureTransport does not support cert directories")
|
||||||
|
@ -868,14 +839,19 @@ class SecureTransportContext(object):
|
||||||
with open(cafile):
|
with open(cafile):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self._trust_bundle = cafile or cadata
|
self._trust_bundle = cafile or cadata # type: ignore[assignment]
|
||||||
|
|
||||||
def load_cert_chain(self, certfile, keyfile=None, password=None):
|
def load_cert_chain(
|
||||||
|
self,
|
||||||
|
certfile: str,
|
||||||
|
keyfile: str | None = None,
|
||||||
|
password: str | None = None,
|
||||||
|
) -> None:
|
||||||
self._client_cert = certfile
|
self._client_cert = certfile
|
||||||
self._client_key = keyfile
|
self._client_key = keyfile
|
||||||
self._client_cert_passphrase = password
|
self._client_cert_passphrase = password
|
||||||
|
|
||||||
def set_alpn_protocols(self, protocols):
|
def set_alpn_protocols(self, protocols: list[str | bytes]) -> None:
|
||||||
"""
|
"""
|
||||||
Sets the ALPN protocols that will later be set on the context.
|
Sets the ALPN protocols that will later be set on the context.
|
||||||
|
|
||||||
|
@ -885,16 +861,16 @@ class SecureTransportContext(object):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"SecureTransport supports ALPN only in macOS 10.12+"
|
"SecureTransport supports ALPN only in macOS 10.12+"
|
||||||
)
|
)
|
||||||
self._alpn_protocols = [six.ensure_binary(p) for p in protocols]
|
self._alpn_protocols = [util.util.to_bytes(p, "ascii") for p in protocols]
|
||||||
|
|
||||||
def wrap_socket(
|
def wrap_socket(
|
||||||
self,
|
self,
|
||||||
sock,
|
sock: socket_cls,
|
||||||
server_side=False,
|
server_side: bool = False,
|
||||||
do_handshake_on_connect=True,
|
do_handshake_on_connect: bool = True,
|
||||||
suppress_ragged_eofs=True,
|
suppress_ragged_eofs: bool = True,
|
||||||
server_hostname=None,
|
server_hostname: bytes | str | None = None,
|
||||||
):
|
) -> WrappedSocket:
|
||||||
# So, what do we do here? Firstly, we assert some properties. This is a
|
# So, what do we do here? Firstly, we assert some properties. This is a
|
||||||
# stripped down shim, so there is some functionality we don't support.
|
# stripped down shim, so there is some functionality we don't support.
|
||||||
# See PEP 543 for the real deal.
|
# See PEP 543 for the real deal.
|
||||||
|
@ -911,11 +887,27 @@ class SecureTransportContext(object):
|
||||||
server_hostname,
|
server_hostname,
|
||||||
self._verify,
|
self._verify,
|
||||||
self._trust_bundle,
|
self._trust_bundle,
|
||||||
self._min_version,
|
_tls_version_to_st[self._minimum_version],
|
||||||
self._max_version,
|
_tls_version_to_st[self._maximum_version],
|
||||||
self._client_cert,
|
self._client_cert,
|
||||||
self._client_key,
|
self._client_key,
|
||||||
self._client_key_passphrase,
|
self._client_key_passphrase,
|
||||||
self._alpn_protocols,
|
self._alpn_protocols,
|
||||||
)
|
)
|
||||||
return wrapped_socket
|
return wrapped_socket
|
||||||
|
|
||||||
|
@property
|
||||||
|
def minimum_version(self) -> int:
|
||||||
|
return self._minimum_version
|
||||||
|
|
||||||
|
@minimum_version.setter
|
||||||
|
def minimum_version(self, minimum_version: int) -> None:
|
||||||
|
self._minimum_version = minimum_version
|
||||||
|
|
||||||
|
@property
|
||||||
|
def maximum_version(self) -> int:
|
||||||
|
return self._maximum_version
|
||||||
|
|
||||||
|
@maximum_version.setter
|
||||||
|
def maximum_version(self, maximum_version: int) -> None:
|
||||||
|
self._maximum_version = maximum_version
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
"""
|
||||||
This module contains provisional support for SOCKS proxies from within
|
This module contains provisional support for SOCKS proxies from within
|
||||||
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
|
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
|
||||||
|
@ -38,10 +37,11 @@ with the proxy:
|
||||||
proxy_url="socks5h://<username>:<password>@proxy-host"
|
proxy_url="socks5h://<username>:<password>@proxy-host"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import socks
|
import socks # type: ignore[import]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
@ -51,13 +51,13 @@ except ImportError:
|
||||||
(
|
(
|
||||||
"SOCKS support in urllib3 requires the installation of optional "
|
"SOCKS support in urllib3 requires the installation of optional "
|
||||||
"dependencies: specifically, PySocks. For more information, see "
|
"dependencies: specifically, PySocks. For more information, see "
|
||||||
"https://urllib3.readthedocs.io/en/1.26.x/contrib.html#socks-proxies"
|
"https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies"
|
||||||
),
|
),
|
||||||
DependencyWarning,
|
DependencyWarning,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
from socket import error as SocketError
|
import typing
|
||||||
from socket import timeout as SocketTimeout
|
from socket import timeout as SocketTimeout
|
||||||
|
|
||||||
from ..connection import HTTPConnection, HTTPSConnection
|
from ..connection import HTTPConnection, HTTPSConnection
|
||||||
|
@ -69,7 +69,21 @@ from ..util.url import parse_url
|
||||||
try:
|
try:
|
||||||
import ssl
|
import ssl
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ssl = None
|
ssl = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
class _TYPE_SOCKS_OPTIONS(TypedDict):
|
||||||
|
socks_version: int
|
||||||
|
proxy_host: str | None
|
||||||
|
proxy_port: str | None
|
||||||
|
username: str | None
|
||||||
|
password: str | None
|
||||||
|
rdns: bool
|
||||||
|
|
||||||
|
except ImportError: # Python 3.7
|
||||||
|
_TYPE_SOCKS_OPTIONS = typing.Dict[str, typing.Any] # type: ignore[misc, assignment]
|
||||||
|
|
||||||
|
|
||||||
class SOCKSConnection(HTTPConnection):
|
class SOCKSConnection(HTTPConnection):
|
||||||
|
@ -77,15 +91,20 @@ class SOCKSConnection(HTTPConnection):
|
||||||
A plain-text HTTP connection that connects via a SOCKS proxy.
|
A plain-text HTTP connection that connects via a SOCKS proxy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(
|
||||||
self._socks_options = kwargs.pop("_socks_options")
|
self,
|
||||||
super(SOCKSConnection, self).__init__(*args, **kwargs)
|
_socks_options: _TYPE_SOCKS_OPTIONS,
|
||||||
|
*args: typing.Any,
|
||||||
|
**kwargs: typing.Any,
|
||||||
|
) -> None:
|
||||||
|
self._socks_options = _socks_options
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _new_conn(self):
|
def _new_conn(self) -> socks.socksocket:
|
||||||
"""
|
"""
|
||||||
Establish a new connection via the SOCKS proxy.
|
Establish a new connection via the SOCKS proxy.
|
||||||
"""
|
"""
|
||||||
extra_kw = {}
|
extra_kw: dict[str, typing.Any] = {}
|
||||||
if self.source_address:
|
if self.source_address:
|
||||||
extra_kw["source_address"] = self.source_address
|
extra_kw["source_address"] = self.source_address
|
||||||
|
|
||||||
|
@ -102,15 +121,14 @@ class SOCKSConnection(HTTPConnection):
|
||||||
proxy_password=self._socks_options["password"],
|
proxy_password=self._socks_options["password"],
|
||||||
proxy_rdns=self._socks_options["rdns"],
|
proxy_rdns=self._socks_options["rdns"],
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
**extra_kw
|
**extra_kw,
|
||||||
)
|
)
|
||||||
|
|
||||||
except SocketTimeout:
|
except SocketTimeout as e:
|
||||||
raise ConnectTimeoutError(
|
raise ConnectTimeoutError(
|
||||||
self,
|
self,
|
||||||
"Connection to %s timed out. (connect timeout=%s)"
|
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
|
||||||
% (self.host, self.timeout),
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
except socks.ProxyError as e:
|
except socks.ProxyError as e:
|
||||||
# This is fragile as hell, but it seems to be the only way to raise
|
# This is fragile as hell, but it seems to be the only way to raise
|
||||||
|
@ -120,22 +138,23 @@ class SOCKSConnection(HTTPConnection):
|
||||||
if isinstance(error, SocketTimeout):
|
if isinstance(error, SocketTimeout):
|
||||||
raise ConnectTimeoutError(
|
raise ConnectTimeoutError(
|
||||||
self,
|
self,
|
||||||
"Connection to %s timed out. (connect timeout=%s)"
|
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
|
||||||
% (self.host, self.timeout),
|
) from e
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
# Adding `from e` messes with coverage somehow, so it's omitted.
|
||||||
|
# See #2386.
|
||||||
raise NewConnectionError(
|
raise NewConnectionError(
|
||||||
self, "Failed to establish a new connection: %s" % error
|
self, f"Failed to establish a new connection: {error}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NewConnectionError(
|
raise NewConnectionError(
|
||||||
self, "Failed to establish a new connection: %s" % e
|
self, f"Failed to establish a new connection: {e}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
except SocketError as e: # Defensive: PySocks should catch all these.
|
except OSError as e: # Defensive: PySocks should catch all these.
|
||||||
raise NewConnectionError(
|
raise NewConnectionError(
|
||||||
self, "Failed to establish a new connection: %s" % e
|
self, f"Failed to establish a new connection: {e}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
@ -169,12 +188,12 @@ class SOCKSProxyManager(PoolManager):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
proxy_url,
|
proxy_url: str,
|
||||||
username=None,
|
username: str | None = None,
|
||||||
password=None,
|
password: str | None = None,
|
||||||
num_pools=10,
|
num_pools: int = 10,
|
||||||
headers=None,
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
**connection_pool_kw
|
**connection_pool_kw: typing.Any,
|
||||||
):
|
):
|
||||||
parsed = parse_url(proxy_url)
|
parsed = parse_url(proxy_url)
|
||||||
|
|
||||||
|
@ -195,7 +214,7 @@ class SOCKSProxyManager(PoolManager):
|
||||||
socks_version = socks.PROXY_TYPE_SOCKS4
|
socks_version = socks.PROXY_TYPE_SOCKS4
|
||||||
rdns = True
|
rdns = True
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unable to determine SOCKS version from %s" % proxy_url)
|
raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
|
||||||
|
|
||||||
self.proxy_url = proxy_url
|
self.proxy_url = proxy_url
|
||||||
|
|
||||||
|
@ -209,8 +228,6 @@ class SOCKSProxyManager(PoolManager):
|
||||||
}
|
}
|
||||||
connection_pool_kw["_socks_options"] = socks_options
|
connection_pool_kw["_socks_options"] = socks_options
|
||||||
|
|
||||||
super(SOCKSProxyManager, self).__init__(
|
super().__init__(num_pools, headers, **connection_pool_kw)
|
||||||
num_pools, headers, **connection_pool_kw
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme
|
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme
|
||||||
|
|
|
@ -1,6 +1,16 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
from .packages.six.moves.http_client import IncompleteRead as httplib_IncompleteRead
|
import socket
|
||||||
|
import typing
|
||||||
|
import warnings
|
||||||
|
from email.errors import MessageDefect
|
||||||
|
from http.client import IncompleteRead as httplib_IncompleteRead
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from .connection import HTTPConnection
|
||||||
|
from .connectionpool import ConnectionPool
|
||||||
|
from .response import HTTPResponse
|
||||||
|
from .util.retry import Retry
|
||||||
|
|
||||||
# Base Exceptions
|
# Base Exceptions
|
||||||
|
|
||||||
|
@ -8,23 +18,24 @@ from .packages.six.moves.http_client import IncompleteRead as httplib_Incomplete
|
||||||
class HTTPError(Exception):
|
class HTTPError(Exception):
|
||||||
"""Base exception used by this module."""
|
"""Base exception used by this module."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPWarning(Warning):
|
class HTTPWarning(Warning):
|
||||||
"""Base warning used by this module."""
|
"""Base warning used by this module."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
_TYPE_REDUCE_RESULT = typing.Tuple[
|
||||||
|
typing.Callable[..., object], typing.Tuple[object, ...]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class PoolError(HTTPError):
|
class PoolError(HTTPError):
|
||||||
"""Base exception for errors caused within a pool."""
|
"""Base exception for errors caused within a pool."""
|
||||||
|
|
||||||
def __init__(self, pool, message):
|
def __init__(self, pool: ConnectionPool, message: str) -> None:
|
||||||
self.pool = pool
|
self.pool = pool
|
||||||
HTTPError.__init__(self, "%s: %s" % (pool, message))
|
super().__init__(f"{pool}: {message}")
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self) -> _TYPE_REDUCE_RESULT:
|
||||||
# For pickling purposes.
|
# For pickling purposes.
|
||||||
return self.__class__, (None, None)
|
return self.__class__, (None, None)
|
||||||
|
|
||||||
|
@ -32,11 +43,11 @@ class PoolError(HTTPError):
|
||||||
class RequestError(PoolError):
|
class RequestError(PoolError):
|
||||||
"""Base exception for PoolErrors that have associated URLs."""
|
"""Base exception for PoolErrors that have associated URLs."""
|
||||||
|
|
||||||
def __init__(self, pool, url, message):
|
def __init__(self, pool: ConnectionPool, url: str, message: str) -> None:
|
||||||
self.url = url
|
self.url = url
|
||||||
PoolError.__init__(self, pool, message)
|
super().__init__(pool, message)
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self) -> _TYPE_REDUCE_RESULT:
|
||||||
# For pickling purposes.
|
# For pickling purposes.
|
||||||
return self.__class__, (None, self.url, None)
|
return self.__class__, (None, self.url, None)
|
||||||
|
|
||||||
|
@ -44,28 +55,25 @@ class RequestError(PoolError):
|
||||||
class SSLError(HTTPError):
|
class SSLError(HTTPError):
|
||||||
"""Raised when SSL certificate fails in an HTTPS connection."""
|
"""Raised when SSL certificate fails in an HTTPS connection."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProxyError(HTTPError):
|
class ProxyError(HTTPError):
|
||||||
"""Raised when the connection to a proxy fails."""
|
"""Raised when the connection to a proxy fails."""
|
||||||
|
|
||||||
def __init__(self, message, error, *args):
|
# The original error is also available as __cause__.
|
||||||
super(ProxyError, self).__init__(message, error, *args)
|
original_error: Exception
|
||||||
|
|
||||||
|
def __init__(self, message: str, error: Exception) -> None:
|
||||||
|
super().__init__(message, error)
|
||||||
self.original_error = error
|
self.original_error = error
|
||||||
|
|
||||||
|
|
||||||
class DecodeError(HTTPError):
|
class DecodeError(HTTPError):
|
||||||
"""Raised when automatic decoding based on Content-Type fails."""
|
"""Raised when automatic decoding based on Content-Type fails."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProtocolError(HTTPError):
|
class ProtocolError(HTTPError):
|
||||||
"""Raised when something unexpected happens mid-request/response."""
|
"""Raised when something unexpected happens mid-request/response."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
#: Renamed to ProtocolError but aliased for backwards compatibility.
|
#: Renamed to ProtocolError but aliased for backwards compatibility.
|
||||||
ConnectionError = ProtocolError
|
ConnectionError = ProtocolError
|
||||||
|
@ -79,33 +87,36 @@ class MaxRetryError(RequestError):
|
||||||
|
|
||||||
:param pool: The connection pool
|
:param pool: The connection pool
|
||||||
:type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool`
|
:type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool`
|
||||||
:param string url: The requested Url
|
:param str url: The requested Url
|
||||||
:param exceptions.Exception reason: The underlying error
|
:param reason: The underlying error
|
||||||
|
:type reason: :class:`Exception`
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pool, url, reason=None):
|
def __init__(
|
||||||
|
self, pool: ConnectionPool, url: str, reason: Exception | None = None
|
||||||
|
) -> None:
|
||||||
self.reason = reason
|
self.reason = reason
|
||||||
|
|
||||||
message = "Max retries exceeded with url: %s (Caused by %r)" % (url, reason)
|
message = f"Max retries exceeded with url: {url} (Caused by {reason!r})"
|
||||||
|
|
||||||
RequestError.__init__(self, pool, url, message)
|
super().__init__(pool, url, message)
|
||||||
|
|
||||||
|
|
||||||
class HostChangedError(RequestError):
|
class HostChangedError(RequestError):
|
||||||
"""Raised when an existing pool gets a request for a foreign host."""
|
"""Raised when an existing pool gets a request for a foreign host."""
|
||||||
|
|
||||||
def __init__(self, pool, url, retries=3):
|
def __init__(
|
||||||
message = "Tried to open a foreign host with url: %s" % url
|
self, pool: ConnectionPool, url: str, retries: Retry | int = 3
|
||||||
RequestError.__init__(self, pool, url, message)
|
) -> None:
|
||||||
|
message = f"Tried to open a foreign host with url: {url}"
|
||||||
|
super().__init__(pool, url, message)
|
||||||
self.retries = retries
|
self.retries = retries
|
||||||
|
|
||||||
|
|
||||||
class TimeoutStateError(HTTPError):
|
class TimeoutStateError(HTTPError):
|
||||||
"""Raised when passing an invalid state to a timeout"""
|
"""Raised when passing an invalid state to a timeout"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TimeoutError(HTTPError):
|
class TimeoutError(HTTPError):
|
||||||
"""Raised when a socket timeout error occurs.
|
"""Raised when a socket timeout error occurs.
|
||||||
|
@ -114,53 +125,66 @@ class TimeoutError(HTTPError):
|
||||||
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
|
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ReadTimeoutError(TimeoutError, RequestError):
|
class ReadTimeoutError(TimeoutError, RequestError):
|
||||||
"""Raised when a socket timeout occurs while receiving data from a server"""
|
"""Raised when a socket timeout occurs while receiving data from a server"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# This timeout error does not have a URL attached and needs to inherit from the
|
# This timeout error does not have a URL attached and needs to inherit from the
|
||||||
# base HTTPError
|
# base HTTPError
|
||||||
class ConnectTimeoutError(TimeoutError):
|
class ConnectTimeoutError(TimeoutError):
|
||||||
"""Raised when a socket timeout occurs while connecting to a server"""
|
"""Raised when a socket timeout occurs while connecting to a server"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
class NewConnectionError(ConnectTimeoutError, HTTPError):
|
||||||
class NewConnectionError(ConnectTimeoutError, PoolError):
|
|
||||||
"""Raised when we fail to establish a new connection. Usually ECONNREFUSED."""
|
"""Raised when we fail to establish a new connection. Usually ECONNREFUSED."""
|
||||||
|
|
||||||
pass
|
def __init__(self, conn: HTTPConnection, message: str) -> None:
|
||||||
|
self.conn = conn
|
||||||
|
super().__init__(f"{conn}: {message}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pool(self) -> HTTPConnection:
|
||||||
|
warnings.warn(
|
||||||
|
"The 'pool' property is deprecated and will be removed "
|
||||||
|
"in urllib3 v2.1.0. Use 'conn' instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.conn
|
||||||
|
|
||||||
|
|
||||||
|
class NameResolutionError(NewConnectionError):
|
||||||
|
"""Raised when host name resolution fails."""
|
||||||
|
|
||||||
|
def __init__(self, host: str, conn: HTTPConnection, reason: socket.gaierror):
|
||||||
|
message = f"Failed to resolve '{host}' ({reason})"
|
||||||
|
super().__init__(conn, message)
|
||||||
|
|
||||||
|
|
||||||
class EmptyPoolError(PoolError):
|
class EmptyPoolError(PoolError):
|
||||||
"""Raised when a pool runs out of connections and no more are allowed."""
|
"""Raised when a pool runs out of connections and no more are allowed."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
class FullPoolError(PoolError):
|
||||||
|
"""Raised when we try to add a connection to a full pool in blocking mode."""
|
||||||
|
|
||||||
|
|
||||||
class ClosedPoolError(PoolError):
|
class ClosedPoolError(PoolError):
|
||||||
"""Raised when a request enters a pool after the pool has been closed."""
|
"""Raised when a request enters a pool after the pool has been closed."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LocationValueError(ValueError, HTTPError):
|
class LocationValueError(ValueError, HTTPError):
|
||||||
"""Raised when there is something wrong with a given URL input."""
|
"""Raised when there is something wrong with a given URL input."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LocationParseError(LocationValueError):
|
class LocationParseError(LocationValueError):
|
||||||
"""Raised when get_host or similar fails to parse the URL input."""
|
"""Raised when get_host or similar fails to parse the URL input."""
|
||||||
|
|
||||||
def __init__(self, location):
|
def __init__(self, location: str) -> None:
|
||||||
message = "Failed to parse: %s" % location
|
message = f"Failed to parse: {location}"
|
||||||
HTTPError.__init__(self, message)
|
super().__init__(message)
|
||||||
|
|
||||||
self.location = location
|
self.location = location
|
||||||
|
|
||||||
|
@ -168,9 +192,9 @@ class LocationParseError(LocationValueError):
|
||||||
class URLSchemeUnknown(LocationValueError):
|
class URLSchemeUnknown(LocationValueError):
|
||||||
"""Raised when a URL input has an unsupported scheme."""
|
"""Raised when a URL input has an unsupported scheme."""
|
||||||
|
|
||||||
def __init__(self, scheme):
|
def __init__(self, scheme: str):
|
||||||
message = "Not supported URL scheme %s" % scheme
|
message = f"Not supported URL scheme {scheme}"
|
||||||
super(URLSchemeUnknown, self).__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
self.scheme = scheme
|
self.scheme = scheme
|
||||||
|
|
||||||
|
@ -185,38 +209,22 @@ class ResponseError(HTTPError):
|
||||||
class SecurityWarning(HTTPWarning):
|
class SecurityWarning(HTTPWarning):
|
||||||
"""Warned when performing security reducing actions"""
|
"""Warned when performing security reducing actions"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SubjectAltNameWarning(SecurityWarning):
|
|
||||||
"""Warned when connecting to a host with a certificate missing a SAN."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InsecureRequestWarning(SecurityWarning):
|
class InsecureRequestWarning(SecurityWarning):
|
||||||
"""Warned when making an unverified HTTPS request."""
|
"""Warned when making an unverified HTTPS request."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
class NotOpenSSLWarning(SecurityWarning):
|
||||||
|
"""Warned when using unsupported SSL library"""
|
||||||
|
|
||||||
|
|
||||||
class SystemTimeWarning(SecurityWarning):
|
class SystemTimeWarning(SecurityWarning):
|
||||||
"""Warned when system time is suspected to be wrong"""
|
"""Warned when system time is suspected to be wrong"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InsecurePlatformWarning(SecurityWarning):
|
class InsecurePlatformWarning(SecurityWarning):
|
||||||
"""Warned when certain TLS/SSL configuration is not available on a platform."""
|
"""Warned when certain TLS/SSL configuration is not available on a platform."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SNIMissingWarning(HTTPWarning):
|
|
||||||
"""Warned when making a HTTPS request without SNI available."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DependencyWarning(HTTPWarning):
|
class DependencyWarning(HTTPWarning):
|
||||||
"""
|
"""
|
||||||
|
@ -224,14 +232,10 @@ class DependencyWarning(HTTPWarning):
|
||||||
dependencies.
|
dependencies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseNotChunked(ProtocolError, ValueError):
|
class ResponseNotChunked(ProtocolError, ValueError):
|
||||||
"""Response needs to be chunked in order to read it as chunks."""
|
"""Response needs to be chunked in order to read it as chunks."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BodyNotHttplibCompatible(HTTPError):
|
class BodyNotHttplibCompatible(HTTPError):
|
||||||
"""
|
"""
|
||||||
|
@ -239,8 +243,6 @@ class BodyNotHttplibCompatible(HTTPError):
|
||||||
(have an fp attribute which returns raw chunks) for read_chunked().
|
(have an fp attribute which returns raw chunks) for read_chunked().
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class IncompleteRead(HTTPError, httplib_IncompleteRead):
|
class IncompleteRead(HTTPError, httplib_IncompleteRead):
|
||||||
"""
|
"""
|
||||||
|
@ -250,12 +252,13 @@ class IncompleteRead(HTTPError, httplib_IncompleteRead):
|
||||||
for ``partial`` to avoid creating large objects on streamed reads.
|
for ``partial`` to avoid creating large objects on streamed reads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, partial, expected):
|
def __init__(self, partial: int, expected: int) -> None:
|
||||||
super(IncompleteRead, self).__init__(partial, expected)
|
self.partial = partial # type: ignore[assignment]
|
||||||
|
self.expected = expected
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return "IncompleteRead(%i bytes read, %i more expected)" % (
|
return "IncompleteRead(%i bytes read, %i more expected)" % (
|
||||||
self.partial,
|
self.partial, # type: ignore[str-format]
|
||||||
self.expected,
|
self.expected,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -263,14 +266,13 @@ class IncompleteRead(HTTPError, httplib_IncompleteRead):
|
||||||
class InvalidChunkLength(HTTPError, httplib_IncompleteRead):
|
class InvalidChunkLength(HTTPError, httplib_IncompleteRead):
|
||||||
"""Invalid chunk length in a chunked response."""
|
"""Invalid chunk length in a chunked response."""
|
||||||
|
|
||||||
def __init__(self, response, length):
|
def __init__(self, response: HTTPResponse, length: bytes) -> None:
|
||||||
super(InvalidChunkLength, self).__init__(
|
self.partial: int = response.tell() # type: ignore[assignment]
|
||||||
response.tell(), response.length_remaining
|
self.expected: int | None = response.length_remaining
|
||||||
)
|
|
||||||
self.response = response
|
self.response = response
|
||||||
self.length = length
|
self.length = length
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return "InvalidChunkLength(got length %r, %i bytes read)" % (
|
return "InvalidChunkLength(got length %r, %i bytes read)" % (
|
||||||
self.length,
|
self.length,
|
||||||
self.partial,
|
self.partial,
|
||||||
|
@ -280,15 +282,13 @@ class InvalidChunkLength(HTTPError, httplib_IncompleteRead):
|
||||||
class InvalidHeader(HTTPError):
|
class InvalidHeader(HTTPError):
|
||||||
"""The header provided was somehow invalid."""
|
"""The header provided was somehow invalid."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProxySchemeUnknown(AssertionError, URLSchemeUnknown):
|
class ProxySchemeUnknown(AssertionError, URLSchemeUnknown):
|
||||||
"""ProxyManager does not support the supplied scheme"""
|
"""ProxyManager does not support the supplied scheme"""
|
||||||
|
|
||||||
# TODO(t-8ch): Stop inheriting from AssertionError in v2.0.
|
# TODO(t-8ch): Stop inheriting from AssertionError in v2.0.
|
||||||
|
|
||||||
def __init__(self, scheme):
|
def __init__(self, scheme: str | None) -> None:
|
||||||
# 'localhost' is here because our URL parser parses
|
# 'localhost' is here because our URL parser parses
|
||||||
# localhost:8080 -> scheme=localhost, remove if we fix this.
|
# localhost:8080 -> scheme=localhost, remove if we fix this.
|
||||||
if scheme == "localhost":
|
if scheme == "localhost":
|
||||||
|
@ -296,28 +296,23 @@ class ProxySchemeUnknown(AssertionError, URLSchemeUnknown):
|
||||||
if scheme is None:
|
if scheme is None:
|
||||||
message = "Proxy URL had no scheme, should start with http:// or https://"
|
message = "Proxy URL had no scheme, should start with http:// or https://"
|
||||||
else:
|
else:
|
||||||
message = (
|
message = f"Proxy URL had unsupported scheme {scheme}, should use http:// or https://"
|
||||||
"Proxy URL had unsupported scheme %s, should use http:// or https://"
|
super().__init__(message)
|
||||||
% scheme
|
|
||||||
)
|
|
||||||
super(ProxySchemeUnknown, self).__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ProxySchemeUnsupported(ValueError):
|
class ProxySchemeUnsupported(ValueError):
|
||||||
"""Fetching HTTPS resources through HTTPS proxies is unsupported"""
|
"""Fetching HTTPS resources through HTTPS proxies is unsupported"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HeaderParsingError(HTTPError):
|
class HeaderParsingError(HTTPError):
|
||||||
"""Raised by assert_header_parsing, but we convert it to a log.warning statement."""
|
"""Raised by assert_header_parsing, but we convert it to a log.warning statement."""
|
||||||
|
|
||||||
def __init__(self, defects, unparsed_data):
|
def __init__(
|
||||||
message = "%s, unparsed data: %r" % (defects or "Unknown", unparsed_data)
|
self, defects: list[MessageDefect], unparsed_data: bytes | str | None
|
||||||
super(HeaderParsingError, self).__init__(message)
|
) -> None:
|
||||||
|
message = f"{defects or 'Unknown'}, unparsed data: {unparsed_data!r}"
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class UnrewindableBodyError(HTTPError):
|
class UnrewindableBodyError(HTTPError):
|
||||||
"""urllib3 encountered an error when trying to rewind a body"""
|
"""urllib3 encountered an error when trying to rewind a body"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
|
@ -1,55 +0,0 @@
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from urllib3.connectionpool import ConnectionPool
|
|
||||||
|
|
||||||
class HTTPError(Exception): ...
|
|
||||||
class HTTPWarning(Warning): ...
|
|
||||||
|
|
||||||
class PoolError(HTTPError):
|
|
||||||
pool: ConnectionPool
|
|
||||||
def __init__(self, pool: ConnectionPool, message: str) -> None: ...
|
|
||||||
def __reduce__(self) -> Union[str, Tuple[Any, ...]]: ...
|
|
||||||
|
|
||||||
class RequestError(PoolError):
|
|
||||||
url: str
|
|
||||||
def __init__(self, pool: ConnectionPool, url: str, message: str) -> None: ...
|
|
||||||
def __reduce__(self) -> Union[str, Tuple[Any, ...]]: ...
|
|
||||||
|
|
||||||
class SSLError(HTTPError): ...
|
|
||||||
class ProxyError(HTTPError): ...
|
|
||||||
class DecodeError(HTTPError): ...
|
|
||||||
class ProtocolError(HTTPError): ...
|
|
||||||
|
|
||||||
ConnectionError: ProtocolError
|
|
||||||
|
|
||||||
class MaxRetryError(RequestError):
|
|
||||||
reason: str
|
|
||||||
def __init__(
|
|
||||||
self, pool: ConnectionPool, url: str, reason: Optional[str]
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
class HostChangedError(RequestError):
|
|
||||||
retries: int
|
|
||||||
def __init__(self, pool: ConnectionPool, url: str, retries: int) -> None: ...
|
|
||||||
|
|
||||||
class TimeoutStateError(HTTPError): ...
|
|
||||||
class TimeoutError(HTTPError): ...
|
|
||||||
class ReadTimeoutError(TimeoutError, RequestError): ...
|
|
||||||
class ConnectTimeoutError(TimeoutError): ...
|
|
||||||
class EmptyPoolError(PoolError): ...
|
|
||||||
class ClosedPoolError(PoolError): ...
|
|
||||||
class LocationValueError(ValueError, HTTPError): ...
|
|
||||||
|
|
||||||
class LocationParseError(LocationValueError):
|
|
||||||
location: str
|
|
||||||
def __init__(self, location: str) -> None: ...
|
|
||||||
|
|
||||||
class ResponseError(HTTPError):
|
|
||||||
GENERIC_ERROR: Any
|
|
||||||
SPECIFIC_ERROR: Any
|
|
||||||
|
|
||||||
class SecurityWarning(HTTPWarning): ...
|
|
||||||
class InsecureRequestWarning(SecurityWarning): ...
|
|
||||||
class SystemTimeWarning(SecurityWarning): ...
|
|
||||||
class InsecurePlatformWarning(SecurityWarning): ...
|
|
|
@ -1,13 +1,20 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
import email.utils
|
import email.utils
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import re
|
import typing
|
||||||
|
|
||||||
from .packages import six
|
_TYPE_FIELD_VALUE = typing.Union[str, bytes]
|
||||||
|
_TYPE_FIELD_VALUE_TUPLE = typing.Union[
|
||||||
|
_TYPE_FIELD_VALUE,
|
||||||
|
typing.Tuple[str, _TYPE_FIELD_VALUE],
|
||||||
|
typing.Tuple[str, _TYPE_FIELD_VALUE, str],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def guess_content_type(filename, default="application/octet-stream"):
|
def guess_content_type(
|
||||||
|
filename: str | None, default: str = "application/octet-stream"
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Guess the "Content-Type" of a file.
|
Guess the "Content-Type" of a file.
|
||||||
|
|
||||||
|
@ -21,7 +28,7 @@ def guess_content_type(filename, default="application/octet-stream"):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
def format_header_param_rfc2231(name, value):
|
def format_header_param_rfc2231(name: str, value: _TYPE_FIELD_VALUE) -> str:
|
||||||
"""
|
"""
|
||||||
Helper function to format and quote a single header parameter using the
|
Helper function to format and quote a single header parameter using the
|
||||||
strategy defined in RFC 2231.
|
strategy defined in RFC 2231.
|
||||||
|
@ -34,14 +41,28 @@ def format_header_param_rfc2231(name, value):
|
||||||
The name of the parameter, a string expected to be ASCII only.
|
The name of the parameter, a string expected to be ASCII only.
|
||||||
:param value:
|
:param value:
|
||||||
The value of the parameter, provided as ``bytes`` or `str``.
|
The value of the parameter, provided as ``bytes`` or `str``.
|
||||||
:ret:
|
:returns:
|
||||||
An RFC-2231-formatted unicode string.
|
An RFC-2231-formatted unicode string.
|
||||||
|
|
||||||
|
.. deprecated:: 2.0.0
|
||||||
|
Will be removed in urllib3 v2.1.0. This is not valid for
|
||||||
|
``multipart/form-data`` header parameters.
|
||||||
"""
|
"""
|
||||||
if isinstance(value, six.binary_type):
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"'format_header_param_rfc2231' is deprecated and will be "
|
||||||
|
"removed in urllib3 v2.1.0. This is not valid for "
|
||||||
|
"multipart/form-data header parameters.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(value, bytes):
|
||||||
value = value.decode("utf-8")
|
value = value.decode("utf-8")
|
||||||
|
|
||||||
if not any(ch in value for ch in '"\\\r\n'):
|
if not any(ch in value for ch in '"\\\r\n'):
|
||||||
result = u'%s="%s"' % (name, value)
|
result = f'{name}="{value}"'
|
||||||
try:
|
try:
|
||||||
result.encode("ascii")
|
result.encode("ascii")
|
||||||
except (UnicodeEncodeError, UnicodeDecodeError):
|
except (UnicodeEncodeError, UnicodeDecodeError):
|
||||||
|
@ -49,81 +70,87 @@ def format_header_param_rfc2231(name, value):
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if six.PY2: # Python 2:
|
|
||||||
value = value.encode("utf-8")
|
|
||||||
|
|
||||||
# encode_rfc2231 accepts an encoded string and returns an ascii-encoded
|
|
||||||
# string in Python 2 but accepts and returns unicode strings in Python 3
|
|
||||||
value = email.utils.encode_rfc2231(value, "utf-8")
|
value = email.utils.encode_rfc2231(value, "utf-8")
|
||||||
value = "%s*=%s" % (name, value)
|
value = f"{name}*={value}"
|
||||||
|
|
||||||
if six.PY2: # Python 2:
|
|
||||||
value = value.decode("utf-8")
|
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
_HTML5_REPLACEMENTS = {
|
def format_multipart_header_param(name: str, value: _TYPE_FIELD_VALUE) -> str:
|
||||||
u"\u0022": u"%22",
|
|
||||||
# Replace "\" with "\\".
|
|
||||||
u"\u005C": u"\u005C\u005C",
|
|
||||||
}
|
|
||||||
|
|
||||||
# All control characters from 0x00 to 0x1F *except* 0x1B.
|
|
||||||
_HTML5_REPLACEMENTS.update(
|
|
||||||
{
|
|
||||||
six.unichr(cc): u"%{:02X}".format(cc)
|
|
||||||
for cc in range(0x00, 0x1F + 1)
|
|
||||||
if cc not in (0x1B,)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _replace_multiple(value, needles_and_replacements):
|
|
||||||
def replacer(match):
|
|
||||||
return needles_and_replacements[match.group(0)]
|
|
||||||
|
|
||||||
pattern = re.compile(
|
|
||||||
r"|".join([re.escape(needle) for needle in needles_and_replacements.keys()])
|
|
||||||
)
|
|
||||||
|
|
||||||
result = pattern.sub(replacer, value)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def format_header_param_html5(name, value):
|
|
||||||
"""
|
"""
|
||||||
Helper function to format and quote a single header parameter using the
|
Format and quote a single multipart header parameter.
|
||||||
HTML5 strategy.
|
|
||||||
|
|
||||||
Particularly useful for header parameters which might contain
|
This follows the `WHATWG HTML Standard`_ as of 2021/06/10, matching
|
||||||
non-ASCII values, like file names. This follows the `HTML5 Working Draft
|
the behavior of current browser and curl versions. Values are
|
||||||
Section 4.10.22.7`_ and matches the behavior of curl and modern browsers.
|
assumed to be UTF-8. The ``\\n``, ``\\r``, and ``"`` characters are
|
||||||
|
percent encoded.
|
||||||
|
|
||||||
.. _HTML5 Working Draft Section 4.10.22.7:
|
.. _WHATWG HTML Standard:
|
||||||
https://w3c.github.io/html/sec-forms.html#multipart-form-data
|
https://html.spec.whatwg.org/multipage/
|
||||||
|
form-control-infrastructure.html#multipart-form-data
|
||||||
|
|
||||||
:param name:
|
:param name:
|
||||||
The name of the parameter, a string expected to be ASCII only.
|
The name of the parameter, an ASCII-only ``str``.
|
||||||
:param value:
|
:param value:
|
||||||
The value of the parameter, provided as ``bytes`` or `str``.
|
The value of the parameter, a ``str`` or UTF-8 encoded
|
||||||
:ret:
|
``bytes``.
|
||||||
A unicode string, stripped of troublesome characters.
|
:returns:
|
||||||
|
A string ``name="value"`` with the escaped value.
|
||||||
|
|
||||||
|
.. versionchanged:: 2.0.0
|
||||||
|
Matches the WHATWG HTML Standard as of 2021/06/10. Control
|
||||||
|
characters are no longer percent encoded.
|
||||||
|
|
||||||
|
.. versionchanged:: 2.0.0
|
||||||
|
Renamed from ``format_header_param_html5`` and
|
||||||
|
``format_header_param``. The old names will be removed in
|
||||||
|
urllib3 v2.1.0.
|
||||||
"""
|
"""
|
||||||
if isinstance(value, six.binary_type):
|
if isinstance(value, bytes):
|
||||||
value = value.decode("utf-8")
|
value = value.decode("utf-8")
|
||||||
|
|
||||||
value = _replace_multiple(value, _HTML5_REPLACEMENTS)
|
# percent encode \n \r "
|
||||||
|
value = value.translate({10: "%0A", 13: "%0D", 34: "%22"})
|
||||||
return u'%s="%s"' % (name, value)
|
return f'{name}="{value}"'
|
||||||
|
|
||||||
|
|
||||||
# For backwards-compatibility.
|
def format_header_param_html5(name: str, value: _TYPE_FIELD_VALUE) -> str:
|
||||||
format_header_param = format_header_param_html5
|
"""
|
||||||
|
.. deprecated:: 2.0.0
|
||||||
|
Renamed to :func:`format_multipart_header_param`. Will be
|
||||||
|
removed in urllib3 v2.1.0.
|
||||||
|
"""
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"'format_header_param_html5' has been renamed to "
|
||||||
|
"'format_multipart_header_param'. The old name will be "
|
||||||
|
"removed in urllib3 v2.1.0.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return format_multipart_header_param(name, value)
|
||||||
|
|
||||||
|
|
||||||
class RequestField(object):
|
def format_header_param(name: str, value: _TYPE_FIELD_VALUE) -> str:
|
||||||
|
"""
|
||||||
|
.. deprecated:: 2.0.0
|
||||||
|
Renamed to :func:`format_multipart_header_param`. Will be
|
||||||
|
removed in urllib3 v2.1.0.
|
||||||
|
"""
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"'format_header_param' has been renamed to "
|
||||||
|
"'format_multipart_header_param'. The old name will be "
|
||||||
|
"removed in urllib3 v2.1.0.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return format_multipart_header_param(name, value)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestField:
|
||||||
"""
|
"""
|
||||||
A data container for request body parameters.
|
A data container for request body parameters.
|
||||||
|
|
||||||
|
@ -135,29 +162,47 @@ class RequestField(object):
|
||||||
An optional filename of the request field. Must be unicode.
|
An optional filename of the request field. Must be unicode.
|
||||||
:param headers:
|
:param headers:
|
||||||
An optional dict-like object of headers to initially use for the field.
|
An optional dict-like object of headers to initially use for the field.
|
||||||
:param header_formatter:
|
|
||||||
An optional callable that is used to encode and format the headers. By
|
.. versionchanged:: 2.0.0
|
||||||
default, this is :func:`format_header_param_html5`.
|
The ``header_formatter`` parameter is deprecated and will
|
||||||
|
be removed in urllib3 v2.1.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name,
|
name: str,
|
||||||
data,
|
data: _TYPE_FIELD_VALUE,
|
||||||
filename=None,
|
filename: str | None = None,
|
||||||
headers=None,
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
header_formatter=format_header_param_html5,
|
header_formatter: typing.Callable[[str, _TYPE_FIELD_VALUE], str] | None = None,
|
||||||
):
|
):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._filename = filename
|
self._filename = filename
|
||||||
self.data = data
|
self.data = data
|
||||||
self.headers = {}
|
self.headers: dict[str, str | None] = {}
|
||||||
if headers:
|
if headers:
|
||||||
self.headers = dict(headers)
|
self.headers = dict(headers)
|
||||||
self.header_formatter = header_formatter
|
|
||||||
|
if header_formatter is not None:
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"The 'header_formatter' parameter is deprecated and "
|
||||||
|
"will be removed in urllib3 v2.1.0.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
self.header_formatter = header_formatter
|
||||||
|
else:
|
||||||
|
self.header_formatter = format_multipart_header_param
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tuples(cls, fieldname, value, header_formatter=format_header_param_html5):
|
def from_tuples(
|
||||||
|
cls,
|
||||||
|
fieldname: str,
|
||||||
|
value: _TYPE_FIELD_VALUE_TUPLE,
|
||||||
|
header_formatter: typing.Callable[[str, _TYPE_FIELD_VALUE], str] | None = None,
|
||||||
|
) -> RequestField:
|
||||||
"""
|
"""
|
||||||
A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
|
A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
|
||||||
|
|
||||||
|
@ -174,11 +219,19 @@ class RequestField(object):
|
||||||
|
|
||||||
Field names and filenames must be unicode.
|
Field names and filenames must be unicode.
|
||||||
"""
|
"""
|
||||||
|
filename: str | None
|
||||||
|
content_type: str | None
|
||||||
|
data: _TYPE_FIELD_VALUE
|
||||||
|
|
||||||
if isinstance(value, tuple):
|
if isinstance(value, tuple):
|
||||||
if len(value) == 3:
|
if len(value) == 3:
|
||||||
filename, data, content_type = value
|
filename, data, content_type = typing.cast(
|
||||||
|
typing.Tuple[str, _TYPE_FIELD_VALUE, str], value
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filename, data = value
|
filename, data = typing.cast(
|
||||||
|
typing.Tuple[str, _TYPE_FIELD_VALUE], value
|
||||||
|
)
|
||||||
content_type = guess_content_type(filename)
|
content_type = guess_content_type(filename)
|
||||||
else:
|
else:
|
||||||
filename = None
|
filename = None
|
||||||
|
@ -192,20 +245,29 @@ class RequestField(object):
|
||||||
|
|
||||||
return request_param
|
return request_param
|
||||||
|
|
||||||
def _render_part(self, name, value):
|
def _render_part(self, name: str, value: _TYPE_FIELD_VALUE) -> str:
|
||||||
"""
|
"""
|
||||||
Overridable helper function to format a single header parameter. By
|
Override this method to change how each multipart header
|
||||||
default, this calls ``self.header_formatter``.
|
parameter is formatted. By default, this calls
|
||||||
|
:func:`format_multipart_header_param`.
|
||||||
|
|
||||||
:param name:
|
:param name:
|
||||||
The name of the parameter, a string expected to be ASCII only.
|
The name of the parameter, an ASCII-only ``str``.
|
||||||
:param value:
|
:param value:
|
||||||
The value of the parameter, provided as a unicode string.
|
The value of the parameter, a ``str`` or UTF-8 encoded
|
||||||
"""
|
``bytes``.
|
||||||
|
|
||||||
|
:meta public:
|
||||||
|
"""
|
||||||
return self.header_formatter(name, value)
|
return self.header_formatter(name, value)
|
||||||
|
|
||||||
def _render_parts(self, header_parts):
|
def _render_parts(
|
||||||
|
self,
|
||||||
|
header_parts: (
|
||||||
|
dict[str, _TYPE_FIELD_VALUE | None]
|
||||||
|
| typing.Sequence[tuple[str, _TYPE_FIELD_VALUE | None]]
|
||||||
|
),
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Helper function to format and quote a single header.
|
Helper function to format and quote a single header.
|
||||||
|
|
||||||
|
@ -216,18 +278,21 @@ class RequestField(object):
|
||||||
A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format
|
A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format
|
||||||
as `k1="v1"; k2="v2"; ...`.
|
as `k1="v1"; k2="v2"; ...`.
|
||||||
"""
|
"""
|
||||||
|
iterable: typing.Iterable[tuple[str, _TYPE_FIELD_VALUE | None]]
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
iterable = header_parts
|
|
||||||
if isinstance(header_parts, dict):
|
if isinstance(header_parts, dict):
|
||||||
iterable = header_parts.items()
|
iterable = header_parts.items()
|
||||||
|
else:
|
||||||
|
iterable = header_parts
|
||||||
|
|
||||||
for name, value in iterable:
|
for name, value in iterable:
|
||||||
if value is not None:
|
if value is not None:
|
||||||
parts.append(self._render_part(name, value))
|
parts.append(self._render_part(name, value))
|
||||||
|
|
||||||
return u"; ".join(parts)
|
return "; ".join(parts)
|
||||||
|
|
||||||
def render_headers(self):
|
def render_headers(self) -> str:
|
||||||
"""
|
"""
|
||||||
Renders the headers for this request field.
|
Renders the headers for this request field.
|
||||||
"""
|
"""
|
||||||
|
@ -236,39 +301,45 @@ class RequestField(object):
|
||||||
sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
|
sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
|
||||||
for sort_key in sort_keys:
|
for sort_key in sort_keys:
|
||||||
if self.headers.get(sort_key, False):
|
if self.headers.get(sort_key, False):
|
||||||
lines.append(u"%s: %s" % (sort_key, self.headers[sort_key]))
|
lines.append(f"{sort_key}: {self.headers[sort_key]}")
|
||||||
|
|
||||||
for header_name, header_value in self.headers.items():
|
for header_name, header_value in self.headers.items():
|
||||||
if header_name not in sort_keys:
|
if header_name not in sort_keys:
|
||||||
if header_value:
|
if header_value:
|
||||||
lines.append(u"%s: %s" % (header_name, header_value))
|
lines.append(f"{header_name}: {header_value}")
|
||||||
|
|
||||||
lines.append(u"\r\n")
|
lines.append("\r\n")
|
||||||
return u"\r\n".join(lines)
|
return "\r\n".join(lines)
|
||||||
|
|
||||||
def make_multipart(
|
def make_multipart(
|
||||||
self, content_disposition=None, content_type=None, content_location=None
|
self,
|
||||||
):
|
content_disposition: str | None = None,
|
||||||
|
content_type: str | None = None,
|
||||||
|
content_location: str | None = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Makes this request field into a multipart request field.
|
Makes this request field into a multipart request field.
|
||||||
|
|
||||||
This method overrides "Content-Disposition", "Content-Type" and
|
This method overrides "Content-Disposition", "Content-Type" and
|
||||||
"Content-Location" headers to the request parameter.
|
"Content-Location" headers to the request parameter.
|
||||||
|
|
||||||
|
:param content_disposition:
|
||||||
|
The 'Content-Disposition' of the request body. Defaults to 'form-data'
|
||||||
:param content_type:
|
:param content_type:
|
||||||
The 'Content-Type' of the request body.
|
The 'Content-Type' of the request body.
|
||||||
:param content_location:
|
:param content_location:
|
||||||
The 'Content-Location' of the request body.
|
The 'Content-Location' of the request body.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.headers["Content-Disposition"] = content_disposition or u"form-data"
|
content_disposition = (content_disposition or "form-data") + "; ".join(
|
||||||
self.headers["Content-Disposition"] += u"; ".join(
|
|
||||||
[
|
[
|
||||||
u"",
|
"",
|
||||||
self._render_parts(
|
self._render_parts(
|
||||||
((u"name", self._name), (u"filename", self._filename))
|
(("name", self._name), ("filename", self._filename))
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.headers["Content-Disposition"] = content_disposition
|
||||||
self.headers["Content-Type"] = content_type
|
self.headers["Content-Type"] = content_type
|
||||||
self.headers["Content-Location"] = content_location
|
self.headers["Content-Location"] = content_location
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
# Stubs for requests.packages.urllib3.fields (Python 3.4)
|
|
||||||
|
|
||||||
from typing import Any, Callable, Mapping, Optional
|
|
||||||
|
|
||||||
def guess_content_type(filename: str, default: str) -> str: ...
|
|
||||||
def format_header_param_rfc2231(name: str, value: str) -> str: ...
|
|
||||||
def format_header_param_html5(name: str, value: str) -> str: ...
|
|
||||||
def format_header_param(name: str, value: str) -> str: ...
|
|
||||||
|
|
||||||
class RequestField:
|
|
||||||
data: Any
|
|
||||||
headers: Optional[Mapping[str, str]]
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
data: Any,
|
|
||||||
filename: Optional[str],
|
|
||||||
headers: Optional[Mapping[str, str]],
|
|
||||||
header_formatter: Callable[[str, str], str],
|
|
||||||
) -> None: ...
|
|
||||||
@classmethod
|
|
||||||
def from_tuples(
|
|
||||||
cls, fieldname: str, value: str, header_formatter: Callable[[str, str], str]
|
|
||||||
) -> RequestField: ...
|
|
||||||
def render_headers(self) -> str: ...
|
|
||||||
def make_multipart(
|
|
||||||
self, content_disposition: str, content_type: str, content_location: str
|
|
||||||
) -> None: ...
|
|
|
@ -1,28 +1,32 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
import binascii
|
import binascii
|
||||||
import codecs
|
import codecs
|
||||||
import os
|
import os
|
||||||
|
import typing
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from .fields import RequestField
|
from .fields import _TYPE_FIELD_VALUE_TUPLE, RequestField
|
||||||
from .packages import six
|
|
||||||
from .packages.six import b
|
|
||||||
|
|
||||||
writer = codecs.lookup("utf-8")[3]
|
writer = codecs.lookup("utf-8")[3]
|
||||||
|
|
||||||
|
_TYPE_FIELDS_SEQUENCE = typing.Sequence[
|
||||||
|
typing.Union[typing.Tuple[str, _TYPE_FIELD_VALUE_TUPLE], RequestField]
|
||||||
|
]
|
||||||
|
_TYPE_FIELDS = typing.Union[
|
||||||
|
_TYPE_FIELDS_SEQUENCE,
|
||||||
|
typing.Mapping[str, _TYPE_FIELD_VALUE_TUPLE],
|
||||||
|
]
|
||||||
|
|
||||||
def choose_boundary():
|
|
||||||
|
def choose_boundary() -> str:
|
||||||
"""
|
"""
|
||||||
Our embarrassingly-simple replacement for mimetools.choose_boundary.
|
Our embarrassingly-simple replacement for mimetools.choose_boundary.
|
||||||
"""
|
"""
|
||||||
boundary = binascii.hexlify(os.urandom(16))
|
return binascii.hexlify(os.urandom(16)).decode()
|
||||||
if not six.PY2:
|
|
||||||
boundary = boundary.decode("ascii")
|
|
||||||
return boundary
|
|
||||||
|
|
||||||
|
|
||||||
def iter_field_objects(fields):
|
def iter_field_objects(fields: _TYPE_FIELDS) -> typing.Iterable[RequestField]:
|
||||||
"""
|
"""
|
||||||
Iterate over fields.
|
Iterate over fields.
|
||||||
|
|
||||||
|
@ -30,42 +34,29 @@ def iter_field_objects(fields):
|
||||||
:class:`~urllib3.fields.RequestField`.
|
:class:`~urllib3.fields.RequestField`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(fields, dict):
|
iterable: typing.Iterable[RequestField | tuple[str, _TYPE_FIELD_VALUE_TUPLE]]
|
||||||
i = six.iteritems(fields)
|
|
||||||
else:
|
|
||||||
i = iter(fields)
|
|
||||||
|
|
||||||
for field in i:
|
if isinstance(fields, typing.Mapping):
|
||||||
|
iterable = fields.items()
|
||||||
|
else:
|
||||||
|
iterable = fields
|
||||||
|
|
||||||
|
for field in iterable:
|
||||||
if isinstance(field, RequestField):
|
if isinstance(field, RequestField):
|
||||||
yield field
|
yield field
|
||||||
else:
|
else:
|
||||||
yield RequestField.from_tuples(*field)
|
yield RequestField.from_tuples(*field)
|
||||||
|
|
||||||
|
|
||||||
def iter_fields(fields):
|
def encode_multipart_formdata(
|
||||||
"""
|
fields: _TYPE_FIELDS, boundary: str | None = None
|
||||||
.. deprecated:: 1.6
|
) -> tuple[bytes, str]:
|
||||||
|
|
||||||
Iterate over fields.
|
|
||||||
|
|
||||||
The addition of :class:`~urllib3.fields.RequestField` makes this function
|
|
||||||
obsolete. Instead, use :func:`iter_field_objects`, which returns
|
|
||||||
:class:`~urllib3.fields.RequestField` objects.
|
|
||||||
|
|
||||||
Supports list of (k, v) tuples and dicts.
|
|
||||||
"""
|
|
||||||
if isinstance(fields, dict):
|
|
||||||
return ((k, v) for k, v in six.iteritems(fields))
|
|
||||||
|
|
||||||
return ((k, v) for k, v in fields)
|
|
||||||
|
|
||||||
|
|
||||||
def encode_multipart_formdata(fields, boundary=None):
|
|
||||||
"""
|
"""
|
||||||
Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
|
Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
|
||||||
|
|
||||||
:param fields:
|
:param fields:
|
||||||
Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`).
|
Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`).
|
||||||
|
Values are processed by :func:`urllib3.fields.RequestField.from_tuples`.
|
||||||
|
|
||||||
:param boundary:
|
:param boundary:
|
||||||
If not specified, then a random boundary will be generated using
|
If not specified, then a random boundary will be generated using
|
||||||
|
@ -76,7 +67,7 @@ def encode_multipart_formdata(fields, boundary=None):
|
||||||
boundary = choose_boundary()
|
boundary = choose_boundary()
|
||||||
|
|
||||||
for field in iter_field_objects(fields):
|
for field in iter_field_objects(fields):
|
||||||
body.write(b("--%s\r\n" % (boundary)))
|
body.write(f"--{boundary}\r\n".encode("latin-1"))
|
||||||
|
|
||||||
writer(body).write(field.render_headers())
|
writer(body).write(field.render_headers())
|
||||||
data = field.data
|
data = field.data
|
||||||
|
@ -84,15 +75,15 @@ def encode_multipart_formdata(fields, boundary=None):
|
||||||
if isinstance(data, int):
|
if isinstance(data, int):
|
||||||
data = str(data) # Backwards compatibility
|
data = str(data) # Backwards compatibility
|
||||||
|
|
||||||
if isinstance(data, six.text_type):
|
if isinstance(data, str):
|
||||||
writer(body).write(data)
|
writer(body).write(data)
|
||||||
else:
|
else:
|
||||||
body.write(data)
|
body.write(data)
|
||||||
|
|
||||||
body.write(b"\r\n")
|
body.write(b"\r\n")
|
||||||
|
|
||||||
body.write(b("--%s--\r\n" % (boundary)))
|
body.write(f"--{boundary}--\r\n".encode("latin-1"))
|
||||||
|
|
||||||
content_type = str("multipart/form-data; boundary=%s" % boundary)
|
content_type = f"multipart/form-data; boundary={boundary}"
|
||||||
|
|
||||||
return body.getvalue(), content_type
|
return body.getvalue(), content_type
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
from typing import Any, Generator, List, Mapping, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from . import fields
|
|
||||||
|
|
||||||
RequestField = fields.RequestField
|
|
||||||
Fields = Union[Mapping[str, str], List[Tuple[str]], List[RequestField]]
|
|
||||||
Iterator = Generator[Tuple[str], None, None]
|
|
||||||
|
|
||||||
writer: Any
|
|
||||||
|
|
||||||
def choose_boundary() -> str: ...
|
|
||||||
def iter_field_objects(fields: Fields) -> Iterator: ...
|
|
||||||
def iter_fields(fields: Fields) -> Iterator: ...
|
|
||||||
def encode_multipart_formdata(
|
|
||||||
fields: Fields, boundary: Optional[str]
|
|
||||||
) -> Tuple[str]: ...
|
|
|
@ -1,51 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
backports.makefile
|
|
||||||
~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
Backports the Python 3 ``socket.makefile`` method for use with anything that
|
|
||||||
wants to create a "fake" socket object.
|
|
||||||
"""
|
|
||||||
import io
|
|
||||||
from socket import SocketIO
|
|
||||||
|
|
||||||
|
|
||||||
def backport_makefile(
|
|
||||||
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Backport of ``socket.makefile`` from Python 3.5.
|
|
||||||
"""
|
|
||||||
if not set(mode) <= {"r", "w", "b"}:
|
|
||||||
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
|
|
||||||
writing = "w" in mode
|
|
||||||
reading = "r" in mode or not writing
|
|
||||||
assert reading or writing
|
|
||||||
binary = "b" in mode
|
|
||||||
rawmode = ""
|
|
||||||
if reading:
|
|
||||||
rawmode += "r"
|
|
||||||
if writing:
|
|
||||||
rawmode += "w"
|
|
||||||
raw = SocketIO(self, rawmode)
|
|
||||||
self._makefile_refs += 1
|
|
||||||
if buffering is None:
|
|
||||||
buffering = -1
|
|
||||||
if buffering < 0:
|
|
||||||
buffering = io.DEFAULT_BUFFER_SIZE
|
|
||||||
if buffering == 0:
|
|
||||||
if not binary:
|
|
||||||
raise ValueError("unbuffered streams must be binary")
|
|
||||||
return raw
|
|
||||||
if reading and writing:
|
|
||||||
buffer = io.BufferedRWPair(raw, raw, buffering)
|
|
||||||
elif reading:
|
|
||||||
buffer = io.BufferedReader(raw, buffering)
|
|
||||||
else:
|
|
||||||
assert writing
|
|
||||||
buffer = io.BufferedWriter(raw, buffering)
|
|
||||||
if binary:
|
|
||||||
return buffer
|
|
||||||
text = io.TextIOWrapper(buffer, encoding, errors, newline)
|
|
||||||
text.mode = mode
|
|
||||||
return text
|
|
|
@ -1,155 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
backports.weakref_finalize
|
|
||||||
~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
Backports the Python 3 ``weakref.finalize`` method.
|
|
||||||
"""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
import sys
|
|
||||||
from weakref import ref
|
|
||||||
|
|
||||||
__all__ = ["weakref_finalize"]
|
|
||||||
|
|
||||||
|
|
||||||
class weakref_finalize(object):
|
|
||||||
"""Class for finalization of weakrefable objects
|
|
||||||
finalize(obj, func, *args, **kwargs) returns a callable finalizer
|
|
||||||
object which will be called when obj is garbage collected. The
|
|
||||||
first time the finalizer is called it evaluates func(*arg, **kwargs)
|
|
||||||
and returns the result. After this the finalizer is dead, and
|
|
||||||
calling it just returns None.
|
|
||||||
When the program exits any remaining finalizers for which the
|
|
||||||
atexit attribute is true will be run in reverse order of creation.
|
|
||||||
By default atexit is true.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Finalizer objects don't have any state of their own. They are
|
|
||||||
# just used as keys to lookup _Info objects in the registry. This
|
|
||||||
# ensures that they cannot be part of a ref-cycle.
|
|
||||||
|
|
||||||
__slots__ = ()
|
|
||||||
_registry = {}
|
|
||||||
_shutdown = False
|
|
||||||
_index_iter = itertools.count()
|
|
||||||
_dirty = False
|
|
||||||
_registered_with_atexit = False
|
|
||||||
|
|
||||||
class _Info(object):
|
|
||||||
__slots__ = ("weakref", "func", "args", "kwargs", "atexit", "index")
|
|
||||||
|
|
||||||
def __init__(self, obj, func, *args, **kwargs):
|
|
||||||
if not self._registered_with_atexit:
|
|
||||||
# We may register the exit function more than once because
|
|
||||||
# of a thread race, but that is harmless
|
|
||||||
import atexit
|
|
||||||
|
|
||||||
atexit.register(self._exitfunc)
|
|
||||||
weakref_finalize._registered_with_atexit = True
|
|
||||||
info = self._Info()
|
|
||||||
info.weakref = ref(obj, self)
|
|
||||||
info.func = func
|
|
||||||
info.args = args
|
|
||||||
info.kwargs = kwargs or None
|
|
||||||
info.atexit = True
|
|
||||||
info.index = next(self._index_iter)
|
|
||||||
self._registry[self] = info
|
|
||||||
weakref_finalize._dirty = True
|
|
||||||
|
|
||||||
def __call__(self, _=None):
|
|
||||||
"""If alive then mark as dead and return func(*args, **kwargs);
|
|
||||||
otherwise return None"""
|
|
||||||
info = self._registry.pop(self, None)
|
|
||||||
if info and not self._shutdown:
|
|
||||||
return info.func(*info.args, **(info.kwargs or {}))
|
|
||||||
|
|
||||||
def detach(self):
|
|
||||||
"""If alive then mark as dead and return (obj, func, args, kwargs);
|
|
||||||
otherwise return None"""
|
|
||||||
info = self._registry.get(self)
|
|
||||||
obj = info and info.weakref()
|
|
||||||
if obj is not None and self._registry.pop(self, None):
|
|
||||||
return (obj, info.func, info.args, info.kwargs or {})
|
|
||||||
|
|
||||||
def peek(self):
|
|
||||||
"""If alive then return (obj, func, args, kwargs);
|
|
||||||
otherwise return None"""
|
|
||||||
info = self._registry.get(self)
|
|
||||||
obj = info and info.weakref()
|
|
||||||
if obj is not None:
|
|
||||||
return (obj, info.func, info.args, info.kwargs or {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def alive(self):
|
|
||||||
"""Whether finalizer is alive"""
|
|
||||||
return self in self._registry
|
|
||||||
|
|
||||||
@property
|
|
||||||
def atexit(self):
|
|
||||||
"""Whether finalizer should be called at exit"""
|
|
||||||
info = self._registry.get(self)
|
|
||||||
return bool(info) and info.atexit
|
|
||||||
|
|
||||||
@atexit.setter
|
|
||||||
def atexit(self, value):
|
|
||||||
info = self._registry.get(self)
|
|
||||||
if info:
|
|
||||||
info.atexit = bool(value)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
info = self._registry.get(self)
|
|
||||||
obj = info and info.weakref()
|
|
||||||
if obj is None:
|
|
||||||
return "<%s object at %#x; dead>" % (type(self).__name__, id(self))
|
|
||||||
else:
|
|
||||||
return "<%s object at %#x; for %r at %#x>" % (
|
|
||||||
type(self).__name__,
|
|
||||||
id(self),
|
|
||||||
type(obj).__name__,
|
|
||||||
id(obj),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _select_for_exit(cls):
|
|
||||||
# Return live finalizers marked for exit, oldest first
|
|
||||||
L = [(f, i) for (f, i) in cls._registry.items() if i.atexit]
|
|
||||||
L.sort(key=lambda item: item[1].index)
|
|
||||||
return [f for (f, i) in L]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _exitfunc(cls):
|
|
||||||
# At shutdown invoke finalizers for which atexit is true.
|
|
||||||
# This is called once all other non-daemonic threads have been
|
|
||||||
# joined.
|
|
||||||
reenable_gc = False
|
|
||||||
try:
|
|
||||||
if cls._registry:
|
|
||||||
import gc
|
|
||||||
|
|
||||||
if gc.isenabled():
|
|
||||||
reenable_gc = True
|
|
||||||
gc.disable()
|
|
||||||
pending = None
|
|
||||||
while True:
|
|
||||||
if pending is None or weakref_finalize._dirty:
|
|
||||||
pending = cls._select_for_exit()
|
|
||||||
weakref_finalize._dirty = False
|
|
||||||
if not pending:
|
|
||||||
break
|
|
||||||
f = pending.pop()
|
|
||||||
try:
|
|
||||||
# gc is disabled, so (assuming no daemonic
|
|
||||||
# threads) the following is the only line in
|
|
||||||
# this function which might trigger creation
|
|
||||||
# of a new finalizer
|
|
||||||
f()
|
|
||||||
except Exception:
|
|
||||||
sys.excepthook(*sys.exc_info())
|
|
||||||
assert f not in cls._registry
|
|
||||||
finally:
|
|
||||||
# prevent any more finalizers from executing during shutdown
|
|
||||||
weakref_finalize._shutdown = True
|
|
||||||
if reenable_gc:
|
|
||||||
gc.enable()
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,24 +1,33 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
import collections
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
import typing
|
||||||
|
import warnings
|
||||||
|
from types import TracebackType
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
from ._collections import RecentlyUsedContainer
|
from ._collections import RecentlyUsedContainer
|
||||||
|
from ._request_methods import RequestMethods
|
||||||
|
from .connection import ProxyConfig
|
||||||
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, port_by_scheme
|
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, port_by_scheme
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
LocationValueError,
|
LocationValueError,
|
||||||
MaxRetryError,
|
MaxRetryError,
|
||||||
ProxySchemeUnknown,
|
ProxySchemeUnknown,
|
||||||
ProxySchemeUnsupported,
|
|
||||||
URLSchemeUnknown,
|
URLSchemeUnknown,
|
||||||
)
|
)
|
||||||
from .packages import six
|
from .response import BaseHTTPResponse
|
||||||
from .packages.six.moves.urllib.parse import urljoin
|
from .util.connection import _TYPE_SOCKET_OPTIONS
|
||||||
from .request import RequestMethods
|
|
||||||
from .util.proxy import connection_requires_http_tunnel
|
from .util.proxy import connection_requires_http_tunnel
|
||||||
from .util.retry import Retry
|
from .util.retry import Retry
|
||||||
from .util.url import parse_url
|
from .util.timeout import Timeout
|
||||||
|
from .util.url import Url, parse_url
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
|
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
|
||||||
|
|
||||||
|
@ -31,52 +40,61 @@ SSL_KEYWORDS = (
|
||||||
"cert_reqs",
|
"cert_reqs",
|
||||||
"ca_certs",
|
"ca_certs",
|
||||||
"ssl_version",
|
"ssl_version",
|
||||||
|
"ssl_minimum_version",
|
||||||
|
"ssl_maximum_version",
|
||||||
"ca_cert_dir",
|
"ca_cert_dir",
|
||||||
"ssl_context",
|
"ssl_context",
|
||||||
"key_password",
|
"key_password",
|
||||||
"server_hostname",
|
"server_hostname",
|
||||||
)
|
)
|
||||||
|
# Default value for `blocksize` - a new parameter introduced to
|
||||||
|
# http.client.HTTPConnection & http.client.HTTPSConnection in Python 3.7
|
||||||
|
_DEFAULT_BLOCKSIZE = 16384
|
||||||
|
|
||||||
# All known keyword arguments that could be provided to the pool manager, its
|
_SelfT = typing.TypeVar("_SelfT")
|
||||||
# pools, or the underlying connections. This is used to construct a pool key.
|
|
||||||
_key_fields = (
|
|
||||||
"key_scheme", # str
|
|
||||||
"key_host", # str
|
|
||||||
"key_port", # int
|
|
||||||
"key_timeout", # int or float or Timeout
|
|
||||||
"key_retries", # int or Retry
|
|
||||||
"key_strict", # bool
|
|
||||||
"key_block", # bool
|
|
||||||
"key_source_address", # str
|
|
||||||
"key_key_file", # str
|
|
||||||
"key_key_password", # str
|
|
||||||
"key_cert_file", # str
|
|
||||||
"key_cert_reqs", # str
|
|
||||||
"key_ca_certs", # str
|
|
||||||
"key_ssl_version", # str
|
|
||||||
"key_ca_cert_dir", # str
|
|
||||||
"key_ssl_context", # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
|
|
||||||
"key_maxsize", # int
|
|
||||||
"key_headers", # dict
|
|
||||||
"key__proxy", # parsed proxy url
|
|
||||||
"key__proxy_headers", # dict
|
|
||||||
"key__proxy_config", # class
|
|
||||||
"key_socket_options", # list of (level (int), optname (int), value (int or str)) tuples
|
|
||||||
"key__socks_options", # dict
|
|
||||||
"key_assert_hostname", # bool or string
|
|
||||||
"key_assert_fingerprint", # str
|
|
||||||
"key_server_hostname", # str
|
|
||||||
)
|
|
||||||
|
|
||||||
#: The namedtuple class used to construct keys for the connection pool.
|
|
||||||
#: All custom key schemes should include the fields in this key at a minimum.
|
|
||||||
PoolKey = collections.namedtuple("PoolKey", _key_fields)
|
|
||||||
|
|
||||||
_proxy_config_fields = ("ssl_context", "use_forwarding_for_https")
|
|
||||||
ProxyConfig = collections.namedtuple("ProxyConfig", _proxy_config_fields)
|
|
||||||
|
|
||||||
|
|
||||||
def _default_key_normalizer(key_class, request_context):
|
class PoolKey(typing.NamedTuple):
|
||||||
|
"""
|
||||||
|
All known keyword arguments that could be provided to the pool manager, its
|
||||||
|
pools, or the underlying connections.
|
||||||
|
|
||||||
|
All custom key schemes should include the fields in this key at a minimum.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key_scheme: str
|
||||||
|
key_host: str
|
||||||
|
key_port: int | None
|
||||||
|
key_timeout: Timeout | float | int | None
|
||||||
|
key_retries: Retry | bool | int | None
|
||||||
|
key_block: bool | None
|
||||||
|
key_source_address: tuple[str, int] | None
|
||||||
|
key_key_file: str | None
|
||||||
|
key_key_password: str | None
|
||||||
|
key_cert_file: str | None
|
||||||
|
key_cert_reqs: str | None
|
||||||
|
key_ca_certs: str | None
|
||||||
|
key_ssl_version: int | str | None
|
||||||
|
key_ssl_minimum_version: ssl.TLSVersion | None
|
||||||
|
key_ssl_maximum_version: ssl.TLSVersion | None
|
||||||
|
key_ca_cert_dir: str | None
|
||||||
|
key_ssl_context: ssl.SSLContext | None
|
||||||
|
key_maxsize: int | None
|
||||||
|
key_headers: frozenset[tuple[str, str]] | None
|
||||||
|
key__proxy: Url | None
|
||||||
|
key__proxy_headers: frozenset[tuple[str, str]] | None
|
||||||
|
key__proxy_config: ProxyConfig | None
|
||||||
|
key_socket_options: _TYPE_SOCKET_OPTIONS | None
|
||||||
|
key__socks_options: frozenset[tuple[str, str]] | None
|
||||||
|
key_assert_hostname: bool | str | None
|
||||||
|
key_assert_fingerprint: str | None
|
||||||
|
key_server_hostname: str | None
|
||||||
|
key_blocksize: int | None
|
||||||
|
|
||||||
|
|
||||||
|
def _default_key_normalizer(
|
||||||
|
key_class: type[PoolKey], request_context: dict[str, typing.Any]
|
||||||
|
) -> PoolKey:
|
||||||
"""
|
"""
|
||||||
Create a pool key out of a request context dictionary.
|
Create a pool key out of a request context dictionary.
|
||||||
|
|
||||||
|
@ -122,6 +140,10 @@ def _default_key_normalizer(key_class, request_context):
|
||||||
if field not in context:
|
if field not in context:
|
||||||
context[field] = None
|
context[field] = None
|
||||||
|
|
||||||
|
# Default key_blocksize to _DEFAULT_BLOCKSIZE if missing from the context
|
||||||
|
if context.get("key_blocksize") is None:
|
||||||
|
context["key_blocksize"] = _DEFAULT_BLOCKSIZE
|
||||||
|
|
||||||
return key_class(**context)
|
return key_class(**context)
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,23 +176,36 @@ class PoolManager(RequestMethods):
|
||||||
Additional parameters are used to create fresh
|
Additional parameters are used to create fresh
|
||||||
:class:`urllib3.connectionpool.ConnectionPool` instances.
|
:class:`urllib3.connectionpool.ConnectionPool` instances.
|
||||||
|
|
||||||
Example::
|
Example:
|
||||||
|
|
||||||
>>> manager = PoolManager(num_pools=2)
|
.. code-block:: python
|
||||||
>>> r = manager.request('GET', 'http://google.com/')
|
|
||||||
>>> r = manager.request('GET', 'http://google.com/mail')
|
import urllib3
|
||||||
>>> r = manager.request('GET', 'http://yahoo.com/')
|
|
||||||
>>> len(manager.pools)
|
http = urllib3.PoolManager(num_pools=2)
|
||||||
2
|
|
||||||
|
resp1 = http.request("GET", "https://google.com/")
|
||||||
|
resp2 = http.request("GET", "https://google.com/mail")
|
||||||
|
resp3 = http.request("GET", "https://yahoo.com/")
|
||||||
|
|
||||||
|
print(len(http.pools))
|
||||||
|
# 2
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
proxy = None
|
proxy: Url | None = None
|
||||||
proxy_config = None
|
proxy_config: ProxyConfig | None = None
|
||||||
|
|
||||||
def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
|
def __init__(
|
||||||
RequestMethods.__init__(self, headers)
|
self,
|
||||||
|
num_pools: int = 10,
|
||||||
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
|
**connection_pool_kw: typing.Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(headers)
|
||||||
self.connection_pool_kw = connection_pool_kw
|
self.connection_pool_kw = connection_pool_kw
|
||||||
|
|
||||||
|
self.pools: RecentlyUsedContainer[PoolKey, HTTPConnectionPool]
|
||||||
self.pools = RecentlyUsedContainer(num_pools)
|
self.pools = RecentlyUsedContainer(num_pools)
|
||||||
|
|
||||||
# Locally set the pool classes and keys so other PoolManagers can
|
# Locally set the pool classes and keys so other PoolManagers can
|
||||||
|
@ -178,15 +213,26 @@ class PoolManager(RequestMethods):
|
||||||
self.pool_classes_by_scheme = pool_classes_by_scheme
|
self.pool_classes_by_scheme = pool_classes_by_scheme
|
||||||
self.key_fn_by_scheme = key_fn_by_scheme.copy()
|
self.key_fn_by_scheme = key_fn_by_scheme.copy()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self: _SelfT) -> _SelfT:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> Literal[False]:
|
||||||
self.clear()
|
self.clear()
|
||||||
# Return False to re-raise any potential exceptions
|
# Return False to re-raise any potential exceptions
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _new_pool(self, scheme, host, port, request_context=None):
|
def _new_pool(
|
||||||
|
self,
|
||||||
|
scheme: str,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
request_context: dict[str, typing.Any] | None = None,
|
||||||
|
) -> HTTPConnectionPool:
|
||||||
"""
|
"""
|
||||||
Create a new :class:`urllib3.connectionpool.ConnectionPool` based on host, port, scheme, and
|
Create a new :class:`urllib3.connectionpool.ConnectionPool` based on host, port, scheme, and
|
||||||
any additional pool keyword arguments.
|
any additional pool keyword arguments.
|
||||||
|
@ -196,10 +242,15 @@ class PoolManager(RequestMethods):
|
||||||
connection pools handed out by :meth:`connection_from_url` and
|
connection pools handed out by :meth:`connection_from_url` and
|
||||||
companion methods. It is intended to be overridden for customization.
|
companion methods. It is intended to be overridden for customization.
|
||||||
"""
|
"""
|
||||||
pool_cls = self.pool_classes_by_scheme[scheme]
|
pool_cls: type[HTTPConnectionPool] = self.pool_classes_by_scheme[scheme]
|
||||||
if request_context is None:
|
if request_context is None:
|
||||||
request_context = self.connection_pool_kw.copy()
|
request_context = self.connection_pool_kw.copy()
|
||||||
|
|
||||||
|
# Default blocksize to _DEFAULT_BLOCKSIZE if missing or explicitly
|
||||||
|
# set to 'None' in the request_context.
|
||||||
|
if request_context.get("blocksize") is None:
|
||||||
|
request_context["blocksize"] = _DEFAULT_BLOCKSIZE
|
||||||
|
|
||||||
# Although the context has everything necessary to create the pool,
|
# Although the context has everything necessary to create the pool,
|
||||||
# this function has historically only used the scheme, host, and port
|
# this function has historically only used the scheme, host, and port
|
||||||
# in the positional args. When an API change is acceptable these can
|
# in the positional args. When an API change is acceptable these can
|
||||||
|
@ -213,7 +264,7 @@ class PoolManager(RequestMethods):
|
||||||
|
|
||||||
return pool_cls(host, port, **request_context)
|
return pool_cls(host, port, **request_context)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> None:
|
||||||
"""
|
"""
|
||||||
Empty our store of pools and direct them all to close.
|
Empty our store of pools and direct them all to close.
|
||||||
|
|
||||||
|
@ -222,7 +273,13 @@ class PoolManager(RequestMethods):
|
||||||
"""
|
"""
|
||||||
self.pools.clear()
|
self.pools.clear()
|
||||||
|
|
||||||
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
|
def connection_from_host(
|
||||||
|
self,
|
||||||
|
host: str | None,
|
||||||
|
port: int | None = None,
|
||||||
|
scheme: str | None = "http",
|
||||||
|
pool_kwargs: dict[str, typing.Any] | None = None,
|
||||||
|
) -> HTTPConnectionPool:
|
||||||
"""
|
"""
|
||||||
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the host, port, and scheme.
|
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the host, port, and scheme.
|
||||||
|
|
||||||
|
@ -245,13 +302,23 @@ class PoolManager(RequestMethods):
|
||||||
|
|
||||||
return self.connection_from_context(request_context)
|
return self.connection_from_context(request_context)
|
||||||
|
|
||||||
def connection_from_context(self, request_context):
|
def connection_from_context(
|
||||||
|
self, request_context: dict[str, typing.Any]
|
||||||
|
) -> HTTPConnectionPool:
|
||||||
"""
|
"""
|
||||||
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the request context.
|
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the request context.
|
||||||
|
|
||||||
``request_context`` must at least contain the ``scheme`` key and its
|
``request_context`` must at least contain the ``scheme`` key and its
|
||||||
value must be a key in ``key_fn_by_scheme`` instance variable.
|
value must be a key in ``key_fn_by_scheme`` instance variable.
|
||||||
"""
|
"""
|
||||||
|
if "strict" in request_context:
|
||||||
|
warnings.warn(
|
||||||
|
"The 'strict' parameter is no longer needed on Python 3+. "
|
||||||
|
"This will raise an error in urllib3 v2.1.0.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
request_context.pop("strict")
|
||||||
|
|
||||||
scheme = request_context["scheme"].lower()
|
scheme = request_context["scheme"].lower()
|
||||||
pool_key_constructor = self.key_fn_by_scheme.get(scheme)
|
pool_key_constructor = self.key_fn_by_scheme.get(scheme)
|
||||||
if not pool_key_constructor:
|
if not pool_key_constructor:
|
||||||
|
@ -260,7 +327,9 @@ class PoolManager(RequestMethods):
|
||||||
|
|
||||||
return self.connection_from_pool_key(pool_key, request_context=request_context)
|
return self.connection_from_pool_key(pool_key, request_context=request_context)
|
||||||
|
|
||||||
def connection_from_pool_key(self, pool_key, request_context=None):
|
def connection_from_pool_key(
|
||||||
|
self, pool_key: PoolKey, request_context: dict[str, typing.Any]
|
||||||
|
) -> HTTPConnectionPool:
|
||||||
"""
|
"""
|
||||||
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the provided pool key.
|
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the provided pool key.
|
||||||
|
|
||||||
|
@ -284,7 +353,9 @@ class PoolManager(RequestMethods):
|
||||||
|
|
||||||
return pool
|
return pool
|
||||||
|
|
||||||
def connection_from_url(self, url, pool_kwargs=None):
|
def connection_from_url(
|
||||||
|
self, url: str, pool_kwargs: dict[str, typing.Any] | None = None
|
||||||
|
) -> HTTPConnectionPool:
|
||||||
"""
|
"""
|
||||||
Similar to :func:`urllib3.connectionpool.connection_from_url`.
|
Similar to :func:`urllib3.connectionpool.connection_from_url`.
|
||||||
|
|
||||||
|
@ -300,7 +371,9 @@ class PoolManager(RequestMethods):
|
||||||
u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs
|
u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def _merge_pool_kwargs(self, override):
|
def _merge_pool_kwargs(
|
||||||
|
self, override: dict[str, typing.Any] | None
|
||||||
|
) -> dict[str, typing.Any]:
|
||||||
"""
|
"""
|
||||||
Merge a dictionary of override values for self.connection_pool_kw.
|
Merge a dictionary of override values for self.connection_pool_kw.
|
||||||
|
|
||||||
|
@ -320,7 +393,7 @@ class PoolManager(RequestMethods):
|
||||||
base_pool_kwargs[key] = value
|
base_pool_kwargs[key] = value
|
||||||
return base_pool_kwargs
|
return base_pool_kwargs
|
||||||
|
|
||||||
def _proxy_requires_url_absolute_form(self, parsed_url):
|
def _proxy_requires_url_absolute_form(self, parsed_url: Url) -> bool:
|
||||||
"""
|
"""
|
||||||
Indicates if the proxy requires the complete destination URL in the
|
Indicates if the proxy requires the complete destination URL in the
|
||||||
request. Normally this is only needed when not using an HTTP CONNECT
|
request. Normally this is only needed when not using an HTTP CONNECT
|
||||||
|
@ -333,24 +406,9 @@ class PoolManager(RequestMethods):
|
||||||
self.proxy, self.proxy_config, parsed_url.scheme
|
self.proxy, self.proxy_config, parsed_url.scheme
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_proxy_scheme_url_selection(self, url_scheme):
|
def urlopen( # type: ignore[override]
|
||||||
"""
|
self, method: str, url: str, redirect: bool = True, **kw: typing.Any
|
||||||
Validates that were not attempting to do TLS in TLS connections on
|
) -> BaseHTTPResponse:
|
||||||
Python2 or with unsupported SSL implementations.
|
|
||||||
"""
|
|
||||||
if self.proxy is None or url_scheme != "https":
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.proxy.scheme != "https":
|
|
||||||
return
|
|
||||||
|
|
||||||
if six.PY2 and not self.proxy_config.use_forwarding_for_https:
|
|
||||||
raise ProxySchemeUnsupported(
|
|
||||||
"Contacting HTTPS destinations through HTTPS proxies "
|
|
||||||
"'via CONNECT tunnels' is not supported in Python 2"
|
|
||||||
)
|
|
||||||
|
|
||||||
def urlopen(self, method, url, redirect=True, **kw):
|
|
||||||
"""
|
"""
|
||||||
Same as :meth:`urllib3.HTTPConnectionPool.urlopen`
|
Same as :meth:`urllib3.HTTPConnectionPool.urlopen`
|
||||||
with custom cross-host redirect logic and only sends the request-uri
|
with custom cross-host redirect logic and only sends the request-uri
|
||||||
|
@ -360,7 +418,16 @@ class PoolManager(RequestMethods):
|
||||||
:class:`urllib3.connectionpool.ConnectionPool` can be chosen for it.
|
:class:`urllib3.connectionpool.ConnectionPool` can be chosen for it.
|
||||||
"""
|
"""
|
||||||
u = parse_url(url)
|
u = parse_url(url)
|
||||||
self._validate_proxy_scheme_url_selection(u.scheme)
|
|
||||||
|
if u.scheme is None:
|
||||||
|
warnings.warn(
|
||||||
|
"URLs without a scheme (ie 'https://') are deprecated and will raise an error "
|
||||||
|
"in a future version of urllib3. To avoid this DeprecationWarning ensure all URLs "
|
||||||
|
"start with 'https://' or 'http://'. Read more in this issue: "
|
||||||
|
"https://github.com/urllib3/urllib3/issues/2920",
|
||||||
|
category=DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme)
|
conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme)
|
||||||
|
|
||||||
|
@ -368,7 +435,7 @@ class PoolManager(RequestMethods):
|
||||||
kw["redirect"] = False
|
kw["redirect"] = False
|
||||||
|
|
||||||
if "headers" not in kw:
|
if "headers" not in kw:
|
||||||
kw["headers"] = self.headers.copy()
|
kw["headers"] = self.headers
|
||||||
|
|
||||||
if self._proxy_requires_url_absolute_form(u):
|
if self._proxy_requires_url_absolute_form(u):
|
||||||
response = conn.urlopen(method, url, **kw)
|
response = conn.urlopen(method, url, **kw)
|
||||||
|
@ -396,10 +463,11 @@ class PoolManager(RequestMethods):
|
||||||
if retries.remove_headers_on_redirect and not conn.is_same_host(
|
if retries.remove_headers_on_redirect and not conn.is_same_host(
|
||||||
redirect_location
|
redirect_location
|
||||||
):
|
):
|
||||||
headers = list(six.iterkeys(kw["headers"]))
|
new_headers = kw["headers"].copy()
|
||||||
for header in headers:
|
for header in kw["headers"]:
|
||||||
if header.lower() in retries.remove_headers_on_redirect:
|
if header.lower() in retries.remove_headers_on_redirect:
|
||||||
kw["headers"].pop(header, None)
|
new_headers.pop(header, None)
|
||||||
|
kw["headers"] = new_headers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
retries = retries.increment(method, url, response=response, _pool=conn)
|
retries = retries.increment(method, url, response=response, _pool=conn)
|
||||||
|
@ -445,37 +513,51 @@ class ProxyManager(PoolManager):
|
||||||
private. IP address, target hostname, SNI, and port are always visible
|
private. IP address, target hostname, SNI, and port are always visible
|
||||||
to an HTTPS proxy even when this flag is disabled.
|
to an HTTPS proxy even when this flag is disabled.
|
||||||
|
|
||||||
|
:param proxy_assert_hostname:
|
||||||
|
The hostname of the certificate to verify against.
|
||||||
|
|
||||||
|
:param proxy_assert_fingerprint:
|
||||||
|
The fingerprint of the certificate to verify against.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> proxy = urllib3.ProxyManager('http://localhost:3128/')
|
|
||||||
>>> r1 = proxy.request('GET', 'http://google.com/')
|
.. code-block:: python
|
||||||
>>> r2 = proxy.request('GET', 'http://httpbin.org/')
|
|
||||||
>>> len(proxy.pools)
|
import urllib3
|
||||||
1
|
|
||||||
>>> r3 = proxy.request('GET', 'https://httpbin.org/')
|
proxy = urllib3.ProxyManager("https://localhost:3128/")
|
||||||
>>> r4 = proxy.request('GET', 'https://twitter.com/')
|
|
||||||
>>> len(proxy.pools)
|
resp1 = proxy.request("GET", "https://google.com/")
|
||||||
3
|
resp2 = proxy.request("GET", "https://httpbin.org/")
|
||||||
|
|
||||||
|
print(len(proxy.pools))
|
||||||
|
# 1
|
||||||
|
|
||||||
|
resp3 = proxy.request("GET", "https://httpbin.org/")
|
||||||
|
resp4 = proxy.request("GET", "https://twitter.com/")
|
||||||
|
|
||||||
|
print(len(proxy.pools))
|
||||||
|
# 3
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
proxy_url,
|
proxy_url: str,
|
||||||
num_pools=10,
|
num_pools: int = 10,
|
||||||
headers=None,
|
headers: typing.Mapping[str, str] | None = None,
|
||||||
proxy_headers=None,
|
proxy_headers: typing.Mapping[str, str] | None = None,
|
||||||
proxy_ssl_context=None,
|
proxy_ssl_context: ssl.SSLContext | None = None,
|
||||||
use_forwarding_for_https=False,
|
use_forwarding_for_https: bool = False,
|
||||||
**connection_pool_kw
|
proxy_assert_hostname: None | str | Literal[False] = None,
|
||||||
):
|
proxy_assert_fingerprint: str | None = None,
|
||||||
|
**connection_pool_kw: typing.Any,
|
||||||
|
) -> None:
|
||||||
if isinstance(proxy_url, HTTPConnectionPool):
|
if isinstance(proxy_url, HTTPConnectionPool):
|
||||||
proxy_url = "%s://%s:%i" % (
|
str_proxy_url = f"{proxy_url.scheme}://{proxy_url.host}:{proxy_url.port}"
|
||||||
proxy_url.scheme,
|
else:
|
||||||
proxy_url.host,
|
str_proxy_url = proxy_url
|
||||||
proxy_url.port,
|
proxy = parse_url(str_proxy_url)
|
||||||
)
|
|
||||||
proxy = parse_url(proxy_url)
|
|
||||||
|
|
||||||
if proxy.scheme not in ("http", "https"):
|
if proxy.scheme not in ("http", "https"):
|
||||||
raise ProxySchemeUnknown(proxy.scheme)
|
raise ProxySchemeUnknown(proxy.scheme)
|
||||||
|
@ -487,25 +569,38 @@ class ProxyManager(PoolManager):
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
self.proxy_headers = proxy_headers or {}
|
self.proxy_headers = proxy_headers or {}
|
||||||
self.proxy_ssl_context = proxy_ssl_context
|
self.proxy_ssl_context = proxy_ssl_context
|
||||||
self.proxy_config = ProxyConfig(proxy_ssl_context, use_forwarding_for_https)
|
self.proxy_config = ProxyConfig(
|
||||||
|
proxy_ssl_context,
|
||||||
|
use_forwarding_for_https,
|
||||||
|
proxy_assert_hostname,
|
||||||
|
proxy_assert_fingerprint,
|
||||||
|
)
|
||||||
|
|
||||||
connection_pool_kw["_proxy"] = self.proxy
|
connection_pool_kw["_proxy"] = self.proxy
|
||||||
connection_pool_kw["_proxy_headers"] = self.proxy_headers
|
connection_pool_kw["_proxy_headers"] = self.proxy_headers
|
||||||
connection_pool_kw["_proxy_config"] = self.proxy_config
|
connection_pool_kw["_proxy_config"] = self.proxy_config
|
||||||
|
|
||||||
super(ProxyManager, self).__init__(num_pools, headers, **connection_pool_kw)
|
super().__init__(num_pools, headers, **connection_pool_kw)
|
||||||
|
|
||||||
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
|
def connection_from_host(
|
||||||
|
self,
|
||||||
|
host: str | None,
|
||||||
|
port: int | None = None,
|
||||||
|
scheme: str | None = "http",
|
||||||
|
pool_kwargs: dict[str, typing.Any] | None = None,
|
||||||
|
) -> HTTPConnectionPool:
|
||||||
if scheme == "https":
|
if scheme == "https":
|
||||||
return super(ProxyManager, self).connection_from_host(
|
return super().connection_from_host(
|
||||||
host, port, scheme, pool_kwargs=pool_kwargs
|
host, port, scheme, pool_kwargs=pool_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return super(ProxyManager, self).connection_from_host(
|
return super().connection_from_host(
|
||||||
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs
|
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs # type: ignore[union-attr]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_proxy_headers(self, url, headers=None):
|
def _set_proxy_headers(
|
||||||
|
self, url: str, headers: typing.Mapping[str, str] | None = None
|
||||||
|
) -> typing.Mapping[str, str]:
|
||||||
"""
|
"""
|
||||||
Sets headers needed by proxies: specifically, the Accept and Host
|
Sets headers needed by proxies: specifically, the Accept and Host
|
||||||
headers. Only sets headers not provided by the user.
|
headers. Only sets headers not provided by the user.
|
||||||
|
@ -520,7 +615,9 @@ class ProxyManager(PoolManager):
|
||||||
headers_.update(headers)
|
headers_.update(headers)
|
||||||
return headers_
|
return headers_
|
||||||
|
|
||||||
def urlopen(self, method, url, redirect=True, **kw):
|
def urlopen( # type: ignore[override]
|
||||||
|
self, method: str, url: str, redirect: bool = True, **kw: typing.Any
|
||||||
|
) -> BaseHTTPResponse:
|
||||||
"Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute."
|
"Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute."
|
||||||
u = parse_url(url)
|
u = parse_url(url)
|
||||||
if not connection_requires_http_tunnel(self.proxy, self.proxy_config, u.scheme):
|
if not connection_requires_http_tunnel(self.proxy, self.proxy_config, u.scheme):
|
||||||
|
@ -530,8 +627,8 @@ class ProxyManager(PoolManager):
|
||||||
headers = kw.get("headers", self.headers)
|
headers = kw.get("headers", self.headers)
|
||||||
kw["headers"] = self._set_proxy_headers(url, headers)
|
kw["headers"] = self._set_proxy_headers(url, headers)
|
||||||
|
|
||||||
return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw)
|
return super().urlopen(method, url, redirect=redirect, **kw)
|
||||||
|
|
||||||
|
|
||||||
def proxy_from_url(url, **kw):
|
def proxy_from_url(url: str, **kw: typing.Any) -> ProxyManager:
|
||||||
return ProxyManager(proxy_url=url, **kw)
|
return ProxyManager(proxy_url=url, **kw)
|
||||||
|
|
2
lib/urllib3/py.typed
Normal file
2
lib/urllib3/py.typed
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
# Instruct type checkers to look for inline type annotations in this package.
|
||||||
|
# See PEP 561.
|
File diff suppressed because it is too large
Load diff
|
@ -1,46 +1,41 @@
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
# For backwards compatibility, provide imports that used to be here.
|
# For backwards compatibility, provide imports that used to be here.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from .connection import is_connection_dropped
|
from .connection import is_connection_dropped
|
||||||
from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers
|
from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers
|
||||||
from .response import is_fp_closed
|
from .response import is_fp_closed
|
||||||
from .retry import Retry
|
from .retry import Retry
|
||||||
from .ssl_ import (
|
from .ssl_ import (
|
||||||
ALPN_PROTOCOLS,
|
ALPN_PROTOCOLS,
|
||||||
HAS_SNI,
|
|
||||||
IS_PYOPENSSL,
|
IS_PYOPENSSL,
|
||||||
IS_SECURETRANSPORT,
|
IS_SECURETRANSPORT,
|
||||||
PROTOCOL_TLS,
|
|
||||||
SSLContext,
|
SSLContext,
|
||||||
assert_fingerprint,
|
assert_fingerprint,
|
||||||
|
create_urllib3_context,
|
||||||
resolve_cert_reqs,
|
resolve_cert_reqs,
|
||||||
resolve_ssl_version,
|
resolve_ssl_version,
|
||||||
ssl_wrap_socket,
|
ssl_wrap_socket,
|
||||||
)
|
)
|
||||||
from .timeout import Timeout, current_time
|
from .timeout import Timeout
|
||||||
from .url import Url, get_host, parse_url, split_first
|
from .url import Url, parse_url
|
||||||
from .wait import wait_for_read, wait_for_write
|
from .wait import wait_for_read, wait_for_write
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"HAS_SNI",
|
|
||||||
"IS_PYOPENSSL",
|
"IS_PYOPENSSL",
|
||||||
"IS_SECURETRANSPORT",
|
"IS_SECURETRANSPORT",
|
||||||
"SSLContext",
|
"SSLContext",
|
||||||
"PROTOCOL_TLS",
|
|
||||||
"ALPN_PROTOCOLS",
|
"ALPN_PROTOCOLS",
|
||||||
"Retry",
|
"Retry",
|
||||||
"Timeout",
|
"Timeout",
|
||||||
"Url",
|
"Url",
|
||||||
"assert_fingerprint",
|
"assert_fingerprint",
|
||||||
"current_time",
|
"create_urllib3_context",
|
||||||
"is_connection_dropped",
|
"is_connection_dropped",
|
||||||
"is_fp_closed",
|
"is_fp_closed",
|
||||||
"get_host",
|
|
||||||
"parse_url",
|
"parse_url",
|
||||||
"make_headers",
|
"make_headers",
|
||||||
"resolve_cert_reqs",
|
"resolve_cert_reqs",
|
||||||
"resolve_ssl_version",
|
"resolve_ssl_version",
|
||||||
"split_first",
|
|
||||||
"ssl_wrap_socket",
|
"ssl_wrap_socket",
|
||||||
"wait_for_read",
|
"wait_for_read",
|
||||||
"wait_for_write",
|
"wait_for_write",
|
||||||
|
|
|
@ -1,33 +1,23 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
import socket
|
import socket
|
||||||
|
import typing
|
||||||
|
|
||||||
from ..contrib import _appengine_environ
|
|
||||||
from ..exceptions import LocationParseError
|
from ..exceptions import LocationParseError
|
||||||
from ..packages import six
|
from .timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT
|
||||||
from .wait import NoWayToWaitForSocketError, wait_for_read
|
|
||||||
|
_TYPE_SOCKET_OPTIONS = typing.Sequence[typing.Tuple[int, int, typing.Union[int, bytes]]]
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from .._base_connection import BaseHTTPConnection
|
||||||
|
|
||||||
|
|
||||||
def is_connection_dropped(conn): # Platform-specific
|
def is_connection_dropped(conn: BaseHTTPConnection) -> bool: # Platform-specific
|
||||||
"""
|
"""
|
||||||
Returns True if the connection is dropped and should be closed.
|
Returns True if the connection is dropped and should be closed.
|
||||||
|
:param conn: :class:`urllib3.connection.HTTPConnection` object.
|
||||||
:param conn:
|
|
||||||
:class:`http.client.HTTPConnection` object.
|
|
||||||
|
|
||||||
Note: For platforms like AppEngine, this will always return ``False`` to
|
|
||||||
let the platform handle connection recycling transparently for us.
|
|
||||||
"""
|
"""
|
||||||
sock = getattr(conn, "sock", False)
|
return not conn.is_connected
|
||||||
if sock is False: # Platform-specific: AppEngine
|
|
||||||
return False
|
|
||||||
if sock is None: # Connection already closed (such as by httplib).
|
|
||||||
return True
|
|
||||||
try:
|
|
||||||
# Returns True if readable, which here means it's been dropped
|
|
||||||
return wait_for_read(sock, timeout=0.0)
|
|
||||||
except NoWayToWaitForSocketError: # Platform-specific: AppEngine
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# This function is copied from socket.py in the Python 2.7 standard
|
# This function is copied from socket.py in the Python 2.7 standard
|
||||||
|
@ -35,11 +25,11 @@ def is_connection_dropped(conn): # Platform-specific
|
||||||
# One additional modification is that we avoid binding to IPv6 servers
|
# One additional modification is that we avoid binding to IPv6 servers
|
||||||
# discovered in DNS if the system doesn't have IPv6 functionality.
|
# discovered in DNS if the system doesn't have IPv6 functionality.
|
||||||
def create_connection(
|
def create_connection(
|
||||||
address,
|
address: tuple[str, int],
|
||||||
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
|
timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
|
||||||
source_address=None,
|
source_address: tuple[str, int] | None = None,
|
||||||
socket_options=None,
|
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
|
||||||
):
|
) -> socket.socket:
|
||||||
"""Connect to *address* and return the socket object.
|
"""Connect to *address* and return the socket object.
|
||||||
|
|
||||||
Convenience function. Connect to *address* (a 2-tuple ``(host,
|
Convenience function. Connect to *address* (a 2-tuple ``(host,
|
||||||
|
@ -65,9 +55,7 @@ def create_connection(
|
||||||
try:
|
try:
|
||||||
host.encode("idna")
|
host.encode("idna")
|
||||||
except UnicodeError:
|
except UnicodeError:
|
||||||
return six.raise_from(
|
raise LocationParseError(f"'{host}', label empty or too long") from None
|
||||||
LocationParseError(u"'%s', label empty or too long" % host), None
|
|
||||||
)
|
|
||||||
|
|
||||||
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
|
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
|
||||||
af, socktype, proto, canonname, sa = res
|
af, socktype, proto, canonname, sa = res
|
||||||
|
@ -78,26 +66,33 @@ def create_connection(
|
||||||
# If provided, set socket level options before connecting.
|
# If provided, set socket level options before connecting.
|
||||||
_set_socket_options(sock, socket_options)
|
_set_socket_options(sock, socket_options)
|
||||||
|
|
||||||
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
|
if timeout is not _DEFAULT_TIMEOUT:
|
||||||
sock.settimeout(timeout)
|
sock.settimeout(timeout)
|
||||||
if source_address:
|
if source_address:
|
||||||
sock.bind(source_address)
|
sock.bind(source_address)
|
||||||
sock.connect(sa)
|
sock.connect(sa)
|
||||||
|
# Break explicitly a reference cycle
|
||||||
|
err = None
|
||||||
return sock
|
return sock
|
||||||
|
|
||||||
except socket.error as e:
|
except OSError as _:
|
||||||
err = e
|
err = _
|
||||||
if sock is not None:
|
if sock is not None:
|
||||||
sock.close()
|
sock.close()
|
||||||
sock = None
|
|
||||||
|
|
||||||
if err is not None:
|
if err is not None:
|
||||||
raise err
|
try:
|
||||||
|
raise err
|
||||||
raise socket.error("getaddrinfo returns an empty list")
|
finally:
|
||||||
|
# Break explicitly a reference cycle
|
||||||
|
err = None
|
||||||
|
else:
|
||||||
|
raise OSError("getaddrinfo returns an empty list")
|
||||||
|
|
||||||
|
|
||||||
def _set_socket_options(sock, options):
|
def _set_socket_options(
|
||||||
|
sock: socket.socket, options: _TYPE_SOCKET_OPTIONS | None
|
||||||
|
) -> None:
|
||||||
if options is None:
|
if options is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -105,7 +100,7 @@ def _set_socket_options(sock, options):
|
||||||
sock.setsockopt(*opt)
|
sock.setsockopt(*opt)
|
||||||
|
|
||||||
|
|
||||||
def allowed_gai_family():
|
def allowed_gai_family() -> socket.AddressFamily:
|
||||||
"""This function is designed to work in the context of
|
"""This function is designed to work in the context of
|
||||||
getaddrinfo, where family=socket.AF_UNSPEC is the default and
|
getaddrinfo, where family=socket.AF_UNSPEC is the default and
|
||||||
will perform a DNS search for both IPv6 and IPv4 records."""
|
will perform a DNS search for both IPv6 and IPv4 records."""
|
||||||
|
@ -116,18 +111,11 @@ def allowed_gai_family():
|
||||||
return family
|
return family
|
||||||
|
|
||||||
|
|
||||||
def _has_ipv6(host):
|
def _has_ipv6(host: str) -> bool:
|
||||||
"""Returns True if the system can bind an IPv6 address."""
|
"""Returns True if the system can bind an IPv6 address."""
|
||||||
sock = None
|
sock = None
|
||||||
has_ipv6 = False
|
has_ipv6 = False
|
||||||
|
|
||||||
# App Engine doesn't support IPV6 sockets and actually has a quota on the
|
|
||||||
# number of sockets that can be used, so just early out here instead of
|
|
||||||
# creating a socket needlessly.
|
|
||||||
# See https://github.com/urllib3/urllib3/issues/1446
|
|
||||||
if _appengine_environ.is_appengine_sandbox():
|
|
||||||
return False
|
|
||||||
|
|
||||||
if socket.has_ipv6:
|
if socket.has_ipv6:
|
||||||
# has_ipv6 returns true if cPython was compiled with IPv6 support.
|
# has_ipv6 returns true if cPython was compiled with IPv6 support.
|
||||||
# It does not tell us if the system has IPv6 support enabled. To
|
# It does not tell us if the system has IPv6 support enabled. To
|
||||||
|
|
|
@ -1,9 +1,18 @@
|
||||||
from .ssl_ import create_urllib3_context, resolve_cert_reqs, resolve_ssl_version
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from .url import Url
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from ..connection import ProxyConfig
|
||||||
|
|
||||||
|
|
||||||
def connection_requires_http_tunnel(
|
def connection_requires_http_tunnel(
|
||||||
proxy_url=None, proxy_config=None, destination_scheme=None
|
proxy_url: Url | None = None,
|
||||||
):
|
proxy_config: ProxyConfig | None = None,
|
||||||
|
destination_scheme: str | None = None,
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns True if the connection requires an HTTP CONNECT through the proxy.
|
Returns True if the connection requires an HTTP CONNECT through the proxy.
|
||||||
|
|
||||||
|
@ -32,26 +41,3 @@ def connection_requires_http_tunnel(
|
||||||
|
|
||||||
# Otherwise always use a tunnel.
|
# Otherwise always use a tunnel.
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def create_proxy_ssl_context(
|
|
||||||
ssl_version, cert_reqs, ca_certs=None, ca_cert_dir=None, ca_cert_data=None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Generates a default proxy ssl context if one hasn't been provided by the
|
|
||||||
user.
|
|
||||||
"""
|
|
||||||
ssl_context = create_urllib3_context(
|
|
||||||
ssl_version=resolve_ssl_version(ssl_version),
|
|
||||||
cert_reqs=resolve_cert_reqs(cert_reqs),
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not ca_certs
|
|
||||||
and not ca_cert_dir
|
|
||||||
and not ca_cert_data
|
|
||||||
and hasattr(ssl_context, "load_default_certs")
|
|
||||||
):
|
|
||||||
ssl_context.load_default_certs()
|
|
||||||
|
|
||||||
return ssl_context
|
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
import collections
|
|
||||||
|
|
||||||
from ..packages import six
|
|
||||||
from ..packages.six.moves import queue
|
|
||||||
|
|
||||||
if six.PY2:
|
|
||||||
# Queue is imported for side effects on MS Windows. See issue #229.
|
|
||||||
import Queue as _unused_module_Queue # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
class LifoQueue(queue.Queue):
|
|
||||||
def _init(self, _):
|
|
||||||
self.queue = collections.deque()
|
|
||||||
|
|
||||||
def _qsize(self, len=len):
|
|
||||||
return len(self.queue)
|
|
||||||
|
|
||||||
def _put(self, item):
|
|
||||||
self.queue.append(item)
|
|
||||||
|
|
||||||
def _get(self):
|
|
||||||
return self.queue.pop()
|
|
|
@ -1,9 +1,15 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import typing
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from ..exceptions import UnrewindableBodyError
|
from ..exceptions import UnrewindableBodyError
|
||||||
from ..packages.six import b, integer_types
|
from .util import to_bytes
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from typing_extensions import Final
|
||||||
|
|
||||||
# Pass as a value within ``headers`` to skip
|
# Pass as a value within ``headers`` to skip
|
||||||
# emitting some HTTP headers that are added automatically.
|
# emitting some HTTP headers that are added automatically.
|
||||||
|
@ -15,25 +21,45 @@ SKIPPABLE_HEADERS = frozenset(["accept-encoding", "host", "user-agent"])
|
||||||
ACCEPT_ENCODING = "gzip,deflate"
|
ACCEPT_ENCODING = "gzip,deflate"
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
import brotlicffi as _unused_module_brotli # noqa: F401
|
import brotlicffi as _unused_module_brotli # type: ignore[import] # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import brotli as _unused_module_brotli # noqa: F401
|
import brotli as _unused_module_brotli # type: ignore[import] # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
ACCEPT_ENCODING += ",br"
|
ACCEPT_ENCODING += ",br"
|
||||||
|
try:
|
||||||
|
import zstandard as _unused_module_zstd # type: ignore[import] # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
ACCEPT_ENCODING += ",zstd"
|
||||||
|
|
||||||
_FAILEDTELL = object()
|
|
||||||
|
class _TYPE_FAILEDTELL(Enum):
|
||||||
|
token = 0
|
||||||
|
|
||||||
|
|
||||||
|
_FAILEDTELL: Final[_TYPE_FAILEDTELL] = _TYPE_FAILEDTELL.token
|
||||||
|
|
||||||
|
_TYPE_BODY_POSITION = typing.Union[int, _TYPE_FAILEDTELL]
|
||||||
|
|
||||||
|
# When sending a request with these methods we aren't expecting
|
||||||
|
# a body so don't need to set an explicit 'Content-Length: 0'
|
||||||
|
# The reason we do this in the negative instead of tracking methods
|
||||||
|
# which 'should' have a body is because unknown methods should be
|
||||||
|
# treated as if they were 'POST' which *does* expect a body.
|
||||||
|
_METHODS_NOT_EXPECTING_BODY = {"GET", "HEAD", "DELETE", "TRACE", "OPTIONS", "CONNECT"}
|
||||||
|
|
||||||
|
|
||||||
def make_headers(
|
def make_headers(
|
||||||
keep_alive=None,
|
keep_alive: bool | None = None,
|
||||||
accept_encoding=None,
|
accept_encoding: bool | list[str] | str | None = None,
|
||||||
user_agent=None,
|
user_agent: str | None = None,
|
||||||
basic_auth=None,
|
basic_auth: str | None = None,
|
||||||
proxy_basic_auth=None,
|
proxy_basic_auth: str | None = None,
|
||||||
disable_cache=None,
|
disable_cache: bool | None = None,
|
||||||
):
|
) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Shortcuts for generating request headers.
|
Shortcuts for generating request headers.
|
||||||
|
|
||||||
|
@ -42,7 +68,8 @@ def make_headers(
|
||||||
|
|
||||||
:param accept_encoding:
|
:param accept_encoding:
|
||||||
Can be a boolean, list, or string.
|
Can be a boolean, list, or string.
|
||||||
``True`` translates to 'gzip,deflate'.
|
``True`` translates to 'gzip,deflate'. If either the ``brotli`` or
|
||||||
|
``brotlicffi`` package is installed 'gzip,deflate,br' is used instead.
|
||||||
List will get joined by comma.
|
List will get joined by comma.
|
||||||
String will be used as provided.
|
String will be used as provided.
|
||||||
|
|
||||||
|
@ -61,14 +88,18 @@ def make_headers(
|
||||||
:param disable_cache:
|
:param disable_cache:
|
||||||
If ``True``, adds 'cache-control: no-cache' header.
|
If ``True``, adds 'cache-control: no-cache' header.
|
||||||
|
|
||||||
Example::
|
Example:
|
||||||
|
|
||||||
>>> make_headers(keep_alive=True, user_agent="Batman/1.0")
|
.. code-block:: python
|
||||||
{'connection': 'keep-alive', 'user-agent': 'Batman/1.0'}
|
|
||||||
>>> make_headers(accept_encoding=True)
|
import urllib3
|
||||||
{'accept-encoding': 'gzip,deflate'}
|
|
||||||
|
print(urllib3.util.make_headers(keep_alive=True, user_agent="Batman/1.0"))
|
||||||
|
# {'connection': 'keep-alive', 'user-agent': 'Batman/1.0'}
|
||||||
|
print(urllib3.util.make_headers(accept_encoding=True))
|
||||||
|
# {'accept-encoding': 'gzip,deflate'}
|
||||||
"""
|
"""
|
||||||
headers = {}
|
headers: dict[str, str] = {}
|
||||||
if accept_encoding:
|
if accept_encoding:
|
||||||
if isinstance(accept_encoding, str):
|
if isinstance(accept_encoding, str):
|
||||||
pass
|
pass
|
||||||
|
@ -85,12 +116,14 @@ def make_headers(
|
||||||
headers["connection"] = "keep-alive"
|
headers["connection"] = "keep-alive"
|
||||||
|
|
||||||
if basic_auth:
|
if basic_auth:
|
||||||
headers["authorization"] = "Basic " + b64encode(b(basic_auth)).decode("utf-8")
|
headers[
|
||||||
|
"authorization"
|
||||||
|
] = f"Basic {b64encode(basic_auth.encode('latin-1')).decode()}"
|
||||||
|
|
||||||
if proxy_basic_auth:
|
if proxy_basic_auth:
|
||||||
headers["proxy-authorization"] = "Basic " + b64encode(
|
headers[
|
||||||
b(proxy_basic_auth)
|
"proxy-authorization"
|
||||||
).decode("utf-8")
|
] = f"Basic {b64encode(proxy_basic_auth.encode('latin-1')).decode()}"
|
||||||
|
|
||||||
if disable_cache:
|
if disable_cache:
|
||||||
headers["cache-control"] = "no-cache"
|
headers["cache-control"] = "no-cache"
|
||||||
|
@ -98,7 +131,9 @@ def make_headers(
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def set_file_position(body, pos):
|
def set_file_position(
|
||||||
|
body: typing.Any, pos: _TYPE_BODY_POSITION | None
|
||||||
|
) -> _TYPE_BODY_POSITION | None:
|
||||||
"""
|
"""
|
||||||
If a position is provided, move file to that point.
|
If a position is provided, move file to that point.
|
||||||
Otherwise, we'll attempt to record a position for future use.
|
Otherwise, we'll attempt to record a position for future use.
|
||||||
|
@ -108,7 +143,7 @@ def set_file_position(body, pos):
|
||||||
elif getattr(body, "tell", None) is not None:
|
elif getattr(body, "tell", None) is not None:
|
||||||
try:
|
try:
|
||||||
pos = body.tell()
|
pos = body.tell()
|
||||||
except (IOError, OSError):
|
except OSError:
|
||||||
# This differentiates from None, allowing us to catch
|
# This differentiates from None, allowing us to catch
|
||||||
# a failed `tell()` later when trying to rewind the body.
|
# a failed `tell()` later when trying to rewind the body.
|
||||||
pos = _FAILEDTELL
|
pos = _FAILEDTELL
|
||||||
|
@ -116,7 +151,7 @@ def set_file_position(body, pos):
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
|
|
||||||
def rewind_body(body, body_pos):
|
def rewind_body(body: typing.IO[typing.AnyStr], body_pos: _TYPE_BODY_POSITION) -> None:
|
||||||
"""
|
"""
|
||||||
Attempt to rewind body to a certain position.
|
Attempt to rewind body to a certain position.
|
||||||
Primarily used for request redirects and retries.
|
Primarily used for request redirects and retries.
|
||||||
|
@ -128,13 +163,13 @@ def rewind_body(body, body_pos):
|
||||||
Position to seek to in file.
|
Position to seek to in file.
|
||||||
"""
|
"""
|
||||||
body_seek = getattr(body, "seek", None)
|
body_seek = getattr(body, "seek", None)
|
||||||
if body_seek is not None and isinstance(body_pos, integer_types):
|
if body_seek is not None and isinstance(body_pos, int):
|
||||||
try:
|
try:
|
||||||
body_seek(body_pos)
|
body_seek(body_pos)
|
||||||
except (IOError, OSError):
|
except OSError as e:
|
||||||
raise UnrewindableBodyError(
|
raise UnrewindableBodyError(
|
||||||
"An error occurred when rewinding request body for redirect/retry."
|
"An error occurred when rewinding request body for redirect/retry."
|
||||||
)
|
) from e
|
||||||
elif body_pos is _FAILEDTELL:
|
elif body_pos is _FAILEDTELL:
|
||||||
raise UnrewindableBodyError(
|
raise UnrewindableBodyError(
|
||||||
"Unable to record file position for rewinding "
|
"Unable to record file position for rewinding "
|
||||||
|
@ -142,5 +177,80 @@ def rewind_body(body, body_pos):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"body_pos must be of type integer, instead it was %s." % type(body_pos)
|
f"body_pos must be of type integer, instead it was {type(body_pos)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChunksAndContentLength(typing.NamedTuple):
|
||||||
|
chunks: typing.Iterable[bytes] | None
|
||||||
|
content_length: int | None
|
||||||
|
|
||||||
|
|
||||||
|
def body_to_chunks(
|
||||||
|
body: typing.Any | None, method: str, blocksize: int
|
||||||
|
) -> ChunksAndContentLength:
|
||||||
|
"""Takes the HTTP request method, body, and blocksize and
|
||||||
|
transforms them into an iterable of chunks to pass to
|
||||||
|
socket.sendall() and an optional 'Content-Length' header.
|
||||||
|
|
||||||
|
A 'Content-Length' of 'None' indicates the length of the body
|
||||||
|
can't be determined so should use 'Transfer-Encoding: chunked'
|
||||||
|
for framing instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
chunks: typing.Iterable[bytes] | None
|
||||||
|
content_length: int | None
|
||||||
|
|
||||||
|
# No body, we need to make a recommendation on 'Content-Length'
|
||||||
|
# based on whether that request method is expected to have
|
||||||
|
# a body or not.
|
||||||
|
if body is None:
|
||||||
|
chunks = None
|
||||||
|
if method.upper() not in _METHODS_NOT_EXPECTING_BODY:
|
||||||
|
content_length = 0
|
||||||
|
else:
|
||||||
|
content_length = None
|
||||||
|
|
||||||
|
# Bytes or strings become bytes
|
||||||
|
elif isinstance(body, (str, bytes)):
|
||||||
|
chunks = (to_bytes(body),)
|
||||||
|
content_length = len(chunks[0])
|
||||||
|
|
||||||
|
# File-like object, TODO: use seek() and tell() for length?
|
||||||
|
elif hasattr(body, "read"):
|
||||||
|
|
||||||
|
def chunk_readable() -> typing.Iterable[bytes]:
|
||||||
|
nonlocal body, blocksize
|
||||||
|
encode = isinstance(body, io.TextIOBase)
|
||||||
|
while True:
|
||||||
|
datablock = body.read(blocksize)
|
||||||
|
if not datablock:
|
||||||
|
break
|
||||||
|
if encode:
|
||||||
|
datablock = datablock.encode("iso-8859-1")
|
||||||
|
yield datablock
|
||||||
|
|
||||||
|
chunks = chunk_readable()
|
||||||
|
content_length = None
|
||||||
|
|
||||||
|
# Otherwise we need to start checking via duck-typing.
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Check if the body implements the buffer API.
|
||||||
|
mv = memoryview(body)
|
||||||
|
except TypeError:
|
||||||
|
try:
|
||||||
|
# Check if the body is an iterable
|
||||||
|
chunks = iter(body)
|
||||||
|
content_length = None
|
||||||
|
except TypeError:
|
||||||
|
raise TypeError(
|
||||||
|
f"'body' must be a bytes-like object, file-like "
|
||||||
|
f"object, or iterable. Instead was {body!r}"
|
||||||
|
) from None
|
||||||
|
else:
|
||||||
|
# Since it implements the buffer API can be passed directly to socket.sendall()
|
||||||
|
chunks = (body,)
|
||||||
|
content_length = mv.nbytes
|
||||||
|
|
||||||
|
return ChunksAndContentLength(chunks=chunks, content_length=content_length)
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import http.client as httplib
|
||||||
from email.errors import MultipartInvariantViolationDefect, StartBoundaryNotFoundDefect
|
from email.errors import MultipartInvariantViolationDefect, StartBoundaryNotFoundDefect
|
||||||
|
|
||||||
from ..exceptions import HeaderParsingError
|
from ..exceptions import HeaderParsingError
|
||||||
from ..packages.six.moves import http_client as httplib
|
|
||||||
|
|
||||||
|
|
||||||
def is_fp_closed(obj):
|
def is_fp_closed(obj: object) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks whether a given file-like object is closed.
|
Checks whether a given file-like object is closed.
|
||||||
|
|
||||||
|
@ -17,27 +17,27 @@ def is_fp_closed(obj):
|
||||||
try:
|
try:
|
||||||
# Check `isclosed()` first, in case Python3 doesn't set `closed`.
|
# Check `isclosed()` first, in case Python3 doesn't set `closed`.
|
||||||
# GH Issue #928
|
# GH Issue #928
|
||||||
return obj.isclosed()
|
return obj.isclosed() # type: ignore[no-any-return, attr-defined]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check via the official file-like-object way.
|
# Check via the official file-like-object way.
|
||||||
return obj.closed
|
return obj.closed # type: ignore[no-any-return, attr-defined]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if the object is a container for another file-like object that
|
# Check if the object is a container for another file-like object that
|
||||||
# gets released on exhaustion (e.g. HTTPResponse).
|
# gets released on exhaustion (e.g. HTTPResponse).
|
||||||
return obj.fp is None
|
return obj.fp is None # type: ignore[attr-defined]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
raise ValueError("Unable to determine whether fp is closed.")
|
raise ValueError("Unable to determine whether fp is closed.")
|
||||||
|
|
||||||
|
|
||||||
def assert_header_parsing(headers):
|
def assert_header_parsing(headers: httplib.HTTPMessage) -> None:
|
||||||
"""
|
"""
|
||||||
Asserts whether all headers have been successfully parsed.
|
Asserts whether all headers have been successfully parsed.
|
||||||
Extracts encountered errors from the result of parsing headers.
|
Extracts encountered errors from the result of parsing headers.
|
||||||
|
@ -53,55 +53,49 @@ def assert_header_parsing(headers):
|
||||||
# This will fail silently if we pass in the wrong kind of parameter.
|
# This will fail silently if we pass in the wrong kind of parameter.
|
||||||
# To make debugging easier add an explicit check.
|
# To make debugging easier add an explicit check.
|
||||||
if not isinstance(headers, httplib.HTTPMessage):
|
if not isinstance(headers, httplib.HTTPMessage):
|
||||||
raise TypeError("expected httplib.Message, got {0}.".format(type(headers)))
|
raise TypeError(f"expected httplib.Message, got {type(headers)}.")
|
||||||
|
|
||||||
defects = getattr(headers, "defects", None)
|
|
||||||
get_payload = getattr(headers, "get_payload", None)
|
|
||||||
|
|
||||||
unparsed_data = None
|
unparsed_data = None
|
||||||
if get_payload:
|
|
||||||
# get_payload is actually email.message.Message.get_payload;
|
|
||||||
# we're only interested in the result if it's not a multipart message
|
|
||||||
if not headers.is_multipart():
|
|
||||||
payload = get_payload()
|
|
||||||
|
|
||||||
if isinstance(payload, (bytes, str)):
|
# get_payload is actually email.message.Message.get_payload;
|
||||||
unparsed_data = payload
|
# we're only interested in the result if it's not a multipart message
|
||||||
if defects:
|
if not headers.is_multipart():
|
||||||
# httplib is assuming a response body is available
|
payload = headers.get_payload()
|
||||||
# when parsing headers even when httplib only sends
|
|
||||||
# header data to parse_headers() This results in
|
|
||||||
# defects on multipart responses in particular.
|
|
||||||
# See: https://github.com/urllib3/urllib3/issues/800
|
|
||||||
|
|
||||||
# So we ignore the following defects:
|
if isinstance(payload, (bytes, str)):
|
||||||
# - StartBoundaryNotFoundDefect:
|
unparsed_data = payload
|
||||||
# The claimed start boundary was never found.
|
|
||||||
# - MultipartInvariantViolationDefect:
|
# httplib is assuming a response body is available
|
||||||
# A message claimed to be a multipart but no subparts were found.
|
# when parsing headers even when httplib only sends
|
||||||
defects = [
|
# header data to parse_headers() This results in
|
||||||
defect
|
# defects on multipart responses in particular.
|
||||||
for defect in defects
|
# See: https://github.com/urllib3/urllib3/issues/800
|
||||||
if not isinstance(
|
|
||||||
defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect)
|
# So we ignore the following defects:
|
||||||
)
|
# - StartBoundaryNotFoundDefect:
|
||||||
]
|
# The claimed start boundary was never found.
|
||||||
|
# - MultipartInvariantViolationDefect:
|
||||||
|
# A message claimed to be a multipart but no subparts were found.
|
||||||
|
defects = [
|
||||||
|
defect
|
||||||
|
for defect in headers.defects
|
||||||
|
if not isinstance(
|
||||||
|
defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
if defects or unparsed_data:
|
if defects or unparsed_data:
|
||||||
raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data)
|
raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data)
|
||||||
|
|
||||||
|
|
||||||
def is_response_to_head(response):
|
def is_response_to_head(response: httplib.HTTPResponse) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks whether the request of a response has been a HEAD-request.
|
Checks whether the request of a response has been a HEAD-request.
|
||||||
Handles the quirks of AppEngine.
|
|
||||||
|
|
||||||
:param http.client.HTTPResponse response:
|
:param http.client.HTTPResponse response:
|
||||||
Response to check if the originating request
|
Response to check if the originating request
|
||||||
used 'HEAD' as a method.
|
used 'HEAD' as a method.
|
||||||
"""
|
"""
|
||||||
# FIXME: Can we do this somehow without accessing private httplib _method?
|
# FIXME: Can we do this somehow without accessing private httplib _method?
|
||||||
method = response._method
|
method_str = response._method # type: str # type: ignore[attr-defined]
|
||||||
if isinstance(method, int): # Platform-specific: Appengine
|
return method_str.upper() == "HEAD"
|
||||||
return method == 3
|
|
||||||
return method.upper() == "HEAD"
|
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
import email
|
import email
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import warnings
|
import typing
|
||||||
from collections import namedtuple
|
|
||||||
from itertools import takewhile
|
from itertools import takewhile
|
||||||
|
from types import TracebackType
|
||||||
|
|
||||||
from ..exceptions import (
|
from ..exceptions import (
|
||||||
ConnectTimeoutError,
|
ConnectTimeoutError,
|
||||||
|
@ -17,97 +18,49 @@ from ..exceptions import (
|
||||||
ReadTimeoutError,
|
ReadTimeoutError,
|
||||||
ResponseError,
|
ResponseError,
|
||||||
)
|
)
|
||||||
from ..packages import six
|
from .util import reraise
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from ..connectionpool import ConnectionPool
|
||||||
|
from ..response import BaseHTTPResponse
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Data structure for representing the metadata of requests that result in a retry.
|
# Data structure for representing the metadata of requests that result in a retry.
|
||||||
RequestHistory = namedtuple(
|
class RequestHistory(typing.NamedTuple):
|
||||||
"RequestHistory", ["method", "url", "error", "status", "redirect_location"]
|
method: str | None
|
||||||
)
|
url: str | None
|
||||||
|
error: Exception | None
|
||||||
|
status: int | None
|
||||||
|
redirect_location: str | None
|
||||||
|
|
||||||
|
|
||||||
# TODO: In v2 we can remove this sentinel and metaclass with deprecated options.
|
class Retry:
|
||||||
_Default = object()
|
|
||||||
|
|
||||||
|
|
||||||
class _RetryMeta(type):
|
|
||||||
@property
|
|
||||||
def DEFAULT_METHOD_WHITELIST(cls):
|
|
||||||
warnings.warn(
|
|
||||||
"Using 'Retry.DEFAULT_METHOD_WHITELIST' is deprecated and "
|
|
||||||
"will be removed in v2.0. Use 'Retry.DEFAULT_ALLOWED_METHODS' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
return cls.DEFAULT_ALLOWED_METHODS
|
|
||||||
|
|
||||||
@DEFAULT_METHOD_WHITELIST.setter
|
|
||||||
def DEFAULT_METHOD_WHITELIST(cls, value):
|
|
||||||
warnings.warn(
|
|
||||||
"Using 'Retry.DEFAULT_METHOD_WHITELIST' is deprecated and "
|
|
||||||
"will be removed in v2.0. Use 'Retry.DEFAULT_ALLOWED_METHODS' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
cls.DEFAULT_ALLOWED_METHODS = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def DEFAULT_REDIRECT_HEADERS_BLACKLIST(cls):
|
|
||||||
warnings.warn(
|
|
||||||
"Using 'Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST' is deprecated and "
|
|
||||||
"will be removed in v2.0. Use 'Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
return cls.DEFAULT_REMOVE_HEADERS_ON_REDIRECT
|
|
||||||
|
|
||||||
@DEFAULT_REDIRECT_HEADERS_BLACKLIST.setter
|
|
||||||
def DEFAULT_REDIRECT_HEADERS_BLACKLIST(cls, value):
|
|
||||||
warnings.warn(
|
|
||||||
"Using 'Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST' is deprecated and "
|
|
||||||
"will be removed in v2.0. Use 'Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
cls.DEFAULT_REMOVE_HEADERS_ON_REDIRECT = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def BACKOFF_MAX(cls):
|
|
||||||
warnings.warn(
|
|
||||||
"Using 'Retry.BACKOFF_MAX' is deprecated and "
|
|
||||||
"will be removed in v2.0. Use 'Retry.DEFAULT_BACKOFF_MAX' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
return cls.DEFAULT_BACKOFF_MAX
|
|
||||||
|
|
||||||
@BACKOFF_MAX.setter
|
|
||||||
def BACKOFF_MAX(cls, value):
|
|
||||||
warnings.warn(
|
|
||||||
"Using 'Retry.BACKOFF_MAX' is deprecated and "
|
|
||||||
"will be removed in v2.0. Use 'Retry.DEFAULT_BACKOFF_MAX' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
cls.DEFAULT_BACKOFF_MAX = value
|
|
||||||
|
|
||||||
|
|
||||||
@six.add_metaclass(_RetryMeta)
|
|
||||||
class Retry(object):
|
|
||||||
"""Retry configuration.
|
"""Retry configuration.
|
||||||
|
|
||||||
Each retry attempt will create a new Retry object with updated values, so
|
Each retry attempt will create a new Retry object with updated values, so
|
||||||
they can be safely reused.
|
they can be safely reused.
|
||||||
|
|
||||||
Retries can be defined as a default for a pool::
|
Retries can be defined as a default for a pool:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
retries = Retry(connect=5, read=2, redirect=5)
|
retries = Retry(connect=5, read=2, redirect=5)
|
||||||
http = PoolManager(retries=retries)
|
http = PoolManager(retries=retries)
|
||||||
response = http.request('GET', 'http://example.com/')
|
response = http.request("GET", "https://example.com/")
|
||||||
|
|
||||||
Or per-request (which overrides the default for the pool)::
|
Or per-request (which overrides the default for the pool):
|
||||||
|
|
||||||
response = http.request('GET', 'http://example.com/', retries=Retry(10))
|
.. code-block:: python
|
||||||
|
|
||||||
Retries can be disabled by passing ``False``::
|
response = http.request("GET", "https://example.com/", retries=Retry(10))
|
||||||
|
|
||||||
response = http.request('GET', 'http://example.com/', retries=False)
|
Retries can be disabled by passing ``False``:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = http.request("GET", "https://example.com/", retries=False)
|
||||||
|
|
||||||
Errors will be wrapped in :class:`~urllib3.exceptions.MaxRetryError` unless
|
Errors will be wrapped in :class:`~urllib3.exceptions.MaxRetryError` unless
|
||||||
retries are disabled, in which case the causing exception will be raised.
|
retries are disabled, in which case the causing exception will be raised.
|
||||||
|
@ -169,21 +122,16 @@ class Retry(object):
|
||||||
If ``total`` is not set, it's a good idea to set this to 0 to account
|
If ``total`` is not set, it's a good idea to set this to 0 to account
|
||||||
for unexpected edge cases and avoid infinite retry loops.
|
for unexpected edge cases and avoid infinite retry loops.
|
||||||
|
|
||||||
:param iterable allowed_methods:
|
:param Collection allowed_methods:
|
||||||
Set of uppercased HTTP method verbs that we should retry on.
|
Set of uppercased HTTP method verbs that we should retry on.
|
||||||
|
|
||||||
By default, we only retry on methods which are considered to be
|
By default, we only retry on methods which are considered to be
|
||||||
idempotent (multiple requests with the same parameters end with the
|
idempotent (multiple requests with the same parameters end with the
|
||||||
same state). See :attr:`Retry.DEFAULT_ALLOWED_METHODS`.
|
same state). See :attr:`Retry.DEFAULT_ALLOWED_METHODS`.
|
||||||
|
|
||||||
Set to a ``False`` value to retry on any verb.
|
Set to a ``None`` value to retry on any verb.
|
||||||
|
|
||||||
.. warning::
|
:param Collection status_forcelist:
|
||||||
|
|
||||||
Previously this parameter was named ``method_whitelist``, that
|
|
||||||
usage is deprecated in v1.26.0 and will be removed in v2.0.
|
|
||||||
|
|
||||||
:param iterable status_forcelist:
|
|
||||||
A set of integer HTTP status codes that we should force a retry on.
|
A set of integer HTTP status codes that we should force a retry on.
|
||||||
A retry is initiated if the request method is in ``allowed_methods``
|
A retry is initiated if the request method is in ``allowed_methods``
|
||||||
and the response status code is in ``status_forcelist``.
|
and the response status code is in ``status_forcelist``.
|
||||||
|
@ -195,13 +143,17 @@ class Retry(object):
|
||||||
(most errors are resolved immediately by a second try without a
|
(most errors are resolved immediately by a second try without a
|
||||||
delay). urllib3 will sleep for::
|
delay). urllib3 will sleep for::
|
||||||
|
|
||||||
{backoff factor} * (2 ** ({number of total retries} - 1))
|
{backoff factor} * (2 ** ({number of previous retries}))
|
||||||
|
|
||||||
seconds. If the backoff_factor is 0.1, then :func:`.sleep` will sleep
|
seconds. If `backoff_jitter` is non-zero, this sleep is extended by::
|
||||||
for [0.0s, 0.2s, 0.4s, ...] between retries. It will never be longer
|
|
||||||
than :attr:`Retry.DEFAULT_BACKOFF_MAX`.
|
|
||||||
|
|
||||||
By default, backoff is disabled (set to 0).
|
random.uniform(0, {backoff jitter})
|
||||||
|
|
||||||
|
seconds. For example, if the backoff_factor is 0.1, then :func:`Retry.sleep` will
|
||||||
|
sleep for [0.0s, 0.2s, 0.4s, 0.8s, ...] between retries. No backoff will ever
|
||||||
|
be longer than `backoff_max`.
|
||||||
|
|
||||||
|
By default, backoff is disabled (factor set to 0).
|
||||||
|
|
||||||
:param bool raise_on_redirect: Whether, if the number of redirects is
|
:param bool raise_on_redirect: Whether, if the number of redirects is
|
||||||
exhausted, to raise a MaxRetryError, or to return a response with a
|
exhausted, to raise a MaxRetryError, or to return a response with a
|
||||||
|
@ -220,7 +172,7 @@ class Retry(object):
|
||||||
Whether to respect Retry-After header on status codes defined as
|
Whether to respect Retry-After header on status codes defined as
|
||||||
:attr:`Retry.RETRY_AFTER_STATUS_CODES` or not.
|
:attr:`Retry.RETRY_AFTER_STATUS_CODES` or not.
|
||||||
|
|
||||||
:param iterable remove_headers_on_redirect:
|
:param Collection remove_headers_on_redirect:
|
||||||
Sequence of headers to remove from the request when a response
|
Sequence of headers to remove from the request when a response
|
||||||
indicating a redirect is returned before firing off the redirected
|
indicating a redirect is returned before firing off the redirected
|
||||||
request.
|
request.
|
||||||
|
@ -237,48 +189,33 @@ class Retry(object):
|
||||||
#: Default headers to be used for ``remove_headers_on_redirect``
|
#: Default headers to be used for ``remove_headers_on_redirect``
|
||||||
DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset(["Authorization"])
|
DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset(["Authorization"])
|
||||||
|
|
||||||
#: Maximum backoff time.
|
#: Default maximum backoff time.
|
||||||
DEFAULT_BACKOFF_MAX = 120
|
DEFAULT_BACKOFF_MAX = 120
|
||||||
|
|
||||||
|
# Backward compatibility; assigned outside of the class.
|
||||||
|
DEFAULT: typing.ClassVar[Retry]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
total=10,
|
total: bool | int | None = 10,
|
||||||
connect=None,
|
connect: int | None = None,
|
||||||
read=None,
|
read: int | None = None,
|
||||||
redirect=None,
|
redirect: bool | int | None = None,
|
||||||
status=None,
|
status: int | None = None,
|
||||||
other=None,
|
other: int | None = None,
|
||||||
allowed_methods=_Default,
|
allowed_methods: typing.Collection[str] | None = DEFAULT_ALLOWED_METHODS,
|
||||||
status_forcelist=None,
|
status_forcelist: typing.Collection[int] | None = None,
|
||||||
backoff_factor=0,
|
backoff_factor: float = 0,
|
||||||
raise_on_redirect=True,
|
backoff_max: float = DEFAULT_BACKOFF_MAX,
|
||||||
raise_on_status=True,
|
raise_on_redirect: bool = True,
|
||||||
history=None,
|
raise_on_status: bool = True,
|
||||||
respect_retry_after_header=True,
|
history: tuple[RequestHistory, ...] | None = None,
|
||||||
remove_headers_on_redirect=_Default,
|
respect_retry_after_header: bool = True,
|
||||||
# TODO: Deprecated, remove in v2.0
|
remove_headers_on_redirect: typing.Collection[
|
||||||
method_whitelist=_Default,
|
str
|
||||||
):
|
] = DEFAULT_REMOVE_HEADERS_ON_REDIRECT,
|
||||||
|
backoff_jitter: float = 0.0,
|
||||||
if method_whitelist is not _Default:
|
) -> None:
|
||||||
if allowed_methods is not _Default:
|
|
||||||
raise ValueError(
|
|
||||||
"Using both 'allowed_methods' and "
|
|
||||||
"'method_whitelist' together is not allowed. "
|
|
||||||
"Instead only use 'allowed_methods'"
|
|
||||||
)
|
|
||||||
warnings.warn(
|
|
||||||
"Using 'method_whitelist' with Retry is deprecated and "
|
|
||||||
"will be removed in v2.0. Use 'allowed_methods' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
allowed_methods = method_whitelist
|
|
||||||
if allowed_methods is _Default:
|
|
||||||
allowed_methods = self.DEFAULT_ALLOWED_METHODS
|
|
||||||
if remove_headers_on_redirect is _Default:
|
|
||||||
remove_headers_on_redirect = self.DEFAULT_REMOVE_HEADERS_ON_REDIRECT
|
|
||||||
|
|
||||||
self.total = total
|
self.total = total
|
||||||
self.connect = connect
|
self.connect = connect
|
||||||
self.read = read
|
self.read = read
|
||||||
|
@ -293,15 +230,17 @@ class Retry(object):
|
||||||
self.status_forcelist = status_forcelist or set()
|
self.status_forcelist = status_forcelist or set()
|
||||||
self.allowed_methods = allowed_methods
|
self.allowed_methods = allowed_methods
|
||||||
self.backoff_factor = backoff_factor
|
self.backoff_factor = backoff_factor
|
||||||
|
self.backoff_max = backoff_max
|
||||||
self.raise_on_redirect = raise_on_redirect
|
self.raise_on_redirect = raise_on_redirect
|
||||||
self.raise_on_status = raise_on_status
|
self.raise_on_status = raise_on_status
|
||||||
self.history = history or tuple()
|
self.history = history or ()
|
||||||
self.respect_retry_after_header = respect_retry_after_header
|
self.respect_retry_after_header = respect_retry_after_header
|
||||||
self.remove_headers_on_redirect = frozenset(
|
self.remove_headers_on_redirect = frozenset(
|
||||||
[h.lower() for h in remove_headers_on_redirect]
|
h.lower() for h in remove_headers_on_redirect
|
||||||
)
|
)
|
||||||
|
self.backoff_jitter = backoff_jitter
|
||||||
|
|
||||||
def new(self, **kw):
|
def new(self, **kw: typing.Any) -> Retry:
|
||||||
params = dict(
|
params = dict(
|
||||||
total=self.total,
|
total=self.total,
|
||||||
connect=self.connect,
|
connect=self.connect,
|
||||||
|
@ -309,36 +248,28 @@ class Retry(object):
|
||||||
redirect=self.redirect,
|
redirect=self.redirect,
|
||||||
status=self.status,
|
status=self.status,
|
||||||
other=self.other,
|
other=self.other,
|
||||||
|
allowed_methods=self.allowed_methods,
|
||||||
status_forcelist=self.status_forcelist,
|
status_forcelist=self.status_forcelist,
|
||||||
backoff_factor=self.backoff_factor,
|
backoff_factor=self.backoff_factor,
|
||||||
|
backoff_max=self.backoff_max,
|
||||||
raise_on_redirect=self.raise_on_redirect,
|
raise_on_redirect=self.raise_on_redirect,
|
||||||
raise_on_status=self.raise_on_status,
|
raise_on_status=self.raise_on_status,
|
||||||
history=self.history,
|
history=self.history,
|
||||||
remove_headers_on_redirect=self.remove_headers_on_redirect,
|
remove_headers_on_redirect=self.remove_headers_on_redirect,
|
||||||
respect_retry_after_header=self.respect_retry_after_header,
|
respect_retry_after_header=self.respect_retry_after_header,
|
||||||
|
backoff_jitter=self.backoff_jitter,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: If already given in **kw we use what's given to us
|
|
||||||
# If not given we need to figure out what to pass. We decide
|
|
||||||
# based on whether our class has the 'method_whitelist' property
|
|
||||||
# and if so we pass the deprecated 'method_whitelist' otherwise
|
|
||||||
# we use 'allowed_methods'. Remove in v2.0
|
|
||||||
if "method_whitelist" not in kw and "allowed_methods" not in kw:
|
|
||||||
if "method_whitelist" in self.__dict__:
|
|
||||||
warnings.warn(
|
|
||||||
"Using 'method_whitelist' with Retry is deprecated and "
|
|
||||||
"will be removed in v2.0. Use 'allowed_methods' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
params["method_whitelist"] = self.allowed_methods
|
|
||||||
else:
|
|
||||||
params["allowed_methods"] = self.allowed_methods
|
|
||||||
|
|
||||||
params.update(kw)
|
params.update(kw)
|
||||||
return type(self)(**params)
|
return type(self)(**params) # type: ignore[arg-type]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_int(cls, retries, redirect=True, default=None):
|
def from_int(
|
||||||
|
cls,
|
||||||
|
retries: Retry | bool | int | None,
|
||||||
|
redirect: bool | int | None = True,
|
||||||
|
default: Retry | bool | int | None = None,
|
||||||
|
) -> Retry:
|
||||||
"""Backwards-compatibility for the old retries format."""
|
"""Backwards-compatibility for the old retries format."""
|
||||||
if retries is None:
|
if retries is None:
|
||||||
retries = default if default is not None else cls.DEFAULT
|
retries = default if default is not None else cls.DEFAULT
|
||||||
|
@ -351,7 +282,7 @@ class Retry(object):
|
||||||
log.debug("Converted retries value: %r -> %r", retries, new_retries)
|
log.debug("Converted retries value: %r -> %r", retries, new_retries)
|
||||||
return new_retries
|
return new_retries
|
||||||
|
|
||||||
def get_backoff_time(self):
|
def get_backoff_time(self) -> float:
|
||||||
"""Formula for computing the current backoff
|
"""Formula for computing the current backoff
|
||||||
|
|
||||||
:rtype: float
|
:rtype: float
|
||||||
|
@ -366,32 +297,28 @@ class Retry(object):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
backoff_value = self.backoff_factor * (2 ** (consecutive_errors_len - 1))
|
backoff_value = self.backoff_factor * (2 ** (consecutive_errors_len - 1))
|
||||||
return min(self.DEFAULT_BACKOFF_MAX, backoff_value)
|
if self.backoff_jitter != 0.0:
|
||||||
|
backoff_value += random.random() * self.backoff_jitter
|
||||||
|
return float(max(0, min(self.backoff_max, backoff_value)))
|
||||||
|
|
||||||
def parse_retry_after(self, retry_after):
|
def parse_retry_after(self, retry_after: str) -> float:
|
||||||
|
seconds: float
|
||||||
# Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4
|
# Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4
|
||||||
if re.match(r"^\s*[0-9]+\s*$", retry_after):
|
if re.match(r"^\s*[0-9]+\s*$", retry_after):
|
||||||
seconds = int(retry_after)
|
seconds = int(retry_after)
|
||||||
else:
|
else:
|
||||||
retry_date_tuple = email.utils.parsedate_tz(retry_after)
|
retry_date_tuple = email.utils.parsedate_tz(retry_after)
|
||||||
if retry_date_tuple is None:
|
if retry_date_tuple is None:
|
||||||
raise InvalidHeader("Invalid Retry-After header: %s" % retry_after)
|
raise InvalidHeader(f"Invalid Retry-After header: {retry_after}")
|
||||||
if retry_date_tuple[9] is None: # Python 2
|
|
||||||
# Assume UTC if no timezone was specified
|
|
||||||
# On Python2.7, parsedate_tz returns None for a timezone offset
|
|
||||||
# instead of 0 if no timezone is given, where mktime_tz treats
|
|
||||||
# a None timezone offset as local time.
|
|
||||||
retry_date_tuple = retry_date_tuple[:9] + (0,) + retry_date_tuple[10:]
|
|
||||||
|
|
||||||
retry_date = email.utils.mktime_tz(retry_date_tuple)
|
retry_date = email.utils.mktime_tz(retry_date_tuple)
|
||||||
seconds = retry_date - time.time()
|
seconds = retry_date - time.time()
|
||||||
|
|
||||||
if seconds < 0:
|
seconds = max(seconds, 0)
|
||||||
seconds = 0
|
|
||||||
|
|
||||||
return seconds
|
return seconds
|
||||||
|
|
||||||
def get_retry_after(self, response):
|
def get_retry_after(self, response: BaseHTTPResponse) -> float | None:
|
||||||
"""Get the value of Retry-After in seconds."""
|
"""Get the value of Retry-After in seconds."""
|
||||||
|
|
||||||
retry_after = response.headers.get("Retry-After")
|
retry_after = response.headers.get("Retry-After")
|
||||||
|
@ -401,7 +328,7 @@ class Retry(object):
|
||||||
|
|
||||||
return self.parse_retry_after(retry_after)
|
return self.parse_retry_after(retry_after)
|
||||||
|
|
||||||
def sleep_for_retry(self, response=None):
|
def sleep_for_retry(self, response: BaseHTTPResponse) -> bool:
|
||||||
retry_after = self.get_retry_after(response)
|
retry_after = self.get_retry_after(response)
|
||||||
if retry_after:
|
if retry_after:
|
||||||
time.sleep(retry_after)
|
time.sleep(retry_after)
|
||||||
|
@ -409,13 +336,13 @@ class Retry(object):
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _sleep_backoff(self):
|
def _sleep_backoff(self) -> None:
|
||||||
backoff = self.get_backoff_time()
|
backoff = self.get_backoff_time()
|
||||||
if backoff <= 0:
|
if backoff <= 0:
|
||||||
return
|
return
|
||||||
time.sleep(backoff)
|
time.sleep(backoff)
|
||||||
|
|
||||||
def sleep(self, response=None):
|
def sleep(self, response: BaseHTTPResponse | None = None) -> None:
|
||||||
"""Sleep between retry attempts.
|
"""Sleep between retry attempts.
|
||||||
|
|
||||||
This method will respect a server's ``Retry-After`` response header
|
This method will respect a server's ``Retry-After`` response header
|
||||||
|
@ -431,7 +358,7 @@ class Retry(object):
|
||||||
|
|
||||||
self._sleep_backoff()
|
self._sleep_backoff()
|
||||||
|
|
||||||
def _is_connection_error(self, err):
|
def _is_connection_error(self, err: Exception) -> bool:
|
||||||
"""Errors when we're fairly sure that the server did not receive the
|
"""Errors when we're fairly sure that the server did not receive the
|
||||||
request, so it should be safe to retry.
|
request, so it should be safe to retry.
|
||||||
"""
|
"""
|
||||||
|
@ -439,33 +366,23 @@ class Retry(object):
|
||||||
err = err.original_error
|
err = err.original_error
|
||||||
return isinstance(err, ConnectTimeoutError)
|
return isinstance(err, ConnectTimeoutError)
|
||||||
|
|
||||||
def _is_read_error(self, err):
|
def _is_read_error(self, err: Exception) -> bool:
|
||||||
"""Errors that occur after the request has been started, so we should
|
"""Errors that occur after the request has been started, so we should
|
||||||
assume that the server began processing it.
|
assume that the server began processing it.
|
||||||
"""
|
"""
|
||||||
return isinstance(err, (ReadTimeoutError, ProtocolError))
|
return isinstance(err, (ReadTimeoutError, ProtocolError))
|
||||||
|
|
||||||
def _is_method_retryable(self, method):
|
def _is_method_retryable(self, method: str) -> bool:
|
||||||
"""Checks if a given HTTP method should be retried upon, depending if
|
"""Checks if a given HTTP method should be retried upon, depending if
|
||||||
it is included in the allowed_methods
|
it is included in the allowed_methods
|
||||||
"""
|
"""
|
||||||
# TODO: For now favor if the Retry implementation sets its own method_whitelist
|
if self.allowed_methods and method.upper() not in self.allowed_methods:
|
||||||
# property outside of our constructor to avoid breaking custom implementations.
|
|
||||||
if "method_whitelist" in self.__dict__:
|
|
||||||
warnings.warn(
|
|
||||||
"Using 'method_whitelist' with Retry is deprecated and "
|
|
||||||
"will be removed in v2.0. Use 'allowed_methods' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
allowed_methods = self.method_whitelist
|
|
||||||
else:
|
|
||||||
allowed_methods = self.allowed_methods
|
|
||||||
|
|
||||||
if allowed_methods and method.upper() not in allowed_methods:
|
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def is_retry(self, method, status_code, has_retry_after=False):
|
def is_retry(
|
||||||
|
self, method: str, status_code: int, has_retry_after: bool = False
|
||||||
|
) -> bool:
|
||||||
"""Is this method/status code retryable? (Based on allowlists and control
|
"""Is this method/status code retryable? (Based on allowlists and control
|
||||||
variables such as the number of total retries to allow, whether to
|
variables such as the number of total retries to allow, whether to
|
||||||
respect the Retry-After header, whether this header is present, and
|
respect the Retry-After header, whether this header is present, and
|
||||||
|
@ -478,24 +395,27 @@ class Retry(object):
|
||||||
if self.status_forcelist and status_code in self.status_forcelist:
|
if self.status_forcelist and status_code in self.status_forcelist:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return (
|
return bool(
|
||||||
self.total
|
self.total
|
||||||
and self.respect_retry_after_header
|
and self.respect_retry_after_header
|
||||||
and has_retry_after
|
and has_retry_after
|
||||||
and (status_code in self.RETRY_AFTER_STATUS_CODES)
|
and (status_code in self.RETRY_AFTER_STATUS_CODES)
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_exhausted(self):
|
def is_exhausted(self) -> bool:
|
||||||
"""Are we out of retries?"""
|
"""Are we out of retries?"""
|
||||||
retry_counts = (
|
retry_counts = [
|
||||||
self.total,
|
x
|
||||||
self.connect,
|
for x in (
|
||||||
self.read,
|
self.total,
|
||||||
self.redirect,
|
self.connect,
|
||||||
self.status,
|
self.read,
|
||||||
self.other,
|
self.redirect,
|
||||||
)
|
self.status,
|
||||||
retry_counts = list(filter(None, retry_counts))
|
self.other,
|
||||||
|
)
|
||||||
|
if x
|
||||||
|
]
|
||||||
if not retry_counts:
|
if not retry_counts:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -503,18 +423,18 @@ class Retry(object):
|
||||||
|
|
||||||
def increment(
|
def increment(
|
||||||
self,
|
self,
|
||||||
method=None,
|
method: str | None = None,
|
||||||
url=None,
|
url: str | None = None,
|
||||||
response=None,
|
response: BaseHTTPResponse | None = None,
|
||||||
error=None,
|
error: Exception | None = None,
|
||||||
_pool=None,
|
_pool: ConnectionPool | None = None,
|
||||||
_stacktrace=None,
|
_stacktrace: TracebackType | None = None,
|
||||||
):
|
) -> Retry:
|
||||||
"""Return a new Retry object with incremented retry counters.
|
"""Return a new Retry object with incremented retry counters.
|
||||||
|
|
||||||
:param response: A response object, or None, if the server did not
|
:param response: A response object, or None, if the server did not
|
||||||
return a response.
|
return a response.
|
||||||
:type response: :class:`~urllib3.response.HTTPResponse`
|
:type response: :class:`~urllib3.response.BaseHTTPResponse`
|
||||||
:param Exception error: An error encountered during the request, or
|
:param Exception error: An error encountered during the request, or
|
||||||
None if the response was received successfully.
|
None if the response was received successfully.
|
||||||
|
|
||||||
|
@ -522,7 +442,7 @@ class Retry(object):
|
||||||
"""
|
"""
|
||||||
if self.total is False and error:
|
if self.total is False and error:
|
||||||
# Disabled, indicate to re-raise the error.
|
# Disabled, indicate to re-raise the error.
|
||||||
raise six.reraise(type(error), error, _stacktrace)
|
raise reraise(type(error), error, _stacktrace)
|
||||||
|
|
||||||
total = self.total
|
total = self.total
|
||||||
if total is not None:
|
if total is not None:
|
||||||
|
@ -540,14 +460,14 @@ class Retry(object):
|
||||||
if error and self._is_connection_error(error):
|
if error and self._is_connection_error(error):
|
||||||
# Connect retry?
|
# Connect retry?
|
||||||
if connect is False:
|
if connect is False:
|
||||||
raise six.reraise(type(error), error, _stacktrace)
|
raise reraise(type(error), error, _stacktrace)
|
||||||
elif connect is not None:
|
elif connect is not None:
|
||||||
connect -= 1
|
connect -= 1
|
||||||
|
|
||||||
elif error and self._is_read_error(error):
|
elif error and self._is_read_error(error):
|
||||||
# Read retry?
|
# Read retry?
|
||||||
if read is False or not self._is_method_retryable(method):
|
if read is False or method is None or not self._is_method_retryable(method):
|
||||||
raise six.reraise(type(error), error, _stacktrace)
|
raise reraise(type(error), error, _stacktrace)
|
||||||
elif read is not None:
|
elif read is not None:
|
||||||
read -= 1
|
read -= 1
|
||||||
|
|
||||||
|
@ -561,7 +481,9 @@ class Retry(object):
|
||||||
if redirect is not None:
|
if redirect is not None:
|
||||||
redirect -= 1
|
redirect -= 1
|
||||||
cause = "too many redirects"
|
cause = "too many redirects"
|
||||||
redirect_location = response.get_redirect_location()
|
response_redirect_location = response.get_redirect_location()
|
||||||
|
if response_redirect_location:
|
||||||
|
redirect_location = response_redirect_location
|
||||||
status = response.status
|
status = response.status
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -589,31 +511,18 @@ class Retry(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_retry.is_exhausted():
|
if new_retry.is_exhausted():
|
||||||
raise MaxRetryError(_pool, url, error or ResponseError(cause))
|
reason = error or ResponseError(cause)
|
||||||
|
raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type]
|
||||||
|
|
||||||
log.debug("Incremented Retry for (url='%s'): %r", url, new_retry)
|
log.debug("Incremented Retry for (url='%s'): %r", url, new_retry)
|
||||||
|
|
||||||
return new_retry
|
return new_retry
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
"{cls.__name__}(total={self.total}, connect={self.connect}, "
|
f"{type(self).__name__}(total={self.total}, connect={self.connect}, "
|
||||||
"read={self.read}, redirect={self.redirect}, status={self.status})"
|
f"read={self.read}, redirect={self.redirect}, status={self.status})"
|
||||||
).format(cls=type(self), self=self)
|
)
|
||||||
|
|
||||||
def __getattr__(self, item):
|
|
||||||
if item == "method_whitelist":
|
|
||||||
# TODO: Remove this deprecated alias in v2.0
|
|
||||||
warnings.warn(
|
|
||||||
"Using 'method_whitelist' with Retry is deprecated and "
|
|
||||||
"will be removed in v2.0. Use 'allowed_methods' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
return self.allowed_methods
|
|
||||||
try:
|
|
||||||
return getattr(super(Retry, self), item)
|
|
||||||
except AttributeError:
|
|
||||||
return getattr(Retry, item)
|
|
||||||
|
|
||||||
|
|
||||||
# For backwards compatibility (equivalent to pre-v1.9):
|
# For backwards compatibility (equivalent to pre-v1.9):
|
||||||
|
|
|
@ -1,185 +1,152 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
import hmac
|
import hmac
|
||||||
import os
|
import os
|
||||||
|
import socket
|
||||||
import sys
|
import sys
|
||||||
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import unhexlify
|
||||||
from hashlib import md5, sha1, sha256
|
from hashlib import md5, sha1, sha256
|
||||||
|
|
||||||
from ..exceptions import (
|
from ..exceptions import ProxySchemeUnsupported, SSLError
|
||||||
InsecurePlatformWarning,
|
from .url import _BRACELESS_IPV6_ADDRZ_RE, _IPV4_RE
|
||||||
ProxySchemeUnsupported,
|
|
||||||
SNIMissingWarning,
|
|
||||||
SSLError,
|
|
||||||
)
|
|
||||||
from ..packages import six
|
|
||||||
from .url import BRACELESS_IPV6_ADDRZ_RE, IPV4_RE
|
|
||||||
|
|
||||||
SSLContext = None
|
SSLContext = None
|
||||||
SSLTransport = None
|
SSLTransport = None
|
||||||
HAS_SNI = False
|
HAS_NEVER_CHECK_COMMON_NAME = False
|
||||||
IS_PYOPENSSL = False
|
IS_PYOPENSSL = False
|
||||||
IS_SECURETRANSPORT = False
|
IS_SECURETRANSPORT = False
|
||||||
ALPN_PROTOCOLS = ["http/1.1"]
|
ALPN_PROTOCOLS = ["http/1.1"]
|
||||||
|
|
||||||
|
_TYPE_VERSION_INFO = typing.Tuple[int, int, int, str, int]
|
||||||
|
|
||||||
# Maps the length of a digest to a possible hash function producing this digest
|
# Maps the length of a digest to a possible hash function producing this digest
|
||||||
HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256}
|
HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256}
|
||||||
|
|
||||||
|
|
||||||
def _const_compare_digest_backport(a, b):
|
def _is_bpo_43522_fixed(
|
||||||
|
implementation_name: str,
|
||||||
|
version_info: _TYPE_VERSION_INFO,
|
||||||
|
pypy_version_info: _TYPE_VERSION_INFO | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True for CPython 3.8.9+, 3.9.3+ or 3.10+ and PyPy 7.3.8+ where
|
||||||
|
setting SSLContext.hostname_checks_common_name to False works.
|
||||||
|
|
||||||
|
Outside of CPython and PyPy we don't know which implementations work
|
||||||
|
or not so we conservatively use our hostname matching as we know that works
|
||||||
|
on all implementations.
|
||||||
|
|
||||||
|
https://github.com/urllib3/urllib3/issues/2192#issuecomment-821832963
|
||||||
|
https://foss.heptapod.net/pypy/pypy/-/issues/3539
|
||||||
"""
|
"""
|
||||||
Compare two digests of equal length in constant time.
|
if implementation_name == "pypy":
|
||||||
|
# https://foss.heptapod.net/pypy/pypy/-/issues/3129
|
||||||
The digests must be of type str/bytes.
|
return pypy_version_info >= (7, 3, 8) and version_info >= (3, 8) # type: ignore[operator]
|
||||||
Returns True if the digests match, and False otherwise.
|
elif implementation_name == "cpython":
|
||||||
"""
|
major_minor = version_info[:2]
|
||||||
result = abs(len(a) - len(b))
|
micro = version_info[2]
|
||||||
for left, right in zip(bytearray(a), bytearray(b)):
|
return (
|
||||||
result |= left ^ right
|
(major_minor == (3, 8) and micro >= 9)
|
||||||
return result == 0
|
or (major_minor == (3, 9) and micro >= 3)
|
||||||
|
or major_minor >= (3, 10)
|
||||||
|
)
|
||||||
|
else: # Defensive:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
_const_compare_digest = getattr(hmac, "compare_digest", _const_compare_digest_backport)
|
def _is_has_never_check_common_name_reliable(
|
||||||
|
openssl_version: str,
|
||||||
|
openssl_version_number: int,
|
||||||
|
implementation_name: str,
|
||||||
|
version_info: _TYPE_VERSION_INFO,
|
||||||
|
pypy_version_info: _TYPE_VERSION_INFO | None,
|
||||||
|
) -> bool:
|
||||||
|
# As of May 2023, all released versions of LibreSSL fail to reject certificates with
|
||||||
|
# only common names, see https://github.com/urllib3/urllib3/pull/3024
|
||||||
|
is_openssl = openssl_version.startswith("OpenSSL ")
|
||||||
|
# Before fixing OpenSSL issue #14579, the SSL_new() API was not copying hostflags
|
||||||
|
# like X509_CHECK_FLAG_NEVER_CHECK_SUBJECT, which tripped up CPython.
|
||||||
|
# https://github.com/openssl/openssl/issues/14579
|
||||||
|
# This was released in OpenSSL 1.1.1l+ (>=0x101010cf)
|
||||||
|
is_openssl_issue_14579_fixed = openssl_version_number >= 0x101010CF
|
||||||
|
|
||||||
try: # Test for SSL features
|
return is_openssl and (
|
||||||
|
is_openssl_issue_14579_fixed
|
||||||
|
or _is_bpo_43522_fixed(implementation_name, version_info, pypy_version_info)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from ssl import VerifyMode
|
||||||
|
|
||||||
|
from typing_extensions import Literal, TypedDict
|
||||||
|
|
||||||
|
from .ssltransport import SSLTransport as SSLTransportType
|
||||||
|
|
||||||
|
class _TYPE_PEER_CERT_RET_DICT(TypedDict, total=False):
|
||||||
|
subjectAltName: tuple[tuple[str, str], ...]
|
||||||
|
subject: tuple[tuple[tuple[str, str], ...], ...]
|
||||||
|
serialNumber: str
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping from 'ssl.PROTOCOL_TLSX' to 'TLSVersion.X'
|
||||||
|
_SSL_VERSION_TO_TLS_VERSION: dict[int, int] = {}
|
||||||
|
|
||||||
|
try: # Do we have ssl at all?
|
||||||
import ssl
|
import ssl
|
||||||
from ssl import CERT_REQUIRED, wrap_socket
|
from ssl import ( # type: ignore[assignment]
|
||||||
except ImportError:
|
CERT_REQUIRED,
|
||||||
pass
|
HAS_NEVER_CHECK_COMMON_NAME,
|
||||||
|
OP_NO_COMPRESSION,
|
||||||
try:
|
OP_NO_TICKET,
|
||||||
from ssl import HAS_SNI # Has SNI?
|
OPENSSL_VERSION,
|
||||||
except ImportError:
|
OPENSSL_VERSION_NUMBER,
|
||||||
pass
|
PROTOCOL_TLS,
|
||||||
|
PROTOCOL_TLS_CLIENT,
|
||||||
try:
|
OP_NO_SSLv2,
|
||||||
from .ssltransport import SSLTransport
|
OP_NO_SSLv3,
|
||||||
except ImportError:
|
SSLContext,
|
||||||
pass
|
TLSVersion,
|
||||||
|
)
|
||||||
|
|
||||||
try: # Platform-specific: Python 3.6
|
|
||||||
from ssl import PROTOCOL_TLS
|
|
||||||
|
|
||||||
PROTOCOL_SSLv23 = PROTOCOL_TLS
|
PROTOCOL_SSLv23 = PROTOCOL_TLS
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
from ssl import PROTOCOL_SSLv23 as PROTOCOL_TLS
|
|
||||||
|
|
||||||
PROTOCOL_SSLv23 = PROTOCOL_TLS
|
# Setting SSLContext.hostname_checks_common_name = False didn't work before CPython
|
||||||
except ImportError:
|
# 3.8.9, 3.9.3, and 3.10 (but OK on PyPy) or OpenSSL 1.1.1l+
|
||||||
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2
|
if HAS_NEVER_CHECK_COMMON_NAME and not _is_has_never_check_common_name_reliable(
|
||||||
|
OPENSSL_VERSION,
|
||||||
|
OPENSSL_VERSION_NUMBER,
|
||||||
|
sys.implementation.name,
|
||||||
|
sys.version_info,
|
||||||
|
sys.pypy_version_info if sys.implementation.name == "pypy" else None, # type: ignore[attr-defined]
|
||||||
|
):
|
||||||
|
HAS_NEVER_CHECK_COMMON_NAME = False
|
||||||
|
|
||||||
try:
|
# Need to be careful here in case old TLS versions get
|
||||||
from ssl import PROTOCOL_TLS_CLIENT
|
# removed in future 'ssl' module implementations.
|
||||||
except ImportError:
|
for attr in ("TLSv1", "TLSv1_1", "TLSv1_2"):
|
||||||
PROTOCOL_TLS_CLIENT = PROTOCOL_TLS
|
try:
|
||||||
|
_SSL_VERSION_TO_TLS_VERSION[getattr(ssl, f"PROTOCOL_{attr}")] = getattr(
|
||||||
|
TLSVersion, attr
|
||||||
try:
|
|
||||||
from ssl import OP_NO_COMPRESSION, OP_NO_SSLv2, OP_NO_SSLv3
|
|
||||||
except ImportError:
|
|
||||||
OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000
|
|
||||||
OP_NO_COMPRESSION = 0x20000
|
|
||||||
|
|
||||||
|
|
||||||
try: # OP_NO_TICKET was added in Python 3.6
|
|
||||||
from ssl import OP_NO_TICKET
|
|
||||||
except ImportError:
|
|
||||||
OP_NO_TICKET = 0x4000
|
|
||||||
|
|
||||||
|
|
||||||
# A secure default.
|
|
||||||
# Sources for more information on TLS ciphers:
|
|
||||||
#
|
|
||||||
# - https://wiki.mozilla.org/Security/Server_Side_TLS
|
|
||||||
# - https://www.ssllabs.com/projects/best-practices/index.html
|
|
||||||
# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
|
|
||||||
#
|
|
||||||
# The general intent is:
|
|
||||||
# - prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE),
|
|
||||||
# - prefer ECDHE over DHE for better performance,
|
|
||||||
# - prefer any AES-GCM and ChaCha20 over any AES-CBC for better performance and
|
|
||||||
# security,
|
|
||||||
# - prefer AES-GCM over ChaCha20 because hardware-accelerated AES is common,
|
|
||||||
# - disable NULL authentication, MD5 MACs, DSS, and other
|
|
||||||
# insecure ciphers for security reasons.
|
|
||||||
# - NOTE: TLS 1.3 cipher suites are managed through a different interface
|
|
||||||
# not exposed by CPython (yet!) and are enabled by default if they're available.
|
|
||||||
DEFAULT_CIPHERS = ":".join(
|
|
||||||
[
|
|
||||||
"ECDHE+AESGCM",
|
|
||||||
"ECDHE+CHACHA20",
|
|
||||||
"DHE+AESGCM",
|
|
||||||
"DHE+CHACHA20",
|
|
||||||
"ECDH+AESGCM",
|
|
||||||
"DH+AESGCM",
|
|
||||||
"ECDH+AES",
|
|
||||||
"DH+AES",
|
|
||||||
"RSA+AESGCM",
|
|
||||||
"RSA+AES",
|
|
||||||
"!aNULL",
|
|
||||||
"!eNULL",
|
|
||||||
"!MD5",
|
|
||||||
"!DSS",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from ssl import SSLContext # Modern SSL?
|
|
||||||
except ImportError:
|
|
||||||
|
|
||||||
class SSLContext(object): # Platform-specific: Python 2
|
|
||||||
def __init__(self, protocol_version):
|
|
||||||
self.protocol = protocol_version
|
|
||||||
# Use default values from a real SSLContext
|
|
||||||
self.check_hostname = False
|
|
||||||
self.verify_mode = ssl.CERT_NONE
|
|
||||||
self.ca_certs = None
|
|
||||||
self.options = 0
|
|
||||||
self.certfile = None
|
|
||||||
self.keyfile = None
|
|
||||||
self.ciphers = None
|
|
||||||
|
|
||||||
def load_cert_chain(self, certfile, keyfile):
|
|
||||||
self.certfile = certfile
|
|
||||||
self.keyfile = keyfile
|
|
||||||
|
|
||||||
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
|
|
||||||
self.ca_certs = cafile
|
|
||||||
|
|
||||||
if capath is not None:
|
|
||||||
raise SSLError("CA directories not supported in older Pythons")
|
|
||||||
|
|
||||||
if cadata is not None:
|
|
||||||
raise SSLError("CA data not supported in older Pythons")
|
|
||||||
|
|
||||||
def set_ciphers(self, cipher_suite):
|
|
||||||
self.ciphers = cipher_suite
|
|
||||||
|
|
||||||
def wrap_socket(self, socket, server_hostname=None, server_side=False):
|
|
||||||
warnings.warn(
|
|
||||||
"A true SSLContext object is not available. This prevents "
|
|
||||||
"urllib3 from configuring SSL appropriately and may cause "
|
|
||||||
"certain SSL connections to fail. You can upgrade to a newer "
|
|
||||||
"version of Python to solve this. For more information, see "
|
|
||||||
"https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html"
|
|
||||||
"#ssl-warnings",
|
|
||||||
InsecurePlatformWarning,
|
|
||||||
)
|
)
|
||||||
kwargs = {
|
except AttributeError: # Defensive:
|
||||||
"keyfile": self.keyfile,
|
continue
|
||||||
"certfile": self.certfile,
|
|
||||||
"ca_certs": self.ca_certs,
|
from .ssltransport import SSLTransport # type: ignore[assignment]
|
||||||
"cert_reqs": self.verify_mode,
|
except ImportError:
|
||||||
"ssl_version": self.protocol,
|
OP_NO_COMPRESSION = 0x20000 # type: ignore[assignment]
|
||||||
"server_side": server_side,
|
OP_NO_TICKET = 0x4000 # type: ignore[assignment]
|
||||||
}
|
OP_NO_SSLv2 = 0x1000000 # type: ignore[assignment]
|
||||||
return wrap_socket(socket, ciphers=self.ciphers, **kwargs)
|
OP_NO_SSLv3 = 0x2000000 # type: ignore[assignment]
|
||||||
|
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2 # type: ignore[assignment]
|
||||||
|
PROTOCOL_TLS_CLIENT = 16 # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
def assert_fingerprint(cert, fingerprint):
|
_TYPE_PEER_CERT_RET = typing.Union["_TYPE_PEER_CERT_RET_DICT", bytes, None]
|
||||||
|
|
||||||
|
|
||||||
|
def assert_fingerprint(cert: bytes | None, fingerprint: str) -> None:
|
||||||
"""
|
"""
|
||||||
Checks if given fingerprint matches the supplied certificate.
|
Checks if given fingerprint matches the supplied certificate.
|
||||||
|
|
||||||
|
@ -189,26 +156,27 @@ def assert_fingerprint(cert, fingerprint):
|
||||||
Fingerprint as string of hexdigits, can be interspersed by colons.
|
Fingerprint as string of hexdigits, can be interspersed by colons.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if cert is None:
|
||||||
|
raise SSLError("No certificate for the peer.")
|
||||||
|
|
||||||
fingerprint = fingerprint.replace(":", "").lower()
|
fingerprint = fingerprint.replace(":", "").lower()
|
||||||
digest_length = len(fingerprint)
|
digest_length = len(fingerprint)
|
||||||
hashfunc = HASHFUNC_MAP.get(digest_length)
|
hashfunc = HASHFUNC_MAP.get(digest_length)
|
||||||
if not hashfunc:
|
if not hashfunc:
|
||||||
raise SSLError("Fingerprint of invalid length: {0}".format(fingerprint))
|
raise SSLError(f"Fingerprint of invalid length: {fingerprint}")
|
||||||
|
|
||||||
# We need encode() here for py32; works on py2 and p33.
|
# We need encode() here for py32; works on py2 and p33.
|
||||||
fingerprint_bytes = unhexlify(fingerprint.encode())
|
fingerprint_bytes = unhexlify(fingerprint.encode())
|
||||||
|
|
||||||
cert_digest = hashfunc(cert).digest()
|
cert_digest = hashfunc(cert).digest()
|
||||||
|
|
||||||
if not _const_compare_digest(cert_digest, fingerprint_bytes):
|
if not hmac.compare_digest(cert_digest, fingerprint_bytes):
|
||||||
raise SSLError(
|
raise SSLError(
|
||||||
'Fingerprints did not match. Expected "{0}", got "{1}".'.format(
|
f'Fingerprints did not match. Expected "{fingerprint}", got "{cert_digest.hex()}"'
|
||||||
fingerprint, hexlify(cert_digest)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def resolve_cert_reqs(candidate):
|
def resolve_cert_reqs(candidate: None | int | str) -> VerifyMode:
|
||||||
"""
|
"""
|
||||||
Resolves the argument to a numeric constant, which can be passed to
|
Resolves the argument to a numeric constant, which can be passed to
|
||||||
the wrap_socket function/method from the ssl module.
|
the wrap_socket function/method from the ssl module.
|
||||||
|
@ -226,12 +194,12 @@ def resolve_cert_reqs(candidate):
|
||||||
res = getattr(ssl, candidate, None)
|
res = getattr(ssl, candidate, None)
|
||||||
if res is None:
|
if res is None:
|
||||||
res = getattr(ssl, "CERT_" + candidate)
|
res = getattr(ssl, "CERT_" + candidate)
|
||||||
return res
|
return res # type: ignore[no-any-return]
|
||||||
|
|
||||||
return candidate
|
return candidate # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
def resolve_ssl_version(candidate):
|
def resolve_ssl_version(candidate: None | int | str) -> int:
|
||||||
"""
|
"""
|
||||||
like resolve_cert_reqs
|
like resolve_cert_reqs
|
||||||
"""
|
"""
|
||||||
|
@ -242,35 +210,33 @@ def resolve_ssl_version(candidate):
|
||||||
res = getattr(ssl, candidate, None)
|
res = getattr(ssl, candidate, None)
|
||||||
if res is None:
|
if res is None:
|
||||||
res = getattr(ssl, "PROTOCOL_" + candidate)
|
res = getattr(ssl, "PROTOCOL_" + candidate)
|
||||||
return res
|
return typing.cast(int, res)
|
||||||
|
|
||||||
return candidate
|
return candidate
|
||||||
|
|
||||||
|
|
||||||
def create_urllib3_context(
|
def create_urllib3_context(
|
||||||
ssl_version=None, cert_reqs=None, options=None, ciphers=None
|
ssl_version: int | None = None,
|
||||||
):
|
cert_reqs: int | None = None,
|
||||||
"""All arguments have the same meaning as ``ssl_wrap_socket``.
|
options: int | None = None,
|
||||||
|
ciphers: str | None = None,
|
||||||
By default, this function does a lot of the same work that
|
ssl_minimum_version: int | None = None,
|
||||||
``ssl.create_default_context`` does on Python 3.4+. It:
|
ssl_maximum_version: int | None = None,
|
||||||
|
) -> ssl.SSLContext:
|
||||||
- Disables SSLv2, SSLv3, and compression
|
"""Creates and configures an :class:`ssl.SSLContext` instance for use with urllib3.
|
||||||
- Sets a restricted set of server ciphers
|
|
||||||
|
|
||||||
If you wish to enable SSLv3, you can do::
|
|
||||||
|
|
||||||
from urllib3.util import ssl_
|
|
||||||
context = ssl_.create_urllib3_context()
|
|
||||||
context.options &= ~ssl_.OP_NO_SSLv3
|
|
||||||
|
|
||||||
You can do the same to enable compression (substituting ``COMPRESSION``
|
|
||||||
for ``SSLv3`` in the last line above).
|
|
||||||
|
|
||||||
:param ssl_version:
|
:param ssl_version:
|
||||||
The desired protocol version to use. This will default to
|
The desired protocol version to use. This will default to
|
||||||
PROTOCOL_SSLv23 which will negotiate the highest protocol that both
|
PROTOCOL_SSLv23 which will negotiate the highest protocol that both
|
||||||
the server and your installation of OpenSSL support.
|
the server and your installation of OpenSSL support.
|
||||||
|
|
||||||
|
This parameter is deprecated instead use 'ssl_minimum_version'.
|
||||||
|
:param ssl_minimum_version:
|
||||||
|
The minimum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value.
|
||||||
|
:param ssl_maximum_version:
|
||||||
|
The maximum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value.
|
||||||
|
Not recommended to set to anything other than 'ssl.TLSVersion.MAXIMUM_SUPPORTED' which is the
|
||||||
|
default value.
|
||||||
:param cert_reqs:
|
:param cert_reqs:
|
||||||
Whether to require the certificate verification. This defaults to
|
Whether to require the certificate verification. This defaults to
|
||||||
``ssl.CERT_REQUIRED``.
|
``ssl.CERT_REQUIRED``.
|
||||||
|
@ -278,18 +244,60 @@ def create_urllib3_context(
|
||||||
Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``,
|
Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``,
|
||||||
``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``, and ``ssl.OP_NO_TICKET``.
|
``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``, and ``ssl.OP_NO_TICKET``.
|
||||||
:param ciphers:
|
:param ciphers:
|
||||||
Which cipher suites to allow the server to select.
|
Which cipher suites to allow the server to select. Defaults to either system configured
|
||||||
|
ciphers if OpenSSL 1.1.1+, otherwise uses a secure default set of ciphers.
|
||||||
:returns:
|
:returns:
|
||||||
Constructed SSLContext object with specified options
|
Constructed SSLContext object with specified options
|
||||||
:rtype: SSLContext
|
:rtype: SSLContext
|
||||||
"""
|
"""
|
||||||
# PROTOCOL_TLS is deprecated in Python 3.10
|
if SSLContext is None:
|
||||||
if not ssl_version or ssl_version == PROTOCOL_TLS:
|
raise TypeError("Can't create an SSLContext object without an ssl module")
|
||||||
ssl_version = PROTOCOL_TLS_CLIENT
|
|
||||||
|
|
||||||
context = SSLContext(ssl_version)
|
# This means 'ssl_version' was specified as an exact value.
|
||||||
|
if ssl_version not in (None, PROTOCOL_TLS, PROTOCOL_TLS_CLIENT):
|
||||||
|
# Disallow setting 'ssl_version' and 'ssl_minimum|maximum_version'
|
||||||
|
# to avoid conflicts.
|
||||||
|
if ssl_minimum_version is not None or ssl_maximum_version is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Can't specify both 'ssl_version' and either "
|
||||||
|
"'ssl_minimum_version' or 'ssl_maximum_version'"
|
||||||
|
)
|
||||||
|
|
||||||
context.set_ciphers(ciphers or DEFAULT_CIPHERS)
|
# 'ssl_version' is deprecated and will be removed in the future.
|
||||||
|
else:
|
||||||
|
# Use 'ssl_minimum_version' and 'ssl_maximum_version' instead.
|
||||||
|
ssl_minimum_version = _SSL_VERSION_TO_TLS_VERSION.get(
|
||||||
|
ssl_version, TLSVersion.MINIMUM_SUPPORTED
|
||||||
|
)
|
||||||
|
ssl_maximum_version = _SSL_VERSION_TO_TLS_VERSION.get(
|
||||||
|
ssl_version, TLSVersion.MAXIMUM_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
# This warning message is pushing users to use 'ssl_minimum_version'
|
||||||
|
# instead of both min/max. Best practice is to only set the minimum version and
|
||||||
|
# keep the maximum version to be it's default value: 'TLSVersion.MAXIMUM_SUPPORTED'
|
||||||
|
warnings.warn(
|
||||||
|
"'ssl_version' option is deprecated and will be "
|
||||||
|
"removed in urllib3 v2.1.0. Instead use 'ssl_minimum_version'",
|
||||||
|
category=DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# PROTOCOL_TLS is deprecated in Python 3.10 so we always use PROTOCOL_TLS_CLIENT
|
||||||
|
context = SSLContext(PROTOCOL_TLS_CLIENT)
|
||||||
|
|
||||||
|
if ssl_minimum_version is not None:
|
||||||
|
context.minimum_version = ssl_minimum_version
|
||||||
|
else: # Python <3.10 defaults to 'MINIMUM_SUPPORTED' so explicitly set TLSv1.2 here
|
||||||
|
context.minimum_version = TLSVersion.TLSv1_2
|
||||||
|
|
||||||
|
if ssl_maximum_version is not None:
|
||||||
|
context.maximum_version = ssl_maximum_version
|
||||||
|
|
||||||
|
# Unless we're given ciphers defer to either system ciphers in
|
||||||
|
# the case of OpenSSL 1.1.1+ or use our own secure default ciphers.
|
||||||
|
if ciphers:
|
||||||
|
context.set_ciphers(ciphers)
|
||||||
|
|
||||||
# Setting the default here, as we may have no ssl module on import
|
# Setting the default here, as we may have no ssl module on import
|
||||||
cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
|
cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
|
||||||
|
@ -322,26 +330,23 @@ def create_urllib3_context(
|
||||||
) is not None:
|
) is not None:
|
||||||
context.post_handshake_auth = True
|
context.post_handshake_auth = True
|
||||||
|
|
||||||
def disable_check_hostname():
|
|
||||||
if (
|
|
||||||
getattr(context, "check_hostname", None) is not None
|
|
||||||
): # Platform-specific: Python 3.2
|
|
||||||
# We do our own verification, including fingerprints and alternative
|
|
||||||
# hostnames. So disable it here
|
|
||||||
context.check_hostname = False
|
|
||||||
|
|
||||||
# The order of the below lines setting verify_mode and check_hostname
|
# The order of the below lines setting verify_mode and check_hostname
|
||||||
# matter due to safe-guards SSLContext has to prevent an SSLContext with
|
# matter due to safe-guards SSLContext has to prevent an SSLContext with
|
||||||
# check_hostname=True, verify_mode=NONE/OPTIONAL. This is made even more
|
# check_hostname=True, verify_mode=NONE/OPTIONAL.
|
||||||
# complex because we don't know whether PROTOCOL_TLS_CLIENT will be used
|
# We always set 'check_hostname=False' for pyOpenSSL so we rely on our own
|
||||||
# or not so we don't know the initial state of the freshly created SSLContext.
|
# 'ssl.match_hostname()' implementation.
|
||||||
if cert_reqs == ssl.CERT_REQUIRED:
|
if cert_reqs == ssl.CERT_REQUIRED and not IS_PYOPENSSL:
|
||||||
context.verify_mode = cert_reqs
|
context.verify_mode = cert_reqs
|
||||||
disable_check_hostname()
|
context.check_hostname = True
|
||||||
else:
|
else:
|
||||||
disable_check_hostname()
|
context.check_hostname = False
|
||||||
context.verify_mode = cert_reqs
|
context.verify_mode = cert_reqs
|
||||||
|
|
||||||
|
try:
|
||||||
|
context.hostname_checks_common_name = False
|
||||||
|
except AttributeError: # Defensive: for CPython < 3.8.9 and 3.9.3; for PyPy < 7.3.8
|
||||||
|
pass
|
||||||
|
|
||||||
# Enable logging of TLS session keys via defacto standard environment variable
|
# Enable logging of TLS session keys via defacto standard environment variable
|
||||||
# 'SSLKEYLOGFILE', if the feature is available (Python 3.8+). Skip empty values.
|
# 'SSLKEYLOGFILE', if the feature is available (Python 3.8+). Skip empty values.
|
||||||
if hasattr(context, "keylog_filename"):
|
if hasattr(context, "keylog_filename"):
|
||||||
|
@ -352,21 +357,59 @@ def create_urllib3_context(
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
@typing.overload
|
||||||
def ssl_wrap_socket(
|
def ssl_wrap_socket(
|
||||||
sock,
|
sock: socket.socket,
|
||||||
keyfile=None,
|
keyfile: str | None = ...,
|
||||||
certfile=None,
|
certfile: str | None = ...,
|
||||||
cert_reqs=None,
|
cert_reqs: int | None = ...,
|
||||||
ca_certs=None,
|
ca_certs: str | None = ...,
|
||||||
server_hostname=None,
|
server_hostname: str | None = ...,
|
||||||
ssl_version=None,
|
ssl_version: int | None = ...,
|
||||||
ciphers=None,
|
ciphers: str | None = ...,
|
||||||
ssl_context=None,
|
ssl_context: ssl.SSLContext | None = ...,
|
||||||
ca_cert_dir=None,
|
ca_cert_dir: str | None = ...,
|
||||||
key_password=None,
|
key_password: str | None = ...,
|
||||||
ca_cert_data=None,
|
ca_cert_data: None | str | bytes = ...,
|
||||||
tls_in_tls=False,
|
tls_in_tls: Literal[False] = ...,
|
||||||
):
|
) -> ssl.SSLSocket:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@typing.overload
|
||||||
|
def ssl_wrap_socket(
|
||||||
|
sock: socket.socket,
|
||||||
|
keyfile: str | None = ...,
|
||||||
|
certfile: str | None = ...,
|
||||||
|
cert_reqs: int | None = ...,
|
||||||
|
ca_certs: str | None = ...,
|
||||||
|
server_hostname: str | None = ...,
|
||||||
|
ssl_version: int | None = ...,
|
||||||
|
ciphers: str | None = ...,
|
||||||
|
ssl_context: ssl.SSLContext | None = ...,
|
||||||
|
ca_cert_dir: str | None = ...,
|
||||||
|
key_password: str | None = ...,
|
||||||
|
ca_cert_data: None | str | bytes = ...,
|
||||||
|
tls_in_tls: bool = ...,
|
||||||
|
) -> ssl.SSLSocket | SSLTransportType:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def ssl_wrap_socket(
|
||||||
|
sock: socket.socket,
|
||||||
|
keyfile: str | None = None,
|
||||||
|
certfile: str | None = None,
|
||||||
|
cert_reqs: int | None = None,
|
||||||
|
ca_certs: str | None = None,
|
||||||
|
server_hostname: str | None = None,
|
||||||
|
ssl_version: int | None = None,
|
||||||
|
ciphers: str | None = None,
|
||||||
|
ssl_context: ssl.SSLContext | None = None,
|
||||||
|
ca_cert_dir: str | None = None,
|
||||||
|
key_password: str | None = None,
|
||||||
|
ca_cert_data: None | str | bytes = None,
|
||||||
|
tls_in_tls: bool = False,
|
||||||
|
) -> ssl.SSLSocket | SSLTransportType:
|
||||||
"""
|
"""
|
||||||
All arguments except for server_hostname, ssl_context, and ca_cert_dir have
|
All arguments except for server_hostname, ssl_context, and ca_cert_dir have
|
||||||
the same meaning as they do when using :func:`ssl.wrap_socket`.
|
the same meaning as they do when using :func:`ssl.wrap_socket`.
|
||||||
|
@ -392,19 +435,18 @@ def ssl_wrap_socket(
|
||||||
"""
|
"""
|
||||||
context = ssl_context
|
context = ssl_context
|
||||||
if context is None:
|
if context is None:
|
||||||
# Note: This branch of code and all the variables in it are no longer
|
# Note: This branch of code and all the variables in it are only used in tests.
|
||||||
# used by urllib3 itself. We should consider deprecating and removing
|
# We should consider deprecating and removing this code.
|
||||||
# this code.
|
|
||||||
context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers)
|
context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers)
|
||||||
|
|
||||||
if ca_certs or ca_cert_dir or ca_cert_data:
|
if ca_certs or ca_cert_dir or ca_cert_data:
|
||||||
try:
|
try:
|
||||||
context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data)
|
context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data)
|
||||||
except (IOError, OSError) as e:
|
except OSError as e:
|
||||||
raise SSLError(e)
|
raise SSLError(e) from e
|
||||||
|
|
||||||
elif ssl_context is None and hasattr(context, "load_default_certs"):
|
elif ssl_context is None and hasattr(context, "load_default_certs"):
|
||||||
# try to load OS default certs; works well on Windows (require Python3.4+)
|
# try to load OS default certs; works well on Windows.
|
||||||
context.load_default_certs()
|
context.load_default_certs()
|
||||||
|
|
||||||
# Attempt to detect if we get the goofy behavior of the
|
# Attempt to detect if we get the goofy behavior of the
|
||||||
|
@ -420,56 +462,30 @@ def ssl_wrap_socket(
|
||||||
context.load_cert_chain(certfile, keyfile, key_password)
|
context.load_cert_chain(certfile, keyfile, key_password)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(context, "set_alpn_protocols"):
|
context.set_alpn_protocols(ALPN_PROTOCOLS)
|
||||||
context.set_alpn_protocols(ALPN_PROTOCOLS)
|
|
||||||
except NotImplementedError: # Defensive: in CI, we always have set_alpn_protocols
|
except NotImplementedError: # Defensive: in CI, we always have set_alpn_protocols
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# If we detect server_hostname is an IP address then the SNI
|
ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls, server_hostname)
|
||||||
# extension should not be used according to RFC3546 Section 3.1
|
|
||||||
use_sni_hostname = server_hostname and not is_ipaddress(server_hostname)
|
|
||||||
# SecureTransport uses server_hostname in certificate verification.
|
|
||||||
send_sni = (use_sni_hostname and HAS_SNI) or (
|
|
||||||
IS_SECURETRANSPORT and server_hostname
|
|
||||||
)
|
|
||||||
# Do not warn the user if server_hostname is an invalid SNI hostname.
|
|
||||||
if not HAS_SNI and use_sni_hostname:
|
|
||||||
warnings.warn(
|
|
||||||
"An HTTPS request has been made, but the SNI (Server Name "
|
|
||||||
"Indication) extension to TLS is not available on this platform. "
|
|
||||||
"This may cause the server to present an incorrect TLS "
|
|
||||||
"certificate, which can cause validation failures. You can upgrade to "
|
|
||||||
"a newer version of Python to solve this. For more information, see "
|
|
||||||
"https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html"
|
|
||||||
"#ssl-warnings",
|
|
||||||
SNIMissingWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
if send_sni:
|
|
||||||
ssl_sock = _ssl_wrap_socket_impl(
|
|
||||||
sock, context, tls_in_tls, server_hostname=server_hostname
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls)
|
|
||||||
return ssl_sock
|
return ssl_sock
|
||||||
|
|
||||||
|
|
||||||
def is_ipaddress(hostname):
|
def is_ipaddress(hostname: str | bytes) -> bool:
|
||||||
"""Detects whether the hostname given is an IPv4 or IPv6 address.
|
"""Detects whether the hostname given is an IPv4 or IPv6 address.
|
||||||
Also detects IPv6 addresses with Zone IDs.
|
Also detects IPv6 addresses with Zone IDs.
|
||||||
|
|
||||||
:param str hostname: Hostname to examine.
|
:param str hostname: Hostname to examine.
|
||||||
:return: True if the hostname is an IP address, False otherwise.
|
:return: True if the hostname is an IP address, False otherwise.
|
||||||
"""
|
"""
|
||||||
if not six.PY2 and isinstance(hostname, bytes):
|
if isinstance(hostname, bytes):
|
||||||
# IDN A-label bytes are ASCII compatible.
|
# IDN A-label bytes are ASCII compatible.
|
||||||
hostname = hostname.decode("ascii")
|
hostname = hostname.decode("ascii")
|
||||||
return bool(IPV4_RE.match(hostname) or BRACELESS_IPV6_ADDRZ_RE.match(hostname))
|
return bool(_IPV4_RE.match(hostname) or _BRACELESS_IPV6_ADDRZ_RE.match(hostname))
|
||||||
|
|
||||||
|
|
||||||
def _is_key_file_encrypted(key_file):
|
def _is_key_file_encrypted(key_file: str) -> bool:
|
||||||
"""Detects if a key file is encrypted or not."""
|
"""Detects if a key file is encrypted or not."""
|
||||||
with open(key_file, "r") as f:
|
with open(key_file) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
# Look for Proc-Type: 4,ENCRYPTED
|
# Look for Proc-Type: 4,ENCRYPTED
|
||||||
if "ENCRYPTED" in line:
|
if "ENCRYPTED" in line:
|
||||||
|
@ -478,7 +494,12 @@ def _is_key_file_encrypted(key_file):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _ssl_wrap_socket_impl(sock, ssl_context, tls_in_tls, server_hostname=None):
|
def _ssl_wrap_socket_impl(
|
||||||
|
sock: socket.socket,
|
||||||
|
ssl_context: ssl.SSLContext,
|
||||||
|
tls_in_tls: bool,
|
||||||
|
server_hostname: str | None = None,
|
||||||
|
) -> ssl.SSLSocket | SSLTransportType:
|
||||||
if tls_in_tls:
|
if tls_in_tls:
|
||||||
if not SSLTransport:
|
if not SSLTransport:
|
||||||
# Import error, ssl is not available.
|
# Import error, ssl is not available.
|
||||||
|
@ -489,7 +510,4 @@ def _ssl_wrap_socket_impl(sock, ssl_context, tls_in_tls, server_hostname=None):
|
||||||
SSLTransport._validate_ssl_context_for_tls_in_tls(ssl_context)
|
SSLTransport._validate_ssl_context_for_tls_in_tls(ssl_context)
|
||||||
return SSLTransport(sock, ssl_context, server_hostname)
|
return SSLTransport(sock, ssl_context, server_hostname)
|
||||||
|
|
||||||
if server_hostname:
|
return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
|
||||||
return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
|
|
||||||
else:
|
|
||||||
return ssl_context.wrap_socket(sock)
|
|
||||||
|
|
|
@ -1,19 +1,18 @@
|
||||||
"""The match_hostname() function from Python 3.3.3, essential when using SSL."""
|
"""The match_hostname() function from Python 3.5, essential when using SSL."""
|
||||||
|
|
||||||
# Note: This file is under the PSF license as the code comes from the python
|
# Note: This file is under the PSF license as the code comes from the python
|
||||||
# stdlib. http://docs.python.org/3/license.html
|
# stdlib. http://docs.python.org/3/license.html
|
||||||
|
# It is modified to remove commonName support.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
import re
|
import re
|
||||||
import sys
|
import typing
|
||||||
|
from ipaddress import IPv4Address, IPv6Address
|
||||||
|
|
||||||
# ipaddress has been backported to 2.6+ in pypi. If it is installed on the
|
if typing.TYPE_CHECKING:
|
||||||
# system, use it to handle IPAddress ServerAltnames (this was added in
|
from .ssl_ import _TYPE_PEER_CERT_RET_DICT
|
||||||
# python-3.5) otherwise only do DNS matching. This allows
|
|
||||||
# util.ssl_match_hostname to continue to be used in Python 2.7.
|
|
||||||
try:
|
|
||||||
import ipaddress
|
|
||||||
except ImportError:
|
|
||||||
ipaddress = None
|
|
||||||
|
|
||||||
__version__ = "3.5.0.1"
|
__version__ = "3.5.0.1"
|
||||||
|
|
||||||
|
@ -22,7 +21,9 @@ class CertificateError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _dnsname_match(dn, hostname, max_wildcards=1):
|
def _dnsname_match(
|
||||||
|
dn: typing.Any, hostname: str, max_wildcards: int = 1
|
||||||
|
) -> typing.Match[str] | None | bool:
|
||||||
"""Matching according to RFC 6125, section 6.4.3
|
"""Matching according to RFC 6125, section 6.4.3
|
||||||
|
|
||||||
http://tools.ietf.org/html/rfc6125#section-6.4.3
|
http://tools.ietf.org/html/rfc6125#section-6.4.3
|
||||||
|
@ -49,7 +50,7 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
|
||||||
|
|
||||||
# speed up common case w/o wildcards
|
# speed up common case w/o wildcards
|
||||||
if not wildcards:
|
if not wildcards:
|
||||||
return dn.lower() == hostname.lower()
|
return bool(dn.lower() == hostname.lower())
|
||||||
|
|
||||||
# RFC 6125, section 6.4.3, subitem 1.
|
# RFC 6125, section 6.4.3, subitem 1.
|
||||||
# The client SHOULD NOT attempt to match a presented identifier in which
|
# The client SHOULD NOT attempt to match a presented identifier in which
|
||||||
|
@ -76,26 +77,26 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
|
||||||
return pat.match(hostname)
|
return pat.match(hostname)
|
||||||
|
|
||||||
|
|
||||||
def _to_unicode(obj):
|
def _ipaddress_match(ipname: str, host_ip: IPv4Address | IPv6Address) -> bool:
|
||||||
if isinstance(obj, str) and sys.version_info < (3,):
|
|
||||||
# ignored flake8 # F821 to support python 2.7 function
|
|
||||||
obj = unicode(obj, encoding="ascii", errors="strict") # noqa: F821
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def _ipaddress_match(ipname, host_ip):
|
|
||||||
"""Exact matching of IP addresses.
|
"""Exact matching of IP addresses.
|
||||||
|
|
||||||
RFC 6125 explicitly doesn't define an algorithm for this
|
RFC 9110 section 4.3.5: "A reference identity of IP-ID contains the decoded
|
||||||
(section 1.7.2 - "Out of Scope").
|
bytes of the IP address. An IP version 4 address is 4 octets, and an IP
|
||||||
|
version 6 address is 16 octets. [...] A reference identity of type IP-ID
|
||||||
|
matches if the address is identical to an iPAddress value of the
|
||||||
|
subjectAltName extension of the certificate."
|
||||||
"""
|
"""
|
||||||
# OpenSSL may add a trailing newline to a subjectAltName's IP address
|
# OpenSSL may add a trailing newline to a subjectAltName's IP address
|
||||||
# Divergence from upstream: ipaddress can't handle byte str
|
# Divergence from upstream: ipaddress can't handle byte str
|
||||||
ip = ipaddress.ip_address(_to_unicode(ipname).rstrip())
|
ip = ipaddress.ip_address(ipname.rstrip())
|
||||||
return ip == host_ip
|
return bool(ip.packed == host_ip.packed)
|
||||||
|
|
||||||
|
|
||||||
def match_hostname(cert, hostname):
|
def match_hostname(
|
||||||
|
cert: _TYPE_PEER_CERT_RET_DICT | None,
|
||||||
|
hostname: str,
|
||||||
|
hostname_checks_common_name: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Verify that *cert* (in decoded format as returned by
|
"""Verify that *cert* (in decoded format as returned by
|
||||||
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
|
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
|
||||||
rules are followed, but IP addresses are not accepted for *hostname*.
|
rules are followed, but IP addresses are not accepted for *hostname*.
|
||||||
|
@ -111,21 +112,22 @@ def match_hostname(cert, hostname):
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# Divergence from upstream: ipaddress can't handle byte str
|
# Divergence from upstream: ipaddress can't handle byte str
|
||||||
host_ip = ipaddress.ip_address(_to_unicode(hostname))
|
#
|
||||||
except (UnicodeError, ValueError):
|
# The ipaddress module shipped with Python < 3.9 does not support
|
||||||
# ValueError: Not an IP address (common case)
|
# scoped IPv6 addresses so we unconditionally strip the Zone IDs for
|
||||||
# UnicodeError: Divergence from upstream: Have to deal with ipaddress not taking
|
# now. Once we drop support for Python 3.9 we can remove this branch.
|
||||||
# byte strings. addresses should be all ascii, so we consider it not
|
if "%" in hostname:
|
||||||
# an ipaddress in this case
|
host_ip = ipaddress.ip_address(hostname[: hostname.rfind("%")])
|
||||||
|
else:
|
||||||
|
host_ip = ipaddress.ip_address(hostname)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
# Not an IP address (common case)
|
||||||
host_ip = None
|
host_ip = None
|
||||||
except AttributeError:
|
|
||||||
# Divergence from upstream: Make ipaddress library optional
|
|
||||||
if ipaddress is None:
|
|
||||||
host_ip = None
|
|
||||||
else: # Defensive
|
|
||||||
raise
|
|
||||||
dnsnames = []
|
dnsnames = []
|
||||||
san = cert.get("subjectAltName", ())
|
san: tuple[tuple[str, str], ...] = cert.get("subjectAltName", ())
|
||||||
|
key: str
|
||||||
|
value: str
|
||||||
for key, value in san:
|
for key, value in san:
|
||||||
if key == "DNS":
|
if key == "DNS":
|
||||||
if host_ip is None and _dnsname_match(value, hostname):
|
if host_ip is None and _dnsname_match(value, hostname):
|
||||||
|
@ -135,25 +137,23 @@ def match_hostname(cert, hostname):
|
||||||
if host_ip is not None and _ipaddress_match(value, host_ip):
|
if host_ip is not None and _ipaddress_match(value, host_ip):
|
||||||
return
|
return
|
||||||
dnsnames.append(value)
|
dnsnames.append(value)
|
||||||
if not dnsnames:
|
|
||||||
# The subject is only checked when there is no dNSName entry
|
# We only check 'commonName' if it's enabled and we're not verifying
|
||||||
# in subjectAltName
|
# an IP address. IP addresses aren't valid within 'commonName'.
|
||||||
|
if hostname_checks_common_name and host_ip is None and not dnsnames:
|
||||||
for sub in cert.get("subject", ()):
|
for sub in cert.get("subject", ()):
|
||||||
for key, value in sub:
|
for key, value in sub:
|
||||||
# XXX according to RFC 2818, the most specific Common Name
|
|
||||||
# must be used.
|
|
||||||
if key == "commonName":
|
if key == "commonName":
|
||||||
if _dnsname_match(value, hostname):
|
if _dnsname_match(value, hostname):
|
||||||
return
|
return
|
||||||
dnsnames.append(value)
|
dnsnames.append(value)
|
||||||
|
|
||||||
if len(dnsnames) > 1:
|
if len(dnsnames) > 1:
|
||||||
raise CertificateError(
|
raise CertificateError(
|
||||||
"hostname %r "
|
"hostname %r "
|
||||||
"doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
|
"doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
|
||||||
)
|
)
|
||||||
elif len(dnsnames) == 1:
|
elif len(dnsnames) == 1:
|
||||||
raise CertificateError("hostname %r doesn't match %r" % (hostname, dnsnames[0]))
|
raise CertificateError(f"hostname {hostname!r} doesn't match {dnsnames[0]!r}")
|
||||||
else:
|
else:
|
||||||
raise CertificateError(
|
raise CertificateError("no appropriate subjectAltName fields were found")
|
||||||
"no appropriate commonName or subjectAltName fields were found"
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,9 +1,21 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
|
import typing
|
||||||
|
|
||||||
from ..exceptions import ProxySchemeUnsupported
|
from ..exceptions import ProxySchemeUnsupported
|
||||||
from ..packages import six
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
|
||||||
|
|
||||||
|
|
||||||
|
_SelfT = typing.TypeVar("_SelfT", bound="SSLTransport")
|
||||||
|
_WriteBuffer = typing.Union[bytearray, memoryview]
|
||||||
|
_ReturnValue = typing.TypeVar("_ReturnValue")
|
||||||
|
|
||||||
SSL_BLOCKSIZE = 16384
|
SSL_BLOCKSIZE = 16384
|
||||||
|
|
||||||
|
@ -20,7 +32,7 @@ class SSLTransport:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_ssl_context_for_tls_in_tls(ssl_context):
|
def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
|
||||||
"""
|
"""
|
||||||
Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
|
Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
|
||||||
for TLS in TLS.
|
for TLS in TLS.
|
||||||
|
@ -30,20 +42,18 @@ class SSLTransport:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not hasattr(ssl_context, "wrap_bio"):
|
if not hasattr(ssl_context, "wrap_bio"):
|
||||||
if six.PY2:
|
raise ProxySchemeUnsupported(
|
||||||
raise ProxySchemeUnsupported(
|
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
|
||||||
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
|
"available on non-native SSLContext"
|
||||||
"supported on Python 2"
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ProxySchemeUnsupported(
|
|
||||||
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
|
|
||||||
"available on non-native SSLContext"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True
|
self,
|
||||||
):
|
socket: socket.socket,
|
||||||
|
ssl_context: ssl.SSLContext,
|
||||||
|
server_hostname: str | None = None,
|
||||||
|
suppress_ragged_eofs: bool = True,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create an SSLTransport around socket using the provided ssl_context.
|
Create an SSLTransport around socket using the provided ssl_context.
|
||||||
"""
|
"""
|
||||||
|
@ -60,33 +70,36 @@ class SSLTransport:
|
||||||
# Perform initial handshake.
|
# Perform initial handshake.
|
||||||
self._ssl_io_loop(self.sslobj.do_handshake)
|
self._ssl_io_loop(self.sslobj.do_handshake)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self: _SelfT) -> _SelfT:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *_):
|
def __exit__(self, *_: typing.Any) -> None:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def fileno(self):
|
def fileno(self) -> int:
|
||||||
return self.socket.fileno()
|
return self.socket.fileno()
|
||||||
|
|
||||||
def read(self, len=1024, buffer=None):
|
def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
|
||||||
return self._wrap_ssl_read(len, buffer)
|
return self._wrap_ssl_read(len, buffer)
|
||||||
|
|
||||||
def recv(self, len=1024, flags=0):
|
def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
|
||||||
if flags != 0:
|
if flags != 0:
|
||||||
raise ValueError("non-zero flags not allowed in calls to recv")
|
raise ValueError("non-zero flags not allowed in calls to recv")
|
||||||
return self._wrap_ssl_read(len)
|
return self._wrap_ssl_read(buflen)
|
||||||
|
|
||||||
def recv_into(self, buffer, nbytes=None, flags=0):
|
def recv_into(
|
||||||
|
self,
|
||||||
|
buffer: _WriteBuffer,
|
||||||
|
nbytes: int | None = None,
|
||||||
|
flags: int = 0,
|
||||||
|
) -> None | int | bytes:
|
||||||
if flags != 0:
|
if flags != 0:
|
||||||
raise ValueError("non-zero flags not allowed in calls to recv_into")
|
raise ValueError("non-zero flags not allowed in calls to recv_into")
|
||||||
if buffer and (nbytes is None):
|
if nbytes is None:
|
||||||
nbytes = len(buffer)
|
nbytes = len(buffer)
|
||||||
elif nbytes is None:
|
|
||||||
nbytes = 1024
|
|
||||||
return self.read(nbytes, buffer)
|
return self.read(nbytes, buffer)
|
||||||
|
|
||||||
def sendall(self, data, flags=0):
|
def sendall(self, data: bytes, flags: int = 0) -> None:
|
||||||
if flags != 0:
|
if flags != 0:
|
||||||
raise ValueError("non-zero flags not allowed in calls to sendall")
|
raise ValueError("non-zero flags not allowed in calls to sendall")
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -96,15 +109,20 @@ class SSLTransport:
|
||||||
v = self.send(byte_view[count:])
|
v = self.send(byte_view[count:])
|
||||||
count += v
|
count += v
|
||||||
|
|
||||||
def send(self, data, flags=0):
|
def send(self, data: bytes, flags: int = 0) -> int:
|
||||||
if flags != 0:
|
if flags != 0:
|
||||||
raise ValueError("non-zero flags not allowed in calls to send")
|
raise ValueError("non-zero flags not allowed in calls to send")
|
||||||
response = self._ssl_io_loop(self.sslobj.write, data)
|
return self._ssl_io_loop(self.sslobj.write, data)
|
||||||
return response
|
|
||||||
|
|
||||||
def makefile(
|
def makefile(
|
||||||
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
|
self,
|
||||||
):
|
mode: str,
|
||||||
|
buffering: int | None = None,
|
||||||
|
*,
|
||||||
|
encoding: str | None = None,
|
||||||
|
errors: str | None = None,
|
||||||
|
newline: str | None = None,
|
||||||
|
) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
|
||||||
"""
|
"""
|
||||||
Python's httpclient uses makefile and buffered io when reading HTTP
|
Python's httpclient uses makefile and buffered io when reading HTTP
|
||||||
messages and we need to support it.
|
messages and we need to support it.
|
||||||
|
@ -113,7 +131,7 @@ class SSLTransport:
|
||||||
changes to point to the socket directly.
|
changes to point to the socket directly.
|
||||||
"""
|
"""
|
||||||
if not set(mode) <= {"r", "w", "b"}:
|
if not set(mode) <= {"r", "w", "b"}:
|
||||||
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
|
raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
|
||||||
|
|
||||||
writing = "w" in mode
|
writing = "w" in mode
|
||||||
reading = "r" in mode or not writing
|
reading = "r" in mode or not writing
|
||||||
|
@ -124,8 +142,8 @@ class SSLTransport:
|
||||||
rawmode += "r"
|
rawmode += "r"
|
||||||
if writing:
|
if writing:
|
||||||
rawmode += "w"
|
rawmode += "w"
|
||||||
raw = socket.SocketIO(self, rawmode)
|
raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
|
||||||
self.socket._io_refs += 1
|
self.socket._io_refs += 1 # type: ignore[attr-defined]
|
||||||
if buffering is None:
|
if buffering is None:
|
||||||
buffering = -1
|
buffering = -1
|
||||||
if buffering < 0:
|
if buffering < 0:
|
||||||
|
@ -134,8 +152,9 @@ class SSLTransport:
|
||||||
if not binary:
|
if not binary:
|
||||||
raise ValueError("unbuffered streams must be binary")
|
raise ValueError("unbuffered streams must be binary")
|
||||||
return raw
|
return raw
|
||||||
|
buffer: typing.BinaryIO
|
||||||
if reading and writing:
|
if reading and writing:
|
||||||
buffer = io.BufferedRWPair(raw, raw, buffering)
|
buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
|
||||||
elif reading:
|
elif reading:
|
||||||
buffer = io.BufferedReader(raw, buffering)
|
buffer = io.BufferedReader(raw, buffering)
|
||||||
else:
|
else:
|
||||||
|
@ -144,46 +163,56 @@ class SSLTransport:
|
||||||
if binary:
|
if binary:
|
||||||
return buffer
|
return buffer
|
||||||
text = io.TextIOWrapper(buffer, encoding, errors, newline)
|
text = io.TextIOWrapper(buffer, encoding, errors, newline)
|
||||||
text.mode = mode
|
text.mode = mode # type: ignore[misc]
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def unwrap(self):
|
def unwrap(self) -> None:
|
||||||
self._ssl_io_loop(self.sslobj.unwrap)
|
self._ssl_io_loop(self.sslobj.unwrap)
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
|
|
||||||
def getpeercert(self, binary_form=False):
|
@typing.overload
|
||||||
return self.sslobj.getpeercert(binary_form)
|
def getpeercert(
|
||||||
|
self, binary_form: Literal[False] = ...
|
||||||
|
) -> _TYPE_PEER_CERT_RET_DICT | None:
|
||||||
|
...
|
||||||
|
|
||||||
def version(self):
|
@typing.overload
|
||||||
|
def getpeercert(self, binary_form: Literal[True]) -> bytes | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
|
||||||
|
return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
|
||||||
|
|
||||||
|
def version(self) -> str | None:
|
||||||
return self.sslobj.version()
|
return self.sslobj.version()
|
||||||
|
|
||||||
def cipher(self):
|
def cipher(self) -> tuple[str, str, int] | None:
|
||||||
return self.sslobj.cipher()
|
return self.sslobj.cipher()
|
||||||
|
|
||||||
def selected_alpn_protocol(self):
|
def selected_alpn_protocol(self) -> str | None:
|
||||||
return self.sslobj.selected_alpn_protocol()
|
return self.sslobj.selected_alpn_protocol()
|
||||||
|
|
||||||
def selected_npn_protocol(self):
|
def selected_npn_protocol(self) -> str | None:
|
||||||
return self.sslobj.selected_npn_protocol()
|
return self.sslobj.selected_npn_protocol()
|
||||||
|
|
||||||
def shared_ciphers(self):
|
def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
|
||||||
return self.sslobj.shared_ciphers()
|
return self.sslobj.shared_ciphers()
|
||||||
|
|
||||||
def compression(self):
|
def compression(self) -> str | None:
|
||||||
return self.sslobj.compression()
|
return self.sslobj.compression()
|
||||||
|
|
||||||
def settimeout(self, value):
|
def settimeout(self, value: float | None) -> None:
|
||||||
self.socket.settimeout(value)
|
self.socket.settimeout(value)
|
||||||
|
|
||||||
def gettimeout(self):
|
def gettimeout(self) -> float | None:
|
||||||
return self.socket.gettimeout()
|
return self.socket.gettimeout()
|
||||||
|
|
||||||
def _decref_socketios(self):
|
def _decref_socketios(self) -> None:
|
||||||
self.socket._decref_socketios()
|
self.socket._decref_socketios() # type: ignore[attr-defined]
|
||||||
|
|
||||||
def _wrap_ssl_read(self, len, buffer=None):
|
def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
|
||||||
try:
|
try:
|
||||||
return self._ssl_io_loop(self.sslobj.read, len, buffer)
|
return self._ssl_io_loop(self.sslobj.read, len, buffer)
|
||||||
except ssl.SSLError as e:
|
except ssl.SSLError as e:
|
||||||
|
@ -192,7 +221,32 @@ class SSLTransport:
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _ssl_io_loop(self, func, *args):
|
# func is sslobj.do_handshake or sslobj.unwrap
|
||||||
|
@typing.overload
|
||||||
|
def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
# func is sslobj.write, arg1 is data
|
||||||
|
@typing.overload
|
||||||
|
def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
# func is sslobj.read, arg1 is len, arg2 is buffer
|
||||||
|
@typing.overload
|
||||||
|
def _ssl_io_loop(
|
||||||
|
self,
|
||||||
|
func: typing.Callable[[int, bytearray | None], bytes],
|
||||||
|
arg1: int,
|
||||||
|
arg2: bytearray | None,
|
||||||
|
) -> bytes:
|
||||||
|
...
|
||||||
|
|
||||||
|
def _ssl_io_loop(
|
||||||
|
self,
|
||||||
|
func: typing.Callable[..., _ReturnValue],
|
||||||
|
arg1: None | bytes | int = None,
|
||||||
|
arg2: bytearray | None = None,
|
||||||
|
) -> _ReturnValue:
|
||||||
"""Performs an I/O loop between incoming/outgoing and the socket."""
|
"""Performs an I/O loop between incoming/outgoing and the socket."""
|
||||||
should_loop = True
|
should_loop = True
|
||||||
ret = None
|
ret = None
|
||||||
|
@ -200,7 +254,12 @@ class SSLTransport:
|
||||||
while should_loop:
|
while should_loop:
|
||||||
errno = None
|
errno = None
|
||||||
try:
|
try:
|
||||||
ret = func(*args)
|
if arg1 is None and arg2 is None:
|
||||||
|
ret = func()
|
||||||
|
elif arg2 is None:
|
||||||
|
ret = func(arg1)
|
||||||
|
else:
|
||||||
|
ret = func(arg1, arg2)
|
||||||
except ssl.SSLError as e:
|
except ssl.SSLError as e:
|
||||||
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
|
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
|
||||||
# WANT_READ, and WANT_WRITE are expected, others are not.
|
# WANT_READ, and WANT_WRITE are expected, others are not.
|
||||||
|
@ -218,4 +277,4 @@ class SSLTransport:
|
||||||
self.incoming.write(buf)
|
self.incoming.write(buf)
|
||||||
else:
|
else:
|
||||||
self.incoming.write_eof()
|
self.incoming.write_eof()
|
||||||
return ret
|
return typing.cast(_ReturnValue, ret)
|
||||||
|
|
|
@ -1,44 +1,56 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
import typing
|
||||||
# The default socket timeout, used by httplib to indicate that no timeout was; specified by the user
|
from enum import Enum
|
||||||
from socket import _GLOBAL_DEFAULT_TIMEOUT, getdefaulttimeout
|
from socket import getdefaulttimeout
|
||||||
|
|
||||||
from ..exceptions import TimeoutStateError
|
from ..exceptions import TimeoutStateError
|
||||||
|
|
||||||
# A sentinel value to indicate that no timeout was specified by the user in
|
if typing.TYPE_CHECKING:
|
||||||
# urllib3
|
from typing_extensions import Final
|
||||||
_Default = object()
|
|
||||||
|
|
||||||
|
|
||||||
# Use time.monotonic if available.
|
class _TYPE_DEFAULT(Enum):
|
||||||
current_time = getattr(time, "monotonic", time.time)
|
# This value should never be passed to socket.settimeout() so for safety we use a -1.
|
||||||
|
# socket.settimout() raises a ValueError for negative values.
|
||||||
|
token = -1
|
||||||
|
|
||||||
|
|
||||||
class Timeout(object):
|
_DEFAULT_TIMEOUT: Final[_TYPE_DEFAULT] = _TYPE_DEFAULT.token
|
||||||
|
|
||||||
|
_TYPE_TIMEOUT = typing.Optional[typing.Union[float, _TYPE_DEFAULT]]
|
||||||
|
|
||||||
|
|
||||||
|
class Timeout:
|
||||||
"""Timeout configuration.
|
"""Timeout configuration.
|
||||||
|
|
||||||
Timeouts can be defined as a default for a pool:
|
Timeouts can be defined as a default for a pool:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
timeout = Timeout(connect=2.0, read=7.0)
|
import urllib3
|
||||||
http = PoolManager(timeout=timeout)
|
|
||||||
response = http.request('GET', 'http://example.com/')
|
timeout = urllib3.util.Timeout(connect=2.0, read=7.0)
|
||||||
|
|
||||||
|
http = urllib3.PoolManager(timeout=timeout)
|
||||||
|
|
||||||
|
resp = http.request("GET", "https://example.com/")
|
||||||
|
|
||||||
|
print(resp.status)
|
||||||
|
|
||||||
Or per-request (which overrides the default for the pool):
|
Or per-request (which overrides the default for the pool):
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
response = http.request('GET', 'http://example.com/', timeout=Timeout(10))
|
response = http.request("GET", "https://example.com/", timeout=Timeout(10))
|
||||||
|
|
||||||
Timeouts can be disabled by setting all the parameters to ``None``:
|
Timeouts can be disabled by setting all the parameters to ``None``:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
no_timeout = Timeout(connect=None, read=None)
|
no_timeout = Timeout(connect=None, read=None)
|
||||||
response = http.request('GET', 'http://example.com/, timeout=no_timeout)
|
response = http.request("GET", "https://example.com/", timeout=no_timeout)
|
||||||
|
|
||||||
|
|
||||||
:param total:
|
:param total:
|
||||||
|
@ -96,31 +108,31 @@ class Timeout(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#: A sentinel object representing the default timeout value
|
#: A sentinel object representing the default timeout value
|
||||||
DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT
|
DEFAULT_TIMEOUT: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT
|
||||||
|
|
||||||
def __init__(self, total=None, connect=_Default, read=_Default):
|
def __init__(
|
||||||
|
self,
|
||||||
|
total: _TYPE_TIMEOUT = None,
|
||||||
|
connect: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
|
||||||
|
read: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
|
||||||
|
) -> None:
|
||||||
self._connect = self._validate_timeout(connect, "connect")
|
self._connect = self._validate_timeout(connect, "connect")
|
||||||
self._read = self._validate_timeout(read, "read")
|
self._read = self._validate_timeout(read, "read")
|
||||||
self.total = self._validate_timeout(total, "total")
|
self.total = self._validate_timeout(total, "total")
|
||||||
self._start_connect = None
|
self._start_connect: float | None = None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return "%s(connect=%r, read=%r, total=%r)" % (
|
return f"{type(self).__name__}(connect={self._connect!r}, read={self._read!r}, total={self.total!r})"
|
||||||
type(self).__name__,
|
|
||||||
self._connect,
|
|
||||||
self._read,
|
|
||||||
self.total,
|
|
||||||
)
|
|
||||||
|
|
||||||
# __str__ provided for backwards compatibility
|
# __str__ provided for backwards compatibility
|
||||||
__str__ = __repr__
|
__str__ = __repr__
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def resolve_default_timeout(cls, timeout):
|
def resolve_default_timeout(timeout: _TYPE_TIMEOUT) -> float | None:
|
||||||
return getdefaulttimeout() if timeout is cls.DEFAULT_TIMEOUT else timeout
|
return getdefaulttimeout() if timeout is _DEFAULT_TIMEOUT else timeout
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_timeout(cls, value, name):
|
def _validate_timeout(cls, value: _TYPE_TIMEOUT, name: str) -> _TYPE_TIMEOUT:
|
||||||
"""Check that a timeout attribute is valid.
|
"""Check that a timeout attribute is valid.
|
||||||
|
|
||||||
:param value: The timeout value to validate
|
:param value: The timeout value to validate
|
||||||
|
@ -130,10 +142,7 @@ class Timeout(object):
|
||||||
:raises ValueError: If it is a numeric value less than or equal to
|
:raises ValueError: If it is a numeric value less than or equal to
|
||||||
zero, or the type is not an integer, float, or None.
|
zero, or the type is not an integer, float, or None.
|
||||||
"""
|
"""
|
||||||
if value is _Default:
|
if value is None or value is _DEFAULT_TIMEOUT:
|
||||||
return cls.DEFAULT_TIMEOUT
|
|
||||||
|
|
||||||
if value is None or value is cls.DEFAULT_TIMEOUT:
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
|
@ -147,7 +156,7 @@ class Timeout(object):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Timeout value %s was %s, but it must be an "
|
"Timeout value %s was %s, but it must be an "
|
||||||
"int, float or None." % (name, value)
|
"int, float or None." % (name, value)
|
||||||
)
|
) from None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if value <= 0:
|
if value <= 0:
|
||||||
|
@ -157,16 +166,15 @@ class Timeout(object):
|
||||||
"than or equal to 0." % (name, value)
|
"than or equal to 0." % (name, value)
|
||||||
)
|
)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# Python 3
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Timeout value %s was %s, but it must be an "
|
"Timeout value %s was %s, but it must be an "
|
||||||
"int, float or None." % (name, value)
|
"int, float or None." % (name, value)
|
||||||
)
|
) from None
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, timeout):
|
def from_float(cls, timeout: _TYPE_TIMEOUT) -> Timeout:
|
||||||
"""Create a new Timeout from a legacy timeout value.
|
"""Create a new Timeout from a legacy timeout value.
|
||||||
|
|
||||||
The timeout value used by httplib.py sets the same timeout on the
|
The timeout value used by httplib.py sets the same timeout on the
|
||||||
|
@ -175,13 +183,13 @@ class Timeout(object):
|
||||||
passed to this function.
|
passed to this function.
|
||||||
|
|
||||||
:param timeout: The legacy timeout value.
|
:param timeout: The legacy timeout value.
|
||||||
:type timeout: integer, float, sentinel default object, or None
|
:type timeout: integer, float, :attr:`urllib3.util.Timeout.DEFAULT_TIMEOUT`, or None
|
||||||
:return: Timeout object
|
:return: Timeout object
|
||||||
:rtype: :class:`Timeout`
|
:rtype: :class:`Timeout`
|
||||||
"""
|
"""
|
||||||
return Timeout(read=timeout, connect=timeout)
|
return Timeout(read=timeout, connect=timeout)
|
||||||
|
|
||||||
def clone(self):
|
def clone(self) -> Timeout:
|
||||||
"""Create a copy of the timeout object
|
"""Create a copy of the timeout object
|
||||||
|
|
||||||
Timeout properties are stored per-pool but each request needs a fresh
|
Timeout properties are stored per-pool but each request needs a fresh
|
||||||
|
@ -195,7 +203,7 @@ class Timeout(object):
|
||||||
# detect the user default.
|
# detect the user default.
|
||||||
return Timeout(connect=self._connect, read=self._read, total=self.total)
|
return Timeout(connect=self._connect, read=self._read, total=self.total)
|
||||||
|
|
||||||
def start_connect(self):
|
def start_connect(self) -> float:
|
||||||
"""Start the timeout clock, used during a connect() attempt
|
"""Start the timeout clock, used during a connect() attempt
|
||||||
|
|
||||||
:raises urllib3.exceptions.TimeoutStateError: if you attempt
|
:raises urllib3.exceptions.TimeoutStateError: if you attempt
|
||||||
|
@ -203,10 +211,10 @@ class Timeout(object):
|
||||||
"""
|
"""
|
||||||
if self._start_connect is not None:
|
if self._start_connect is not None:
|
||||||
raise TimeoutStateError("Timeout timer has already been started.")
|
raise TimeoutStateError("Timeout timer has already been started.")
|
||||||
self._start_connect = current_time()
|
self._start_connect = time.monotonic()
|
||||||
return self._start_connect
|
return self._start_connect
|
||||||
|
|
||||||
def get_connect_duration(self):
|
def get_connect_duration(self) -> float:
|
||||||
"""Gets the time elapsed since the call to :meth:`start_connect`.
|
"""Gets the time elapsed since the call to :meth:`start_connect`.
|
||||||
|
|
||||||
:return: Elapsed time in seconds.
|
:return: Elapsed time in seconds.
|
||||||
|
@ -218,10 +226,10 @@ class Timeout(object):
|
||||||
raise TimeoutStateError(
|
raise TimeoutStateError(
|
||||||
"Can't get connect duration for timer that has not started."
|
"Can't get connect duration for timer that has not started."
|
||||||
)
|
)
|
||||||
return current_time() - self._start_connect
|
return time.monotonic() - self._start_connect
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connect_timeout(self):
|
def connect_timeout(self) -> _TYPE_TIMEOUT:
|
||||||
"""Get the value to use when setting a connection timeout.
|
"""Get the value to use when setting a connection timeout.
|
||||||
|
|
||||||
This will be a positive float or integer, the value None
|
This will be a positive float or integer, the value None
|
||||||
|
@ -233,13 +241,13 @@ class Timeout(object):
|
||||||
if self.total is None:
|
if self.total is None:
|
||||||
return self._connect
|
return self._connect
|
||||||
|
|
||||||
if self._connect is None or self._connect is self.DEFAULT_TIMEOUT:
|
if self._connect is None or self._connect is _DEFAULT_TIMEOUT:
|
||||||
return self.total
|
return self.total
|
||||||
|
|
||||||
return min(self._connect, self.total)
|
return min(self._connect, self.total) # type: ignore[type-var]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def read_timeout(self):
|
def read_timeout(self) -> float | None:
|
||||||
"""Get the value for the read timeout.
|
"""Get the value for the read timeout.
|
||||||
|
|
||||||
This assumes some time has elapsed in the connection timeout and
|
This assumes some time has elapsed in the connection timeout and
|
||||||
|
@ -251,21 +259,21 @@ class Timeout(object):
|
||||||
raised.
|
raised.
|
||||||
|
|
||||||
:return: Value to use for the read timeout.
|
:return: Value to use for the read timeout.
|
||||||
:rtype: int, float, :attr:`Timeout.DEFAULT_TIMEOUT` or None
|
:rtype: int, float or None
|
||||||
:raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect`
|
:raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect`
|
||||||
has not yet been called on this object.
|
has not yet been called on this object.
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
self.total is not None
|
self.total is not None
|
||||||
and self.total is not self.DEFAULT_TIMEOUT
|
and self.total is not _DEFAULT_TIMEOUT
|
||||||
and self._read is not None
|
and self._read is not None
|
||||||
and self._read is not self.DEFAULT_TIMEOUT
|
and self._read is not _DEFAULT_TIMEOUT
|
||||||
):
|
):
|
||||||
# In case the connect timeout has not yet been established.
|
# In case the connect timeout has not yet been established.
|
||||||
if self._start_connect is None:
|
if self._start_connect is None:
|
||||||
return self._read
|
return self._read
|
||||||
return max(0, min(self.total - self.get_connect_duration(), self._read))
|
return max(0, min(self.total - self.get_connect_duration(), self._read))
|
||||||
elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT:
|
elif self.total is not None and self.total is not _DEFAULT_TIMEOUT:
|
||||||
return max(0, self.total - self.get_connect_duration())
|
return max(0, self.total - self.get_connect_duration())
|
||||||
else:
|
else:
|
||||||
return self._read
|
return self.resolve_default_timeout(self._read)
|
||||||
|
|
|
@ -1,22 +1,20 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from collections import namedtuple
|
import typing
|
||||||
|
|
||||||
from ..exceptions import LocationParseError
|
from ..exceptions import LocationParseError
|
||||||
from ..packages import six
|
from .util import to_str
|
||||||
|
|
||||||
url_attrs = ["scheme", "auth", "host", "port", "path", "query", "fragment"]
|
|
||||||
|
|
||||||
# We only want to normalize urls with an HTTP(S) scheme.
|
# We only want to normalize urls with an HTTP(S) scheme.
|
||||||
# urllib3 infers URLs without a scheme (None) to be http.
|
# urllib3 infers URLs without a scheme (None) to be http.
|
||||||
NORMALIZABLE_SCHEMES = ("http", "https", None)
|
_NORMALIZABLE_SCHEMES = ("http", "https", None)
|
||||||
|
|
||||||
# Almost all of these patterns were derived from the
|
# Almost all of these patterns were derived from the
|
||||||
# 'rfc3986' module: https://github.com/python-hyper/rfc3986
|
# 'rfc3986' module: https://github.com/python-hyper/rfc3986
|
||||||
PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}")
|
_PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}")
|
||||||
SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)")
|
_SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)")
|
||||||
URI_RE = re.compile(
|
_URI_RE = re.compile(
|
||||||
r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?"
|
r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?"
|
||||||
r"(?://([^\\/?#]*))?"
|
r"(?://([^\\/?#]*))?"
|
||||||
r"([^?#]*)"
|
r"([^?#]*)"
|
||||||
|
@ -25,10 +23,10 @@ URI_RE = re.compile(
|
||||||
re.UNICODE | re.DOTALL,
|
re.UNICODE | re.DOTALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}"
|
_IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}"
|
||||||
HEX_PAT = "[0-9A-Fa-f]{1,4}"
|
_HEX_PAT = "[0-9A-Fa-f]{1,4}"
|
||||||
LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=HEX_PAT, ipv4=IPV4_PAT)
|
_LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=_HEX_PAT, ipv4=_IPV4_PAT)
|
||||||
_subs = {"hex": HEX_PAT, "ls32": LS32_PAT}
|
_subs = {"hex": _HEX_PAT, "ls32": _LS32_PAT}
|
||||||
_variations = [
|
_variations = [
|
||||||
# 6( h16 ":" ) ls32
|
# 6( h16 ":" ) ls32
|
||||||
"(?:%(hex)s:){6}%(ls32)s",
|
"(?:%(hex)s:){6}%(ls32)s",
|
||||||
|
@ -50,69 +48,78 @@ _variations = [
|
||||||
"(?:(?:%(hex)s:){0,6}%(hex)s)?::",
|
"(?:(?:%(hex)s:){0,6}%(hex)s)?::",
|
||||||
]
|
]
|
||||||
|
|
||||||
UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._\-~"
|
_UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._\-~"
|
||||||
IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")"
|
_IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")"
|
||||||
ZONE_ID_PAT = "(?:%25|%)(?:[" + UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+"
|
_ZONE_ID_PAT = "(?:%25|%)(?:[" + _UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+"
|
||||||
IPV6_ADDRZ_PAT = r"\[" + IPV6_PAT + r"(?:" + ZONE_ID_PAT + r")?\]"
|
_IPV6_ADDRZ_PAT = r"\[" + _IPV6_PAT + r"(?:" + _ZONE_ID_PAT + r")?\]"
|
||||||
REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*"
|
_REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*"
|
||||||
TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$")
|
_TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$")
|
||||||
|
|
||||||
IPV4_RE = re.compile("^" + IPV4_PAT + "$")
|
_IPV4_RE = re.compile("^" + _IPV4_PAT + "$")
|
||||||
IPV6_RE = re.compile("^" + IPV6_PAT + "$")
|
_IPV6_RE = re.compile("^" + _IPV6_PAT + "$")
|
||||||
IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT + "$")
|
_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT + "$")
|
||||||
BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT[2:-2] + "$")
|
_BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT[2:-2] + "$")
|
||||||
ZONE_ID_RE = re.compile("(" + ZONE_ID_PAT + r")\]$")
|
_ZONE_ID_RE = re.compile("(" + _ZONE_ID_PAT + r")\]$")
|
||||||
|
|
||||||
_HOST_PORT_PAT = ("^(%s|%s|%s)(?::0*?(|0|[1-9][0-9]{0,4}))?$") % (
|
_HOST_PORT_PAT = ("^(%s|%s|%s)(?::0*?(|0|[1-9][0-9]{0,4}))?$") % (
|
||||||
REG_NAME_PAT,
|
_REG_NAME_PAT,
|
||||||
IPV4_PAT,
|
_IPV4_PAT,
|
||||||
IPV6_ADDRZ_PAT,
|
_IPV6_ADDRZ_PAT,
|
||||||
)
|
)
|
||||||
_HOST_PORT_RE = re.compile(_HOST_PORT_PAT, re.UNICODE | re.DOTALL)
|
_HOST_PORT_RE = re.compile(_HOST_PORT_PAT, re.UNICODE | re.DOTALL)
|
||||||
|
|
||||||
UNRESERVED_CHARS = set(
|
_UNRESERVED_CHARS = set(
|
||||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~"
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~"
|
||||||
)
|
)
|
||||||
SUB_DELIM_CHARS = set("!$&'()*+,;=")
|
_SUB_DELIM_CHARS = set("!$&'()*+,;=")
|
||||||
USERINFO_CHARS = UNRESERVED_CHARS | SUB_DELIM_CHARS | {":"}
|
_USERINFO_CHARS = _UNRESERVED_CHARS | _SUB_DELIM_CHARS | {":"}
|
||||||
PATH_CHARS = USERINFO_CHARS | {"@", "/"}
|
_PATH_CHARS = _USERINFO_CHARS | {"@", "/"}
|
||||||
QUERY_CHARS = FRAGMENT_CHARS = PATH_CHARS | {"?"}
|
_QUERY_CHARS = _FRAGMENT_CHARS = _PATH_CHARS | {"?"}
|
||||||
|
|
||||||
|
|
||||||
class Url(namedtuple("Url", url_attrs)):
|
class Url(
|
||||||
|
typing.NamedTuple(
|
||||||
|
"Url",
|
||||||
|
[
|
||||||
|
("scheme", typing.Optional[str]),
|
||||||
|
("auth", typing.Optional[str]),
|
||||||
|
("host", typing.Optional[str]),
|
||||||
|
("port", typing.Optional[int]),
|
||||||
|
("path", typing.Optional[str]),
|
||||||
|
("query", typing.Optional[str]),
|
||||||
|
("fragment", typing.Optional[str]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Data structure for representing an HTTP URL. Used as a return value for
|
Data structure for representing an HTTP URL. Used as a return value for
|
||||||
:func:`parse_url`. Both the scheme and host are normalized as they are
|
:func:`parse_url`. Both the scheme and host are normalized as they are
|
||||||
both case-insensitive according to RFC 3986.
|
both case-insensitive according to RFC 3986.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ()
|
def __new__( # type: ignore[no-untyped-def]
|
||||||
|
|
||||||
def __new__(
|
|
||||||
cls,
|
cls,
|
||||||
scheme=None,
|
scheme: str | None = None,
|
||||||
auth=None,
|
auth: str | None = None,
|
||||||
host=None,
|
host: str | None = None,
|
||||||
port=None,
|
port: int | None = None,
|
||||||
path=None,
|
path: str | None = None,
|
||||||
query=None,
|
query: str | None = None,
|
||||||
fragment=None,
|
fragment: str | None = None,
|
||||||
):
|
):
|
||||||
if path and not path.startswith("/"):
|
if path and not path.startswith("/"):
|
||||||
path = "/" + path
|
path = "/" + path
|
||||||
if scheme is not None:
|
if scheme is not None:
|
||||||
scheme = scheme.lower()
|
scheme = scheme.lower()
|
||||||
return super(Url, cls).__new__(
|
return super().__new__(cls, scheme, auth, host, port, path, query, fragment)
|
||||||
cls, scheme, auth, host, port, path, query, fragment
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hostname(self):
|
def hostname(self) -> str | None:
|
||||||
"""For backwards-compatibility with urlparse. We're nice like that."""
|
"""For backwards-compatibility with urlparse. We're nice like that."""
|
||||||
return self.host
|
return self.host
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def request_uri(self):
|
def request_uri(self) -> str:
|
||||||
"""Absolute path including the query string."""
|
"""Absolute path including the query string."""
|
||||||
uri = self.path or "/"
|
uri = self.path or "/"
|
||||||
|
|
||||||
|
@ -122,14 +129,37 @@ class Url(namedtuple("Url", url_attrs)):
|
||||||
return uri
|
return uri
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def netloc(self):
|
def authority(self) -> str | None:
|
||||||
"""Network location including host and port"""
|
"""
|
||||||
|
Authority component as defined in RFC 3986 3.2.
|
||||||
|
This includes userinfo (auth), host and port.
|
||||||
|
|
||||||
|
i.e.
|
||||||
|
userinfo@host:port
|
||||||
|
"""
|
||||||
|
userinfo = self.auth
|
||||||
|
netloc = self.netloc
|
||||||
|
if netloc is None or userinfo is None:
|
||||||
|
return netloc
|
||||||
|
else:
|
||||||
|
return f"{userinfo}@{netloc}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def netloc(self) -> str | None:
|
||||||
|
"""
|
||||||
|
Network location including host and port.
|
||||||
|
|
||||||
|
If you need the equivalent of urllib.parse's ``netloc``,
|
||||||
|
use the ``authority`` property instead.
|
||||||
|
"""
|
||||||
|
if self.host is None:
|
||||||
|
return None
|
||||||
if self.port:
|
if self.port:
|
||||||
return "%s:%d" % (self.host, self.port)
|
return f"{self.host}:{self.port}"
|
||||||
return self.host
|
return self.host
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url(self):
|
def url(self) -> str:
|
||||||
"""
|
"""
|
||||||
Convert self into a url
|
Convert self into a url
|
||||||
|
|
||||||
|
@ -138,88 +168,77 @@ class Url(namedtuple("Url", url_attrs)):
|
||||||
:func:`.parse_url`, but it should be equivalent by the RFC (e.g., urls
|
:func:`.parse_url`, but it should be equivalent by the RFC (e.g., urls
|
||||||
with a blank port will have : removed).
|
with a blank port will have : removed).
|
||||||
|
|
||||||
Example: ::
|
Example:
|
||||||
|
|
||||||
>>> U = parse_url('http://google.com/mail/')
|
.. code-block:: python
|
||||||
>>> U.url
|
|
||||||
'http://google.com/mail/'
|
import urllib3
|
||||||
>>> Url('http', 'username:password', 'host.com', 80,
|
|
||||||
... '/path', 'query', 'fragment').url
|
U = urllib3.util.parse_url("https://google.com/mail/")
|
||||||
'http://username:password@host.com:80/path?query#fragment'
|
|
||||||
|
print(U.url)
|
||||||
|
# "https://google.com/mail/"
|
||||||
|
|
||||||
|
print( urllib3.util.Url("https", "username:password",
|
||||||
|
"host.com", 80, "/path", "query", "fragment"
|
||||||
|
).url
|
||||||
|
)
|
||||||
|
# "https://username:password@host.com:80/path?query#fragment"
|
||||||
"""
|
"""
|
||||||
scheme, auth, host, port, path, query, fragment = self
|
scheme, auth, host, port, path, query, fragment = self
|
||||||
url = u""
|
url = ""
|
||||||
|
|
||||||
# We use "is not None" we want things to happen with empty strings (or 0 port)
|
# We use "is not None" we want things to happen with empty strings (or 0 port)
|
||||||
if scheme is not None:
|
if scheme is not None:
|
||||||
url += scheme + u"://"
|
url += scheme + "://"
|
||||||
if auth is not None:
|
if auth is not None:
|
||||||
url += auth + u"@"
|
url += auth + "@"
|
||||||
if host is not None:
|
if host is not None:
|
||||||
url += host
|
url += host
|
||||||
if port is not None:
|
if port is not None:
|
||||||
url += u":" + str(port)
|
url += ":" + str(port)
|
||||||
if path is not None:
|
if path is not None:
|
||||||
url += path
|
url += path
|
||||||
if query is not None:
|
if query is not None:
|
||||||
url += u"?" + query
|
url += "?" + query
|
||||||
if fragment is not None:
|
if fragment is not None:
|
||||||
url += u"#" + fragment
|
url += "#" + fragment
|
||||||
|
|
||||||
return url
|
return url
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return self.url
|
return self.url
|
||||||
|
|
||||||
|
|
||||||
def split_first(s, delims):
|
@typing.overload
|
||||||
"""
|
def _encode_invalid_chars(
|
||||||
.. deprecated:: 1.25
|
component: str, allowed_chars: typing.Container[str]
|
||||||
|
) -> str: # Abstract
|
||||||
Given a string and an iterable of delimiters, split on the first found
|
...
|
||||||
delimiter. Return two split parts and the matched delimiter.
|
|
||||||
|
|
||||||
If not found, then the first part is the full input string.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
>>> split_first('foo/bar?baz', '?/=')
|
|
||||||
('foo', 'bar?baz', '/')
|
|
||||||
>>> split_first('foo/bar?baz', '123')
|
|
||||||
('foo/bar?baz', '', None)
|
|
||||||
|
|
||||||
Scales linearly with number of delims. Not ideal for large number of delims.
|
|
||||||
"""
|
|
||||||
min_idx = None
|
|
||||||
min_delim = None
|
|
||||||
for d in delims:
|
|
||||||
idx = s.find(d)
|
|
||||||
if idx < 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if min_idx is None or idx < min_idx:
|
|
||||||
min_idx = idx
|
|
||||||
min_delim = d
|
|
||||||
|
|
||||||
if min_idx is None or min_idx < 0:
|
|
||||||
return s, "", None
|
|
||||||
|
|
||||||
return s[:min_idx], s[min_idx + 1 :], min_delim
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
|
@typing.overload
|
||||||
|
def _encode_invalid_chars(
|
||||||
|
component: None, allowed_chars: typing.Container[str]
|
||||||
|
) -> None: # Abstract
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_invalid_chars(
|
||||||
|
component: str | None, allowed_chars: typing.Container[str]
|
||||||
|
) -> str | None:
|
||||||
"""Percent-encodes a URI component without reapplying
|
"""Percent-encodes a URI component without reapplying
|
||||||
onto an already percent-encoded component.
|
onto an already percent-encoded component.
|
||||||
"""
|
"""
|
||||||
if component is None:
|
if component is None:
|
||||||
return component
|
return component
|
||||||
|
|
||||||
component = six.ensure_text(component)
|
component = to_str(component)
|
||||||
|
|
||||||
# Normalize existing percent-encoded bytes.
|
# Normalize existing percent-encoded bytes.
|
||||||
# Try to see if the component we're encoding is already percent-encoded
|
# Try to see if the component we're encoding is already percent-encoded
|
||||||
# so we can skip all '%' characters but still encode all others.
|
# so we can skip all '%' characters but still encode all others.
|
||||||
component, percent_encodings = PERCENT_RE.subn(
|
component, percent_encodings = _PERCENT_RE.subn(
|
||||||
lambda match: match.group(0).upper(), component
|
lambda match: match.group(0).upper(), component
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -228,7 +247,7 @@ def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
|
||||||
encoded_component = bytearray()
|
encoded_component = bytearray()
|
||||||
|
|
||||||
for i in range(0, len(uri_bytes)):
|
for i in range(0, len(uri_bytes)):
|
||||||
# Will return a single character bytestring on both Python 2 & 3
|
# Will return a single character bytestring
|
||||||
byte = uri_bytes[i : i + 1]
|
byte = uri_bytes[i : i + 1]
|
||||||
byte_ord = ord(byte)
|
byte_ord = ord(byte)
|
||||||
if (is_percent_encoded and byte == b"%") or (
|
if (is_percent_encoded and byte == b"%") or (
|
||||||
|
@ -238,10 +257,10 @@ def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
|
||||||
continue
|
continue
|
||||||
encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper()))
|
encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper()))
|
||||||
|
|
||||||
return encoded_component.decode(encoding)
|
return encoded_component.decode()
|
||||||
|
|
||||||
|
|
||||||
def _remove_path_dot_segments(path):
|
def _remove_path_dot_segments(path: str) -> str:
|
||||||
# See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code
|
# See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code
|
||||||
segments = path.split("/") # Turn the path into a list of segments
|
segments = path.split("/") # Turn the path into a list of segments
|
||||||
output = [] # Initialize the variable to use to store output
|
output = [] # Initialize the variable to use to store output
|
||||||
|
@ -251,7 +270,7 @@ def _remove_path_dot_segments(path):
|
||||||
if segment == ".":
|
if segment == ".":
|
||||||
continue
|
continue
|
||||||
# Anything other than '..', should be appended to the output
|
# Anything other than '..', should be appended to the output
|
||||||
elif segment != "..":
|
if segment != "..":
|
||||||
output.append(segment)
|
output.append(segment)
|
||||||
# In this case segment == '..', if we can, we should pop the last
|
# In this case segment == '..', if we can, we should pop the last
|
||||||
# element
|
# element
|
||||||
|
@ -271,18 +290,25 @@ def _remove_path_dot_segments(path):
|
||||||
return "/".join(output)
|
return "/".join(output)
|
||||||
|
|
||||||
|
|
||||||
def _normalize_host(host, scheme):
|
@typing.overload
|
||||||
if host:
|
def _normalize_host(host: None, scheme: str | None) -> None:
|
||||||
if isinstance(host, six.binary_type):
|
...
|
||||||
host = six.ensure_str(host)
|
|
||||||
|
|
||||||
if scheme in NORMALIZABLE_SCHEMES:
|
|
||||||
is_ipv6 = IPV6_ADDRZ_RE.match(host)
|
@typing.overload
|
||||||
|
def _normalize_host(host: str, scheme: str | None) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_host(host: str | None, scheme: str | None) -> str | None:
|
||||||
|
if host:
|
||||||
|
if scheme in _NORMALIZABLE_SCHEMES:
|
||||||
|
is_ipv6 = _IPV6_ADDRZ_RE.match(host)
|
||||||
if is_ipv6:
|
if is_ipv6:
|
||||||
# IPv6 hosts of the form 'a::b%zone' are encoded in a URL as
|
# IPv6 hosts of the form 'a::b%zone' are encoded in a URL as
|
||||||
# such per RFC 6874: 'a::b%25zone'. Unquote the ZoneID
|
# such per RFC 6874: 'a::b%25zone'. Unquote the ZoneID
|
||||||
# separator as necessary to return a valid RFC 4007 scoped IP.
|
# separator as necessary to return a valid RFC 4007 scoped IP.
|
||||||
match = ZONE_ID_RE.search(host)
|
match = _ZONE_ID_RE.search(host)
|
||||||
if match:
|
if match:
|
||||||
start, end = match.span(1)
|
start, end = match.span(1)
|
||||||
zone_id = host[start:end]
|
zone_id = host[start:end]
|
||||||
|
@ -291,46 +317,56 @@ def _normalize_host(host, scheme):
|
||||||
zone_id = zone_id[3:]
|
zone_id = zone_id[3:]
|
||||||
else:
|
else:
|
||||||
zone_id = zone_id[1:]
|
zone_id = zone_id[1:]
|
||||||
zone_id = "%" + _encode_invalid_chars(zone_id, UNRESERVED_CHARS)
|
zone_id = _encode_invalid_chars(zone_id, _UNRESERVED_CHARS)
|
||||||
return host[:start].lower() + zone_id + host[end:]
|
return f"{host[:start].lower()}%{zone_id}{host[end:]}"
|
||||||
else:
|
else:
|
||||||
return host.lower()
|
return host.lower()
|
||||||
elif not IPV4_RE.match(host):
|
elif not _IPV4_RE.match(host):
|
||||||
return six.ensure_str(
|
return to_str(
|
||||||
b".".join([_idna_encode(label) for label in host.split(".")])
|
b".".join([_idna_encode(label) for label in host.split(".")]),
|
||||||
|
"ascii",
|
||||||
)
|
)
|
||||||
return host
|
return host
|
||||||
|
|
||||||
|
|
||||||
def _idna_encode(name):
|
def _idna_encode(name: str) -> bytes:
|
||||||
if name and any(ord(x) >= 128 for x in name):
|
if not name.isascii():
|
||||||
try:
|
try:
|
||||||
import idna
|
import idna
|
||||||
except ImportError:
|
except ImportError:
|
||||||
six.raise_from(
|
raise LocationParseError(
|
||||||
LocationParseError("Unable to parse URL without the 'idna' module"),
|
"Unable to parse URL without the 'idna' module"
|
||||||
None,
|
) from None
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
return idna.encode(name.lower(), strict=True, std3_rules=True)
|
return idna.encode(name.lower(), strict=True, std3_rules=True)
|
||||||
except idna.IDNAError:
|
except idna.IDNAError:
|
||||||
six.raise_from(
|
raise LocationParseError(
|
||||||
LocationParseError(u"Name '%s' is not a valid IDNA label" % name), None
|
f"Name '{name}' is not a valid IDNA label"
|
||||||
)
|
) from None
|
||||||
|
|
||||||
return name.lower().encode("ascii")
|
return name.lower().encode("ascii")
|
||||||
|
|
||||||
|
|
||||||
def _encode_target(target):
|
def _encode_target(target: str) -> str:
|
||||||
"""Percent-encodes a request target so that there are no invalid characters"""
|
"""Percent-encodes a request target so that there are no invalid characters
|
||||||
path, query = TARGET_RE.match(target).groups()
|
|
||||||
target = _encode_invalid_chars(path, PATH_CHARS)
|
Pre-condition for this function is that 'target' must start with '/'.
|
||||||
query = _encode_invalid_chars(query, QUERY_CHARS)
|
If that is the case then _TARGET_RE will always produce a match.
|
||||||
|
"""
|
||||||
|
match = _TARGET_RE.match(target)
|
||||||
|
if not match: # Defensive:
|
||||||
|
raise LocationParseError(f"{target!r} is not a valid request URI")
|
||||||
|
|
||||||
|
path, query = match.groups()
|
||||||
|
encoded_target = _encode_invalid_chars(path, _PATH_CHARS)
|
||||||
if query is not None:
|
if query is not None:
|
||||||
target += "?" + query
|
query = _encode_invalid_chars(query, _QUERY_CHARS)
|
||||||
return target
|
encoded_target += "?" + query
|
||||||
|
return encoded_target
|
||||||
|
|
||||||
|
|
||||||
def parse_url(url):
|
def parse_url(url: str) -> Url:
|
||||||
"""
|
"""
|
||||||
Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is
|
Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is
|
||||||
performed to parse incomplete urls. Fields not provided will be None.
|
performed to parse incomplete urls. Fields not provided will be None.
|
||||||
|
@ -341,28 +377,44 @@ def parse_url(url):
|
||||||
|
|
||||||
:param str url: URL to parse into a :class:`.Url` namedtuple.
|
:param str url: URL to parse into a :class:`.Url` namedtuple.
|
||||||
|
|
||||||
Partly backwards-compatible with :mod:`urlparse`.
|
Partly backwards-compatible with :mod:`urllib.parse`.
|
||||||
|
|
||||||
Example::
|
Example:
|
||||||
|
|
||||||
>>> parse_url('http://google.com/mail/')
|
.. code-block:: python
|
||||||
Url(scheme='http', host='google.com', port=None, path='/mail/', ...)
|
|
||||||
>>> parse_url('google.com:80')
|
import urllib3
|
||||||
Url(scheme=None, host='google.com', port=80, path=None, ...)
|
|
||||||
>>> parse_url('/foo?bar')
|
print( urllib3.util.parse_url('http://google.com/mail/'))
|
||||||
Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
|
# Url(scheme='http', host='google.com', port=None, path='/mail/', ...)
|
||||||
|
|
||||||
|
print( urllib3.util.parse_url('google.com:80'))
|
||||||
|
# Url(scheme=None, host='google.com', port=80, path=None, ...)
|
||||||
|
|
||||||
|
print( urllib3.util.parse_url('/foo?bar'))
|
||||||
|
# Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
|
||||||
"""
|
"""
|
||||||
if not url:
|
if not url:
|
||||||
# Empty
|
# Empty
|
||||||
return Url()
|
return Url()
|
||||||
|
|
||||||
source_url = url
|
source_url = url
|
||||||
if not SCHEME_RE.search(url):
|
if not _SCHEME_RE.search(url):
|
||||||
url = "//" + url
|
url = "//" + url
|
||||||
|
|
||||||
|
scheme: str | None
|
||||||
|
authority: str | None
|
||||||
|
auth: str | None
|
||||||
|
host: str | None
|
||||||
|
port: str | None
|
||||||
|
port_int: int | None
|
||||||
|
path: str | None
|
||||||
|
query: str | None
|
||||||
|
fragment: str | None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
scheme, authority, path, query, fragment = URI_RE.match(url).groups()
|
scheme, authority, path, query, fragment = _URI_RE.match(url).groups() # type: ignore[union-attr]
|
||||||
normalize_uri = scheme is None or scheme.lower() in NORMALIZABLE_SCHEMES
|
normalize_uri = scheme is None or scheme.lower() in _NORMALIZABLE_SCHEMES
|
||||||
|
|
||||||
if scheme:
|
if scheme:
|
||||||
scheme = scheme.lower()
|
scheme = scheme.lower()
|
||||||
|
@ -370,31 +422,33 @@ def parse_url(url):
|
||||||
if authority:
|
if authority:
|
||||||
auth, _, host_port = authority.rpartition("@")
|
auth, _, host_port = authority.rpartition("@")
|
||||||
auth = auth or None
|
auth = auth or None
|
||||||
host, port = _HOST_PORT_RE.match(host_port).groups()
|
host, port = _HOST_PORT_RE.match(host_port).groups() # type: ignore[union-attr]
|
||||||
if auth and normalize_uri:
|
if auth and normalize_uri:
|
||||||
auth = _encode_invalid_chars(auth, USERINFO_CHARS)
|
auth = _encode_invalid_chars(auth, _USERINFO_CHARS)
|
||||||
if port == "":
|
if port == "":
|
||||||
port = None
|
port = None
|
||||||
else:
|
else:
|
||||||
auth, host, port = None, None, None
|
auth, host, port = None, None, None
|
||||||
|
|
||||||
if port is not None:
|
if port is not None:
|
||||||
port = int(port)
|
port_int = int(port)
|
||||||
if not (0 <= port <= 65535):
|
if not (0 <= port_int <= 65535):
|
||||||
raise LocationParseError(url)
|
raise LocationParseError(url)
|
||||||
|
else:
|
||||||
|
port_int = None
|
||||||
|
|
||||||
host = _normalize_host(host, scheme)
|
host = _normalize_host(host, scheme)
|
||||||
|
|
||||||
if normalize_uri and path:
|
if normalize_uri and path:
|
||||||
path = _remove_path_dot_segments(path)
|
path = _remove_path_dot_segments(path)
|
||||||
path = _encode_invalid_chars(path, PATH_CHARS)
|
path = _encode_invalid_chars(path, _PATH_CHARS)
|
||||||
if normalize_uri and query:
|
if normalize_uri and query:
|
||||||
query = _encode_invalid_chars(query, QUERY_CHARS)
|
query = _encode_invalid_chars(query, _QUERY_CHARS)
|
||||||
if normalize_uri and fragment:
|
if normalize_uri and fragment:
|
||||||
fragment = _encode_invalid_chars(fragment, FRAGMENT_CHARS)
|
fragment = _encode_invalid_chars(fragment, _FRAGMENT_CHARS)
|
||||||
|
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError) as e:
|
||||||
return six.raise_from(LocationParseError(source_url), None)
|
raise LocationParseError(source_url) from e
|
||||||
|
|
||||||
# For the sake of backwards compatibility we put empty
|
# For the sake of backwards compatibility we put empty
|
||||||
# string values for path if there are any defined values
|
# string values for path if there are any defined values
|
||||||
|
@ -406,30 +460,12 @@ def parse_url(url):
|
||||||
else:
|
else:
|
||||||
path = None
|
path = None
|
||||||
|
|
||||||
# Ensure that each part of the URL is a `str` for
|
|
||||||
# backwards compatibility.
|
|
||||||
if isinstance(url, six.text_type):
|
|
||||||
ensure_func = six.ensure_text
|
|
||||||
else:
|
|
||||||
ensure_func = six.ensure_str
|
|
||||||
|
|
||||||
def ensure_type(x):
|
|
||||||
return x if x is None else ensure_func(x)
|
|
||||||
|
|
||||||
return Url(
|
return Url(
|
||||||
scheme=ensure_type(scheme),
|
scheme=scheme,
|
||||||
auth=ensure_type(auth),
|
auth=auth,
|
||||||
host=ensure_type(host),
|
host=host,
|
||||||
port=port,
|
port=port_int,
|
||||||
path=ensure_type(path),
|
path=path,
|
||||||
query=ensure_type(query),
|
query=query,
|
||||||
fragment=ensure_type(fragment),
|
fragment=fragment,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_host(url):
|
|
||||||
"""
|
|
||||||
Deprecated. Use :func:`parse_url` instead.
|
|
||||||
"""
|
|
||||||
p = parse_url(url)
|
|
||||||
return p.scheme or "http", p.hostname, p.port
|
|
||||||
|
|
|
@ -1,32 +0,0 @@
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from .. import exceptions
|
|
||||||
|
|
||||||
LocationParseError = exceptions.LocationParseError
|
|
||||||
|
|
||||||
url_attrs: List[str]
|
|
||||||
|
|
||||||
class Url:
|
|
||||||
slots: Any
|
|
||||||
def __new__(
|
|
||||||
cls,
|
|
||||||
scheme: Optional[str],
|
|
||||||
auth: Optional[str],
|
|
||||||
host: Optional[str],
|
|
||||||
port: Optional[str],
|
|
||||||
path: Optional[str],
|
|
||||||
query: Optional[str],
|
|
||||||
fragment: Optional[str],
|
|
||||||
) -> Url: ...
|
|
||||||
@property
|
|
||||||
def hostname(self) -> str: ...
|
|
||||||
@property
|
|
||||||
def request_uri(self) -> str: ...
|
|
||||||
@property
|
|
||||||
def netloc(self) -> str: ...
|
|
||||||
@property
|
|
||||||
def url(self) -> str: ...
|
|
||||||
|
|
||||||
def split_first(s: str, delims: str) -> Tuple[str, str, Optional[str]]: ...
|
|
||||||
def parse_url(url: str) -> Url: ...
|
|
||||||
def get_host(url: str) -> Union[str, Tuple[str]]: ...
|
|
42
lib/urllib3/util/util.py
Normal file
42
lib/urllib3/util/util.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
from types import TracebackType
|
||||||
|
|
||||||
|
|
||||||
|
def to_bytes(
|
||||||
|
x: str | bytes, encoding: str | None = None, errors: str | None = None
|
||||||
|
) -> bytes:
|
||||||
|
if isinstance(x, bytes):
|
||||||
|
return x
|
||||||
|
elif not isinstance(x, str):
|
||||||
|
raise TypeError(f"not expecting type {type(x).__name__}")
|
||||||
|
if encoding or errors:
|
||||||
|
return x.encode(encoding or "utf-8", errors=errors or "strict")
|
||||||
|
return x.encode()
|
||||||
|
|
||||||
|
|
||||||
|
def to_str(
|
||||||
|
x: str | bytes, encoding: str | None = None, errors: str | None = None
|
||||||
|
) -> str:
|
||||||
|
if isinstance(x, str):
|
||||||
|
return x
|
||||||
|
elif not isinstance(x, bytes):
|
||||||
|
raise TypeError(f"not expecting type {type(x).__name__}")
|
||||||
|
if encoding or errors:
|
||||||
|
return x.decode(encoding or "utf-8", errors=errors or "strict")
|
||||||
|
return x.decode()
|
||||||
|
|
||||||
|
|
||||||
|
def reraise(
|
||||||
|
tp: type[BaseException] | None,
|
||||||
|
value: BaseException,
|
||||||
|
tb: TracebackType | None = None,
|
||||||
|
) -> typing.NoReturn:
|
||||||
|
try:
|
||||||
|
if value.__traceback__ is not tb:
|
||||||
|
raise value.with_traceback(tb)
|
||||||
|
raise value
|
||||||
|
finally:
|
||||||
|
value = None # type: ignore[assignment]
|
||||||
|
tb = None
|
|
@ -1,18 +1,10 @@
|
||||||
import errno
|
from __future__ import annotations
|
||||||
|
|
||||||
import select
|
import select
|
||||||
import sys
|
import socket
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
try:
|
__all__ = ["wait_for_read", "wait_for_write"]
|
||||||
from time import monotonic
|
|
||||||
except ImportError:
|
|
||||||
from time import time as monotonic
|
|
||||||
|
|
||||||
__all__ = ["NoWayToWaitForSocketError", "wait_for_read", "wait_for_write"]
|
|
||||||
|
|
||||||
|
|
||||||
class NoWayToWaitForSocketError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# How should we wait on sockets?
|
# How should we wait on sockets?
|
||||||
|
@ -37,37 +29,13 @@ class NoWayToWaitForSocketError(Exception):
|
||||||
# So: on Windows we use select(), and everywhere else we use poll(). We also
|
# So: on Windows we use select(), and everywhere else we use poll(). We also
|
||||||
# fall back to select() in case poll() is somehow broken or missing.
|
# fall back to select() in case poll() is somehow broken or missing.
|
||||||
|
|
||||||
if sys.version_info >= (3, 5):
|
|
||||||
# Modern Python, that retries syscalls by default
|
|
||||||
def _retry_on_intr(fn, timeout):
|
|
||||||
return fn(timeout)
|
|
||||||
|
|
||||||
else:
|
def select_wait_for_socket(
|
||||||
# Old and broken Pythons.
|
sock: socket.socket,
|
||||||
def _retry_on_intr(fn, timeout):
|
read: bool = False,
|
||||||
if timeout is None:
|
write: bool = False,
|
||||||
deadline = float("inf")
|
timeout: float | None = None,
|
||||||
else:
|
) -> bool:
|
||||||
deadline = monotonic() + timeout
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
return fn(timeout)
|
|
||||||
# OSError for 3 <= pyver < 3.5, select.error for pyver <= 2.7
|
|
||||||
except (OSError, select.error) as e:
|
|
||||||
# 'e.args[0]' incantation works for both OSError and select.error
|
|
||||||
if e.args[0] != errno.EINTR:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
timeout = deadline - monotonic()
|
|
||||||
if timeout < 0:
|
|
||||||
timeout = 0
|
|
||||||
if timeout == float("inf"):
|
|
||||||
timeout = None
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
def select_wait_for_socket(sock, read=False, write=False, timeout=None):
|
|
||||||
if not read and not write:
|
if not read and not write:
|
||||||
raise RuntimeError("must specify at least one of read=True, write=True")
|
raise RuntimeError("must specify at least one of read=True, write=True")
|
||||||
rcheck = []
|
rcheck = []
|
||||||
|
@ -82,11 +50,16 @@ def select_wait_for_socket(sock, read=False, write=False, timeout=None):
|
||||||
# sockets for both conditions. (The stdlib selectors module does the same
|
# sockets for both conditions. (The stdlib selectors module does the same
|
||||||
# thing.)
|
# thing.)
|
||||||
fn = partial(select.select, rcheck, wcheck, wcheck)
|
fn = partial(select.select, rcheck, wcheck, wcheck)
|
||||||
rready, wready, xready = _retry_on_intr(fn, timeout)
|
rready, wready, xready = fn(timeout)
|
||||||
return bool(rready or wready or xready)
|
return bool(rready or wready or xready)
|
||||||
|
|
||||||
|
|
||||||
def poll_wait_for_socket(sock, read=False, write=False, timeout=None):
|
def poll_wait_for_socket(
|
||||||
|
sock: socket.socket,
|
||||||
|
read: bool = False,
|
||||||
|
write: bool = False,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> bool:
|
||||||
if not read and not write:
|
if not read and not write:
|
||||||
raise RuntimeError("must specify at least one of read=True, write=True")
|
raise RuntimeError("must specify at least one of read=True, write=True")
|
||||||
mask = 0
|
mask = 0
|
||||||
|
@ -98,32 +71,33 @@ def poll_wait_for_socket(sock, read=False, write=False, timeout=None):
|
||||||
poll_obj.register(sock, mask)
|
poll_obj.register(sock, mask)
|
||||||
|
|
||||||
# For some reason, poll() takes timeout in milliseconds
|
# For some reason, poll() takes timeout in milliseconds
|
||||||
def do_poll(t):
|
def do_poll(t: float | None) -> list[tuple[int, int]]:
|
||||||
if t is not None:
|
if t is not None:
|
||||||
t *= 1000
|
t *= 1000
|
||||||
return poll_obj.poll(t)
|
return poll_obj.poll(t)
|
||||||
|
|
||||||
return bool(_retry_on_intr(do_poll, timeout))
|
return bool(do_poll(timeout))
|
||||||
|
|
||||||
|
|
||||||
def null_wait_for_socket(*args, **kwargs):
|
def _have_working_poll() -> bool:
|
||||||
raise NoWayToWaitForSocketError("no select-equivalent available")
|
|
||||||
|
|
||||||
|
|
||||||
def _have_working_poll():
|
|
||||||
# Apparently some systems have a select.poll that fails as soon as you try
|
# Apparently some systems have a select.poll that fails as soon as you try
|
||||||
# to use it, either due to strange configuration or broken monkeypatching
|
# to use it, either due to strange configuration or broken monkeypatching
|
||||||
# from libraries like eventlet/greenlet.
|
# from libraries like eventlet/greenlet.
|
||||||
try:
|
try:
|
||||||
poll_obj = select.poll()
|
poll_obj = select.poll()
|
||||||
_retry_on_intr(poll_obj.poll, 0)
|
poll_obj.poll(0)
|
||||||
except (AttributeError, OSError):
|
except (AttributeError, OSError):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def wait_for_socket(*args, **kwargs):
|
def wait_for_socket(
|
||||||
|
sock: socket.socket,
|
||||||
|
read: bool = False,
|
||||||
|
write: bool = False,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> bool:
|
||||||
# We delay choosing which implementation to use until the first time we're
|
# We delay choosing which implementation to use until the first time we're
|
||||||
# called. We could do it at import time, but then we might make the wrong
|
# called. We could do it at import time, but then we might make the wrong
|
||||||
# decision if someone goes wild with monkeypatching select.poll after
|
# decision if someone goes wild with monkeypatching select.poll after
|
||||||
|
@ -133,19 +107,17 @@ def wait_for_socket(*args, **kwargs):
|
||||||
wait_for_socket = poll_wait_for_socket
|
wait_for_socket = poll_wait_for_socket
|
||||||
elif hasattr(select, "select"):
|
elif hasattr(select, "select"):
|
||||||
wait_for_socket = select_wait_for_socket
|
wait_for_socket = select_wait_for_socket
|
||||||
else: # Platform-specific: Appengine.
|
return wait_for_socket(sock, read, write, timeout)
|
||||||
wait_for_socket = null_wait_for_socket
|
|
||||||
return wait_for_socket(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def wait_for_read(sock, timeout=None):
|
def wait_for_read(sock: socket.socket, timeout: float | None = None) -> bool:
|
||||||
"""Waits for reading to be available on a given socket.
|
"""Waits for reading to be available on a given socket.
|
||||||
Returns True if the socket is readable, or False if the timeout expired.
|
Returns True if the socket is readable, or False if the timeout expired.
|
||||||
"""
|
"""
|
||||||
return wait_for_socket(sock, read=True, timeout=timeout)
|
return wait_for_socket(sock, read=True, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
def wait_for_write(sock, timeout=None):
|
def wait_for_write(sock: socket.socket, timeout: float | None = None) -> bool:
|
||||||
"""Waits for writing to be available on a given socket.
|
"""Waits for writing to be available on a given socket.
|
||||||
Returns True if the socket is readable, or False if the timeout expired.
|
Returns True if the socket is readable, or False if the timeout expired.
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in a new issue