Merge branch 'feature/UpdateUrllib3' into dev

This commit is contained in:
JackDandy 2024-06-05 08:56:16 +01:00
commit 851cb7786e
24 changed files with 579 additions and 1935 deletions

View file

@ -1,6 +1,7 @@
### 3.32.0 (2024-xx-xx xx:xx:00 UTC) ### 3.32.0 (2024-xx-xx xx:xx:00 UTC)
* Update Requests library 2.31.0 (8812812) to 2.32.3 (0e322af) * Update Requests library 2.31.0 (8812812) to 2.32.3 (0e322af)
* Update urllib3 2.0.7 (56f01e0) to 2.2.1 (54d6edf)
### 3.31.0 (2024-06-05 08:00:00 UTC) ### 3.31.0 (2024-06-05 08:00:00 UTC)

View file

@ -6,6 +6,7 @@ from __future__ import annotations
# Set default logging handler to avoid "No handler found" warnings. # Set default logging handler to avoid "No handler found" warnings.
import logging import logging
import sys
import typing import typing
import warnings import warnings
from logging import NullHandler from logging import NullHandler
@ -32,35 +33,18 @@ except ImportError:
else: else:
if not ssl.OPENSSL_VERSION.startswith("OpenSSL "): # Defensive: if not ssl.OPENSSL_VERSION.startswith("OpenSSL "): # Defensive:
warnings.warn( warnings.warn(
"urllib3 v2.0 only supports OpenSSL 1.1.1+, currently " "urllib3 v2 only supports OpenSSL 1.1.1+, currently "
f"the 'ssl' module is compiled with {ssl.OPENSSL_VERSION!r}. " f"the 'ssl' module is compiled with {ssl.OPENSSL_VERSION!r}. "
"See: https://github.com/urllib3/urllib3/issues/3020", "See: https://github.com/urllib3/urllib3/issues/3020",
exceptions.NotOpenSSLWarning, exceptions.NotOpenSSLWarning,
) )
elif ssl.OPENSSL_VERSION_INFO < (1, 1, 1): # Defensive: elif ssl.OPENSSL_VERSION_INFO < (1, 1, 1): # Defensive:
raise ImportError( raise ImportError(
"urllib3 v2.0 only supports OpenSSL 1.1.1+, currently " "urllib3 v2 only supports OpenSSL 1.1.1+, currently "
f"the 'ssl' module is compiled with {ssl.OPENSSL_VERSION!r}. " f"the 'ssl' module is compiled with {ssl.OPENSSL_VERSION!r}. "
"See: https://github.com/urllib3/urllib3/issues/2168" "See: https://github.com/urllib3/urllib3/issues/2168"
) )
# === NOTE TO REPACKAGERS AND VENDORS ===
# Please delete this block, this logic is only
# for urllib3 being distributed via PyPI.
# See: https://github.com/urllib3/urllib3/issues/2680
try:
import urllib3_secure_extra # type: ignore # noqa: F401
except ModuleNotFoundError:
pass
else:
warnings.warn(
"'urllib3[secure]' extra is deprecated and will be removed "
"in urllib3 v2.1.0. Read more in this issue: "
"https://github.com/urllib3/urllib3/issues/2680",
category=DeprecationWarning,
stacklevel=2,
)
__author__ = "Andrey Petrov (andrey.petrov@shazow.net)" __author__ = "Andrey Petrov (andrey.petrov@shazow.net)"
__license__ = "MIT" __license__ = "MIT"
__version__ = __version__ __version__ = __version__
@ -149,6 +133,61 @@ def request(
Therefore, its side effects could be shared across dependencies relying on it. Therefore, its side effects could be shared across dependencies relying on it.
To avoid side effects create a new ``PoolManager`` instance and use it instead. To avoid side effects create a new ``PoolManager`` instance and use it instead.
The method does not accept low-level ``**urlopen_kw`` keyword arguments. The method does not accept low-level ``**urlopen_kw`` keyword arguments.
:param method:
HTTP request method (such as GET, POST, PUT, etc.)
:param url:
The URL to perform the request on.
:param body:
Data to send in the request body, either :class:`str`, :class:`bytes`,
an iterable of :class:`str`/:class:`bytes`, or a file-like object.
:param fields:
Data to encode and send in the request body.
:param headers:
Dictionary of custom headers to send, such as User-Agent,
If-None-Match, etc.
:param bool preload_content:
If True, the response's body will be preloaded into memory.
:param bool decode_content:
If True, will attempt to decode the body based on the
'content-encoding' header.
:param redirect:
If True, automatically handle redirects (status codes 301, 302,
303, 307, 308). Each redirect counts as a retry. Disabling retries
will disable redirect, too.
:param retries:
Configure the number of retries to allow before raising a
:class:`~urllib3.exceptions.MaxRetryError` exception.
If ``None`` (default) will retry 3 times, see ``Retry.DEFAULT``. Pass a
:class:`~urllib3.util.retry.Retry` object for fine-grained control
over different types of retries.
Pass an integer number to retry connection errors that many times,
but no other types of errors. Pass zero to never retry.
If ``False``, then retries are disabled and any exception is raised
immediately. Also, instead of raising a MaxRetryError on redirects,
the redirect response will be returned.
:type retries: :class:`~urllib3.util.retry.Retry`, False, or an int.
:param timeout:
If specified, overrides the default timeout for this one
request. It may be a float (in seconds) or an instance of
:class:`urllib3.util.Timeout`.
:param json:
Data to encode and send as JSON with UTF-encoded in the request body.
The ``"Content-Type"`` header will be set to ``"application/json"``
unless specified otherwise.
""" """
return _DEFAULT_POOL.request( return _DEFAULT_POOL.request(
@ -164,3 +203,9 @@ def request(
timeout=timeout, timeout=timeout,
json=json, json=json,
) )
if sys.platform == "emscripten":
from .contrib.emscripten import inject_into_urllib3 # noqa: 401
inject_into_urllib3()

View file

@ -28,8 +28,7 @@ class _ResponseOptions(typing.NamedTuple):
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
import ssl import ssl
from typing import Literal, Protocol
from typing_extensions import Literal, Protocol
from .response import BaseHTTPResponse from .response import BaseHTTPResponse

View file

@ -8,7 +8,9 @@ from threading import RLock
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
# We can only import Protocol if TYPE_CHECKING because it's a development # We can only import Protocol if TYPE_CHECKING because it's a development
# dependency, and is not available at runtime. # dependency, and is not available at runtime.
from typing_extensions import Protocol, Self from typing import Protocol
from typing_extensions import Self
class HasGettableStringKeys(Protocol): class HasGettableStringKeys(Protocol):
def keys(self) -> typing.Iterator[str]: def keys(self) -> typing.Iterator[str]:
@ -239,7 +241,7 @@ class HTTPHeaderDict(typing.MutableMapping[str, str]):
def __init__(self, headers: ValidHTTPHeaderSource | None = None, **kwargs: str): def __init__(self, headers: ValidHTTPHeaderSource | None = None, **kwargs: str):
super().__init__() super().__init__()
self._container = {} # 'dict' is insert-ordered in Python 3.7+ self._container = {} # 'dict' is insert-ordered
if headers is not None: if headers is not None:
if isinstance(headers, HTTPHeaderDict): if isinstance(headers, HTTPHeaderDict):
self._copy_from(headers) self._copy_from(headers)

View file

@ -85,6 +85,30 @@ class RequestMethods:
option to drop down to more specific methods when necessary, such as option to drop down to more specific methods when necessary, such as
:meth:`request_encode_url`, :meth:`request_encode_body`, :meth:`request_encode_url`, :meth:`request_encode_body`,
or even the lowest level :meth:`urlopen`. or even the lowest level :meth:`urlopen`.
:param method:
HTTP request method (such as GET, POST, PUT, etc.)
:param url:
The URL to perform the request on.
:param body:
Data to send in the request body, either :class:`str`, :class:`bytes`,
an iterable of :class:`str`/:class:`bytes`, or a file-like object.
:param fields:
Data to encode and send in the request body. Values are processed
by :func:`urllib.parse.urlencode`.
:param headers:
Dictionary of custom headers to send, such as User-Agent,
If-None-Match, etc. If None, pool headers are used. If provided,
these headers completely replace any pool-specific headers.
:param json:
Data to encode and send as JSON with UTF-encoded in the request body.
The ``"Content-Type"`` header will be set to ``"application/json"``
unless specified otherwise.
""" """
method = method.upper() method = method.upper()
@ -95,9 +119,11 @@ class RequestMethods:
if json is not None: if json is not None:
if headers is None: if headers is None:
headers = self.headers.copy() # type: ignore headers = self.headers
if not ("content-type" in map(str.lower, headers.keys())): if not ("content-type" in map(str.lower, headers.keys())):
headers["Content-Type"] = "application/json" # type: ignore headers = HTTPHeaderDict(headers)
headers["Content-Type"] = "application/json"
body = _json.dumps(json, separators=(",", ":"), ensure_ascii=False).encode( body = _json.dumps(json, separators=(",", ":"), ensure_ascii=False).encode(
"utf-8" "utf-8"
@ -130,6 +156,20 @@ class RequestMethods:
""" """
Make a request using :meth:`urlopen` with the ``fields`` encoded in Make a request using :meth:`urlopen` with the ``fields`` encoded in
the url. This is useful for request methods like GET, HEAD, DELETE, etc. the url. This is useful for request methods like GET, HEAD, DELETE, etc.
:param method:
HTTP request method (such as GET, POST, PUT, etc.)
:param url:
The URL to perform the request on.
:param fields:
Data to encode and send in the request body.
:param headers:
Dictionary of custom headers to send, such as User-Agent,
If-None-Match, etc. If None, pool headers are used. If provided,
these headers completely replace any pool-specific headers.
""" """
if headers is None: if headers is None:
headers = self.headers headers = self.headers
@ -186,6 +226,28 @@ class RequestMethods:
be overwritten because it depends on the dynamic random boundary string be overwritten because it depends on the dynamic random boundary string
which is used to compose the body of the request. The random boundary which is used to compose the body of the request. The random boundary
string can be explicitly set with the ``multipart_boundary`` parameter. string can be explicitly set with the ``multipart_boundary`` parameter.
:param method:
HTTP request method (such as GET, POST, PUT, etc.)
:param url:
The URL to perform the request on.
:param fields:
Data to encode and send in the request body.
:param headers:
Dictionary of custom headers to send, such as User-Agent,
If-None-Match, etc. If None, pool headers are used. If provided,
these headers completely replace any pool-specific headers.
:param encode_multipart:
If True, encode the ``fields`` using the multipart/form-data MIME
format.
:param multipart_boundary:
If not specified, then a random boundary will be generated using
:func:`urllib3.filepost.choose_boundary`.
""" """
if headers is None: if headers is None:
headers = self.headers headers = self.headers

View file

@ -1,4 +1,4 @@
# This file is protected via CODEOWNERS # This file is protected via CODEOWNERS
from __future__ import annotations from __future__ import annotations
__version__ = "2.0.7" __version__ = "2.2.1"

View file

@ -14,7 +14,7 @@ from http.client import ResponseNotReady
from socket import timeout as SocketTimeout from socket import timeout as SocketTimeout
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing_extensions import Literal from typing import Literal
from .response import HTTPResponse from .response import HTTPResponse
from .util.ssl_ import _TYPE_PEER_CERT_RET_DICT from .util.ssl_ import _TYPE_PEER_CERT_RET_DICT
@ -73,7 +73,7 @@ port_by_scheme = {"http": 80, "https": 443}
# When it comes time to update this value as a part of regular maintenance # When it comes time to update this value as a part of regular maintenance
# (ie test_recent_date is failing) update it to ~6 months before the current date. # (ie test_recent_date is failing) update it to ~6 months before the current date.
RECENT_DATE = datetime.date(2022, 1, 1) RECENT_DATE = datetime.date(2023, 6, 1)
_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]") _CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
@ -160,11 +160,6 @@ class HTTPConnection(_HTTPConnection):
self._tunnel_port: int | None = None self._tunnel_port: int | None = None
self._tunnel_scheme: str | None = None self._tunnel_scheme: str | None = None
# https://github.com/python/mypy/issues/4125
# Mypy treats this as LSP violation, which is considered a bug.
# If `host` is made a property it violates LSP, because a writeable attribute is overridden with a read-only one.
# However, there is also a `host` setter so LSP is not violated.
# Potentially, a `@host.deleter` might be needed depending on how this issue will be fixed.
@property @property
def host(self) -> str: def host(self) -> str:
""" """
@ -253,6 +248,9 @@ class HTTPConnection(_HTTPConnection):
# not using tunnelling. # not using tunnelling.
self._has_connected_to_proxy = bool(self.proxy) self._has_connected_to_proxy = bool(self.proxy)
if self._has_connected_to_proxy:
self.proxy_is_verified = False
@property @property
def is_closed(self) -> bool: def is_closed(self) -> bool:
return self.sock is None return self.sock is None
@ -267,6 +265,13 @@ class HTTPConnection(_HTTPConnection):
def has_connected_to_proxy(self) -> bool: def has_connected_to_proxy(self) -> bool:
return self._has_connected_to_proxy return self._has_connected_to_proxy
@property
def proxy_is_forwarding(self) -> bool:
"""
Return True if a forwarding proxy is configured, else return False
"""
return bool(self.proxy) and self._tunnel_host is None
def close(self) -> None: def close(self) -> None:
try: try:
super().close() super().close()
@ -302,7 +307,7 @@ class HTTPConnection(_HTTPConnection):
method, url, skip_host=skip_host, skip_accept_encoding=skip_accept_encoding method, url, skip_host=skip_host, skip_accept_encoding=skip_accept_encoding
) )
def putheader(self, header: str, *values: str) -> None: def putheader(self, header: str, *values: str) -> None: # type: ignore[override]
"""""" """"""
if not any(isinstance(v, str) and v == SKIP_HEADER for v in values): if not any(isinstance(v, str) and v == SKIP_HEADER for v in values):
super().putheader(header, *values) super().putheader(header, *values)
@ -616,8 +621,11 @@ class HTTPSConnection(HTTPConnection):
if self._tunnel_host is not None: if self._tunnel_host is not None:
# We're tunneling to an HTTPS origin so need to do TLS-in-TLS. # We're tunneling to an HTTPS origin so need to do TLS-in-TLS.
if self._tunnel_scheme == "https": if self._tunnel_scheme == "https":
# _connect_tls_proxy will verify and assign proxy_is_verified
self.sock = sock = self._connect_tls_proxy(self.host, sock) self.sock = sock = self._connect_tls_proxy(self.host, sock)
tls_in_tls = True tls_in_tls = True
elif self._tunnel_scheme == "http":
self.proxy_is_verified = False
# If we're tunneling it means we're connected to our proxy. # If we're tunneling it means we're connected to our proxy.
self._has_connected_to_proxy = True self._has_connected_to_proxy = True
@ -639,6 +647,9 @@ class HTTPSConnection(HTTPConnection):
SystemTimeWarning, SystemTimeWarning,
) )
# Remove trailing '.' from fqdn hostnames to allow certificate validation
server_hostname_rm_dot = server_hostname.rstrip(".")
sock_and_verified = _ssl_wrap_socket_and_match_hostname( sock_and_verified = _ssl_wrap_socket_and_match_hostname(
sock=sock, sock=sock,
cert_reqs=self.cert_reqs, cert_reqs=self.cert_reqs,
@ -651,13 +662,21 @@ class HTTPSConnection(HTTPConnection):
cert_file=self.cert_file, cert_file=self.cert_file,
key_file=self.key_file, key_file=self.key_file,
key_password=self.key_password, key_password=self.key_password,
server_hostname=server_hostname, server_hostname=server_hostname_rm_dot,
ssl_context=self.ssl_context, ssl_context=self.ssl_context,
tls_in_tls=tls_in_tls, tls_in_tls=tls_in_tls,
assert_hostname=self.assert_hostname, assert_hostname=self.assert_hostname,
assert_fingerprint=self.assert_fingerprint, assert_fingerprint=self.assert_fingerprint,
) )
self.sock = sock_and_verified.socket self.sock = sock_and_verified.socket
# Forwarding proxies can never have a verified target since
# the proxy is the one doing the verification. Should instead
# use a CONNECT tunnel in order to verify the target.
# See: https://github.com/urllib3/urllib3/issues/3267.
if self.proxy_is_forwarding:
self.is_verified = False
else:
self.is_verified = sock_and_verified.is_verified self.is_verified = sock_and_verified.is_verified
# If there's a proxy to be connected to we are fully connected. # If there's a proxy to be connected to we are fully connected.
@ -665,6 +684,11 @@ class HTTPSConnection(HTTPConnection):
# not using tunnelling. # not using tunnelling.
self._has_connected_to_proxy = bool(self.proxy) self._has_connected_to_proxy = bool(self.proxy)
# Set `self.proxy_is_verified` unless it's already set while
# establishing a tunnel.
if self._has_connected_to_proxy and self.proxy_is_verified is None:
self.proxy_is_verified = sock_and_verified.is_verified
def _connect_tls_proxy(self, hostname: str, sock: socket.socket) -> ssl.SSLSocket: def _connect_tls_proxy(self, hostname: str, sock: socket.socket) -> ssl.SSLSocket:
""" """
Establish a TLS connection to the proxy using the provided SSL context. Establish a TLS connection to the proxy using the provided SSL context.
@ -757,10 +781,9 @@ def _ssl_wrap_socket_and_match_hostname(
): ):
context.check_hostname = False context.check_hostname = False
# Try to load OS default certs if none are given. # Try to load OS default certs if none are given. We need to do the hasattr() check
# We need to do the hasattr() check for our custom # for custom pyOpenSSL SSLContext objects because they don't support
# pyOpenSSL and SecureTransport SSLContext objects # load_default_certs().
# because neither support load_default_certs().
if ( if (
not ca_certs not ca_certs
and not ca_cert_dir and not ca_cert_dir
@ -865,6 +888,7 @@ def _wrap_proxy_error(err: Exception, proxy_scheme: str | None) -> ProxyError:
is_likely_http_proxy = ( is_likely_http_proxy = (
"wrong version number" in error_normalized "wrong version number" in error_normalized
or "unknown protocol" in error_normalized or "unknown protocol" in error_normalized
or "record layer failure" in error_normalized
) )
http_proxy_warning = ( http_proxy_warning = (
". Your proxy appears to only use HTTP and not HTTPS, " ". Your proxy appears to only use HTTP and not HTTPS, "

View file

@ -53,8 +53,7 @@ from .util.util import to_str
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
import ssl import ssl
from typing import Literal
from typing_extensions import Literal
from ._base_connection import BaseHTTPConnection, BaseHTTPSConnection from ._base_connection import BaseHTTPConnection, BaseHTTPSConnection
@ -512,9 +511,10 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
pass pass
except OSError as e: except OSError as e:
# MacOS/Linux # MacOS/Linux
# EPROTOTYPE is needed on macOS # EPROTOTYPE and ECONNRESET are needed on macOS
# https://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/ # https://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
if e.errno != errno.EPROTOTYPE: # Condition changed later to emit ECONNRESET instead of only EPROTOTYPE.
if e.errno != errno.EPROTOTYPE and e.errno != errno.ECONNRESET:
raise raise
# Reset the timeout for the recv() on the socket # Reset the timeout for the recv() on the socket
@ -544,6 +544,8 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
response._connection = response_conn # type: ignore[attr-defined] response._connection = response_conn # type: ignore[attr-defined]
response._pool = self # type: ignore[attr-defined] response._pool = self # type: ignore[attr-defined]
# emscripten connection doesn't have _http_vsn_str
http_version = getattr(conn, "_http_vsn_str", "HTTP/?")
log.debug( log.debug(
'%s://%s:%s "%s %s %s" %s %s', '%s://%s:%s "%s %s %s" %s %s',
self.scheme, self.scheme,
@ -552,9 +554,9 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
method, method,
url, url,
# HTTP version # HTTP version
conn._http_vsn_str, # type: ignore[attr-defined] http_version,
response.status, response.status,
response.length_remaining, # type: ignore[attr-defined] response.length_remaining,
) )
return response return response
@ -647,7 +649,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
Configure the number of retries to allow before raising a Configure the number of retries to allow before raising a
:class:`~urllib3.exceptions.MaxRetryError` exception. :class:`~urllib3.exceptions.MaxRetryError` exception.
Pass ``None`` to retry until you receive a response. Pass a If ``None`` (default) will retry 3 times, see ``Retry.DEFAULT``. Pass a
:class:`~urllib3.util.retry.Retry` object for fine-grained control :class:`~urllib3.util.retry.Retry` object for fine-grained control
over different types of retries. over different types of retries.
Pass an integer number to retry connection errors that many times, Pass an integer number to retry connection errors that many times,
@ -1096,7 +1098,8 @@ class HTTPSConnectionPool(HTTPConnectionPool):
if conn.is_closed: if conn.is_closed:
conn.connect() conn.connect()
if not conn.is_verified: # TODO revise this, see https://github.com/urllib3/urllib3/issues/2791
if not conn.is_verified and not conn.proxy_is_verified:
warnings.warn( warnings.warn(
( (
f"Unverified HTTPS request is being made to host '{conn.host}'. " f"Unverified HTTPS request is being made to host '{conn.host}'. "

View file

@ -1,430 +0,0 @@
# type: ignore
"""
This module uses ctypes to bind a whole bunch of functions and constants from
SecureTransport. The goal here is to provide the low-level API to
SecureTransport. These are essentially the C-level functions and constants, and
they're pretty gross to work with.
This code is a bastardised version of the code found in Will Bond's oscrypto
library. An enormous debt is owed to him for blazing this trail for us. For
that reason, this code should be considered to be covered both by urllib3's
license and by oscrypto's:
Copyright (c) 2015-2016 Will Bond <will@wbond.net>
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import platform
from ctypes import (
CDLL,
CFUNCTYPE,
POINTER,
c_bool,
c_byte,
c_char_p,
c_int32,
c_long,
c_size_t,
c_uint32,
c_ulong,
c_void_p,
)
from ctypes.util import find_library
if platform.system() != "Darwin":
raise ImportError("Only macOS is supported")
version = platform.mac_ver()[0]
version_info = tuple(map(int, version.split(".")))
if version_info < (10, 8):
raise OSError(
f"Only OS X 10.8 and newer are supported, not {version_info[0]}.{version_info[1]}"
)
def load_cdll(name: str, macos10_16_path: str) -> CDLL:
"""Loads a CDLL by name, falling back to known path on 10.16+"""
try:
# Big Sur is technically 11 but we use 10.16 due to the Big Sur
# beta being labeled as 10.16.
path: str | None
if version_info >= (10, 16):
path = macos10_16_path
else:
path = find_library(name)
if not path:
raise OSError # Caught and reraised as 'ImportError'
return CDLL(path, use_errno=True)
except OSError:
raise ImportError(f"The library {name} failed to load") from None
Security = load_cdll(
"Security", "/System/Library/Frameworks/Security.framework/Security"
)
CoreFoundation = load_cdll(
"CoreFoundation",
"/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation",
)
Boolean = c_bool
CFIndex = c_long
CFStringEncoding = c_uint32
CFData = c_void_p
CFString = c_void_p
CFArray = c_void_p
CFMutableArray = c_void_p
CFDictionary = c_void_p
CFError = c_void_p
CFType = c_void_p
CFTypeID = c_ulong
CFTypeRef = POINTER(CFType)
CFAllocatorRef = c_void_p
OSStatus = c_int32
CFDataRef = POINTER(CFData)
CFStringRef = POINTER(CFString)
CFArrayRef = POINTER(CFArray)
CFMutableArrayRef = POINTER(CFMutableArray)
CFDictionaryRef = POINTER(CFDictionary)
CFArrayCallBacks = c_void_p
CFDictionaryKeyCallBacks = c_void_p
CFDictionaryValueCallBacks = c_void_p
SecCertificateRef = POINTER(c_void_p)
SecExternalFormat = c_uint32
SecExternalItemType = c_uint32
SecIdentityRef = POINTER(c_void_p)
SecItemImportExportFlags = c_uint32
SecItemImportExportKeyParameters = c_void_p
SecKeychainRef = POINTER(c_void_p)
SSLProtocol = c_uint32
SSLCipherSuite = c_uint32
SSLContextRef = POINTER(c_void_p)
SecTrustRef = POINTER(c_void_p)
SSLConnectionRef = c_uint32
SecTrustResultType = c_uint32
SecTrustOptionFlags = c_uint32
SSLProtocolSide = c_uint32
SSLConnectionType = c_uint32
SSLSessionOption = c_uint32
try:
Security.SecItemImport.argtypes = [
CFDataRef,
CFStringRef,
POINTER(SecExternalFormat),
POINTER(SecExternalItemType),
SecItemImportExportFlags,
POINTER(SecItemImportExportKeyParameters),
SecKeychainRef,
POINTER(CFArrayRef),
]
Security.SecItemImport.restype = OSStatus
Security.SecCertificateGetTypeID.argtypes = []
Security.SecCertificateGetTypeID.restype = CFTypeID
Security.SecIdentityGetTypeID.argtypes = []
Security.SecIdentityGetTypeID.restype = CFTypeID
Security.SecKeyGetTypeID.argtypes = []
Security.SecKeyGetTypeID.restype = CFTypeID
Security.SecCertificateCreateWithData.argtypes = [CFAllocatorRef, CFDataRef]
Security.SecCertificateCreateWithData.restype = SecCertificateRef
Security.SecCertificateCopyData.argtypes = [SecCertificateRef]
Security.SecCertificateCopyData.restype = CFDataRef
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SecIdentityCreateWithCertificate.argtypes = [
CFTypeRef,
SecCertificateRef,
POINTER(SecIdentityRef),
]
Security.SecIdentityCreateWithCertificate.restype = OSStatus
Security.SecKeychainCreate.argtypes = [
c_char_p,
c_uint32,
c_void_p,
Boolean,
c_void_p,
POINTER(SecKeychainRef),
]
Security.SecKeychainCreate.restype = OSStatus
Security.SecKeychainDelete.argtypes = [SecKeychainRef]
Security.SecKeychainDelete.restype = OSStatus
Security.SecPKCS12Import.argtypes = [
CFDataRef,
CFDictionaryRef,
POINTER(CFArrayRef),
]
Security.SecPKCS12Import.restype = OSStatus
SSLReadFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, c_void_p, POINTER(c_size_t))
SSLWriteFunc = CFUNCTYPE(
OSStatus, SSLConnectionRef, POINTER(c_byte), POINTER(c_size_t)
)
Security.SSLSetIOFuncs.argtypes = [SSLContextRef, SSLReadFunc, SSLWriteFunc]
Security.SSLSetIOFuncs.restype = OSStatus
Security.SSLSetPeerID.argtypes = [SSLContextRef, c_char_p, c_size_t]
Security.SSLSetPeerID.restype = OSStatus
Security.SSLSetCertificate.argtypes = [SSLContextRef, CFArrayRef]
Security.SSLSetCertificate.restype = OSStatus
Security.SSLSetCertificateAuthorities.argtypes = [SSLContextRef, CFTypeRef, Boolean]
Security.SSLSetCertificateAuthorities.restype = OSStatus
Security.SSLSetConnection.argtypes = [SSLContextRef, SSLConnectionRef]
Security.SSLSetConnection.restype = OSStatus
Security.SSLSetPeerDomainName.argtypes = [SSLContextRef, c_char_p, c_size_t]
Security.SSLSetPeerDomainName.restype = OSStatus
Security.SSLHandshake.argtypes = [SSLContextRef]
Security.SSLHandshake.restype = OSStatus
Security.SSLRead.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
Security.SSLRead.restype = OSStatus
Security.SSLWrite.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
Security.SSLWrite.restype = OSStatus
Security.SSLClose.argtypes = [SSLContextRef]
Security.SSLClose.restype = OSStatus
Security.SSLGetNumberSupportedCiphers.argtypes = [SSLContextRef, POINTER(c_size_t)]
Security.SSLGetNumberSupportedCiphers.restype = OSStatus
Security.SSLGetSupportedCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
POINTER(c_size_t),
]
Security.SSLGetSupportedCiphers.restype = OSStatus
Security.SSLSetEnabledCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
c_size_t,
]
Security.SSLSetEnabledCiphers.restype = OSStatus
Security.SSLGetNumberEnabledCiphers.argtype = [SSLContextRef, POINTER(c_size_t)]
Security.SSLGetNumberEnabledCiphers.restype = OSStatus
Security.SSLGetEnabledCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
POINTER(c_size_t),
]
Security.SSLGetEnabledCiphers.restype = OSStatus
Security.SSLGetNegotiatedCipher.argtypes = [SSLContextRef, POINTER(SSLCipherSuite)]
Security.SSLGetNegotiatedCipher.restype = OSStatus
Security.SSLGetNegotiatedProtocolVersion.argtypes = [
SSLContextRef,
POINTER(SSLProtocol),
]
Security.SSLGetNegotiatedProtocolVersion.restype = OSStatus
Security.SSLCopyPeerTrust.argtypes = [SSLContextRef, POINTER(SecTrustRef)]
Security.SSLCopyPeerTrust.restype = OSStatus
Security.SecTrustSetAnchorCertificates.argtypes = [SecTrustRef, CFArrayRef]
Security.SecTrustSetAnchorCertificates.restype = OSStatus
Security.SecTrustSetAnchorCertificatesOnly.argstypes = [SecTrustRef, Boolean]
Security.SecTrustSetAnchorCertificatesOnly.restype = OSStatus
Security.SecTrustEvaluate.argtypes = [SecTrustRef, POINTER(SecTrustResultType)]
Security.SecTrustEvaluate.restype = OSStatus
Security.SecTrustGetCertificateCount.argtypes = [SecTrustRef]
Security.SecTrustGetCertificateCount.restype = CFIndex
Security.SecTrustGetCertificateAtIndex.argtypes = [SecTrustRef, CFIndex]
Security.SecTrustGetCertificateAtIndex.restype = SecCertificateRef
Security.SSLCreateContext.argtypes = [
CFAllocatorRef,
SSLProtocolSide,
SSLConnectionType,
]
Security.SSLCreateContext.restype = SSLContextRef
Security.SSLSetSessionOption.argtypes = [SSLContextRef, SSLSessionOption, Boolean]
Security.SSLSetSessionOption.restype = OSStatus
Security.SSLSetProtocolVersionMin.argtypes = [SSLContextRef, SSLProtocol]
Security.SSLSetProtocolVersionMin.restype = OSStatus
Security.SSLSetProtocolVersionMax.argtypes = [SSLContextRef, SSLProtocol]
Security.SSLSetProtocolVersionMax.restype = OSStatus
try:
Security.SSLSetALPNProtocols.argtypes = [SSLContextRef, CFArrayRef]
Security.SSLSetALPNProtocols.restype = OSStatus
except AttributeError:
# Supported only in 10.12+
pass
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SSLReadFunc = SSLReadFunc
Security.SSLWriteFunc = SSLWriteFunc
Security.SSLContextRef = SSLContextRef
Security.SSLProtocol = SSLProtocol
Security.SSLCipherSuite = SSLCipherSuite
Security.SecIdentityRef = SecIdentityRef
Security.SecKeychainRef = SecKeychainRef
Security.SecTrustRef = SecTrustRef
Security.SecTrustResultType = SecTrustResultType
Security.SecExternalFormat = SecExternalFormat
Security.OSStatus = OSStatus
Security.kSecImportExportPassphrase = CFStringRef.in_dll(
Security, "kSecImportExportPassphrase"
)
Security.kSecImportItemIdentity = CFStringRef.in_dll(
Security, "kSecImportItemIdentity"
)
# CoreFoundation time!
CoreFoundation.CFRetain.argtypes = [CFTypeRef]
CoreFoundation.CFRetain.restype = CFTypeRef
CoreFoundation.CFRelease.argtypes = [CFTypeRef]
CoreFoundation.CFRelease.restype = None
CoreFoundation.CFGetTypeID.argtypes = [CFTypeRef]
CoreFoundation.CFGetTypeID.restype = CFTypeID
CoreFoundation.CFStringCreateWithCString.argtypes = [
CFAllocatorRef,
c_char_p,
CFStringEncoding,
]
CoreFoundation.CFStringCreateWithCString.restype = CFStringRef
CoreFoundation.CFStringGetCStringPtr.argtypes = [CFStringRef, CFStringEncoding]
CoreFoundation.CFStringGetCStringPtr.restype = c_char_p
CoreFoundation.CFStringGetCString.argtypes = [
CFStringRef,
c_char_p,
CFIndex,
CFStringEncoding,
]
CoreFoundation.CFStringGetCString.restype = c_bool
CoreFoundation.CFDataCreate.argtypes = [CFAllocatorRef, c_char_p, CFIndex]
CoreFoundation.CFDataCreate.restype = CFDataRef
CoreFoundation.CFDataGetLength.argtypes = [CFDataRef]
CoreFoundation.CFDataGetLength.restype = CFIndex
CoreFoundation.CFDataGetBytePtr.argtypes = [CFDataRef]
CoreFoundation.CFDataGetBytePtr.restype = c_void_p
CoreFoundation.CFDictionaryCreate.argtypes = [
CFAllocatorRef,
POINTER(CFTypeRef),
POINTER(CFTypeRef),
CFIndex,
CFDictionaryKeyCallBacks,
CFDictionaryValueCallBacks,
]
CoreFoundation.CFDictionaryCreate.restype = CFDictionaryRef
CoreFoundation.CFDictionaryGetValue.argtypes = [CFDictionaryRef, CFTypeRef]
CoreFoundation.CFDictionaryGetValue.restype = CFTypeRef
CoreFoundation.CFArrayCreate.argtypes = [
CFAllocatorRef,
POINTER(CFTypeRef),
CFIndex,
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreate.restype = CFArrayRef
CoreFoundation.CFArrayCreateMutable.argtypes = [
CFAllocatorRef,
CFIndex,
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreateMutable.restype = CFMutableArrayRef
CoreFoundation.CFArrayAppendValue.argtypes = [CFMutableArrayRef, c_void_p]
CoreFoundation.CFArrayAppendValue.restype = None
CoreFoundation.CFArrayGetCount.argtypes = [CFArrayRef]
CoreFoundation.CFArrayGetCount.restype = CFIndex
CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArrayRef, CFIndex]
CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p
CoreFoundation.kCFAllocatorDefault = CFAllocatorRef.in_dll(
CoreFoundation, "kCFAllocatorDefault"
)
CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeArrayCallBacks"
)
CoreFoundation.kCFTypeDictionaryKeyCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeDictionaryKeyCallBacks"
)
CoreFoundation.kCFTypeDictionaryValueCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeDictionaryValueCallBacks"
)
CoreFoundation.CFTypeRef = CFTypeRef
CoreFoundation.CFArrayRef = CFArrayRef
CoreFoundation.CFStringRef = CFStringRef
CoreFoundation.CFDictionaryRef = CFDictionaryRef
except AttributeError:
raise ImportError("Error initializing ctypes") from None
class CFConst:
"""
A class object that acts as essentially a namespace for CoreFoundation
constants.
"""
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)

View file

@ -1,474 +0,0 @@
"""
Low-level helpers for the SecureTransport bindings.
These are Python functions that are not directly related to the high-level APIs
but are necessary to get them to work. They include a whole bunch of low-level
CoreFoundation messing about and memory management. The concerns in this module
are almost entirely about trying to avoid memory leaks and providing
appropriate and useful assistance to the higher-level code.
"""
from __future__ import annotations
import base64
import ctypes
import itertools
import os
import re
import ssl
import struct
import tempfile
import typing
from .bindings import ( # type: ignore[attr-defined]
CFArray,
CFConst,
CFData,
CFDictionary,
CFMutableArray,
CFString,
CFTypeRef,
CoreFoundation,
SecKeychainRef,
Security,
)
# This regular expression is used to grab PEM data out of a PEM bundle.
_PEM_CERTS_RE = re.compile(
b"-----BEGIN CERTIFICATE-----\n(.*?)\n-----END CERTIFICATE-----", re.DOTALL
)
def _cf_data_from_bytes(bytestring: bytes) -> CFData:
"""
Given a bytestring, create a CFData object from it. This CFData object must
be CFReleased by the caller.
"""
return CoreFoundation.CFDataCreate(
CoreFoundation.kCFAllocatorDefault, bytestring, len(bytestring)
)
def _cf_dictionary_from_tuples(
tuples: list[tuple[typing.Any, typing.Any]]
) -> CFDictionary:
"""
Given a list of Python tuples, create an associated CFDictionary.
"""
dictionary_size = len(tuples)
# We need to get the dictionary keys and values out in the same order.
keys = (t[0] for t in tuples)
values = (t[1] for t in tuples)
cf_keys = (CoreFoundation.CFTypeRef * dictionary_size)(*keys)
cf_values = (CoreFoundation.CFTypeRef * dictionary_size)(*values)
return CoreFoundation.CFDictionaryCreate(
CoreFoundation.kCFAllocatorDefault,
cf_keys,
cf_values,
dictionary_size,
CoreFoundation.kCFTypeDictionaryKeyCallBacks,
CoreFoundation.kCFTypeDictionaryValueCallBacks,
)
def _cfstr(py_bstr: bytes) -> CFString:
"""
Given a Python binary data, create a CFString.
The string must be CFReleased by the caller.
"""
c_str = ctypes.c_char_p(py_bstr)
cf_str = CoreFoundation.CFStringCreateWithCString(
CoreFoundation.kCFAllocatorDefault,
c_str,
CFConst.kCFStringEncodingUTF8,
)
return cf_str
def _create_cfstring_array(lst: list[bytes]) -> CFMutableArray:
"""
Given a list of Python binary data, create an associated CFMutableArray.
The array must be CFReleased by the caller.
Raises an ssl.SSLError on failure.
"""
cf_arr = None
try:
cf_arr = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
if not cf_arr:
raise MemoryError("Unable to allocate memory!")
for item in lst:
cf_str = _cfstr(item)
if not cf_str:
raise MemoryError("Unable to allocate memory!")
try:
CoreFoundation.CFArrayAppendValue(cf_arr, cf_str)
finally:
CoreFoundation.CFRelease(cf_str)
except BaseException as e:
if cf_arr:
CoreFoundation.CFRelease(cf_arr)
raise ssl.SSLError(f"Unable to allocate array: {e}") from None
return cf_arr
def _cf_string_to_unicode(value: CFString) -> str | None:
"""
Creates a Unicode string from a CFString object. Used entirely for error
reporting.
Yes, it annoys me quite a lot that this function is this complex.
"""
value_as_void_p = ctypes.cast(value, ctypes.POINTER(ctypes.c_void_p))
string = CoreFoundation.CFStringGetCStringPtr(
value_as_void_p, CFConst.kCFStringEncodingUTF8
)
if string is None:
buffer = ctypes.create_string_buffer(1024)
result = CoreFoundation.CFStringGetCString(
value_as_void_p, buffer, 1024, CFConst.kCFStringEncodingUTF8
)
if not result:
raise OSError("Error copying C string from CFStringRef")
string = buffer.value
if string is not None:
string = string.decode("utf-8")
return string # type: ignore[no-any-return]
def _assert_no_error(
error: int, exception_class: type[BaseException] | None = None
) -> None:
"""
Checks the return code and throws an exception if there is an error to
report
"""
if error == 0:
return
cf_error_string = Security.SecCopyErrorMessageString(error, None)
output = _cf_string_to_unicode(cf_error_string)
CoreFoundation.CFRelease(cf_error_string)
if output is None or output == "":
output = f"OSStatus {error}"
if exception_class is None:
exception_class = ssl.SSLError
raise exception_class(output)
def _cert_array_from_pem(pem_bundle: bytes) -> CFArray:
"""
Given a bundle of certs in PEM format, turns them into a CFArray of certs
that can be used to validate a cert chain.
"""
# Normalize the PEM bundle's line endings.
pem_bundle = pem_bundle.replace(b"\r\n", b"\n")
der_certs = [
base64.b64decode(match.group(1)) for match in _PEM_CERTS_RE.finditer(pem_bundle)
]
if not der_certs:
raise ssl.SSLError("No root certificates specified")
cert_array = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
if not cert_array:
raise ssl.SSLError("Unable to allocate memory!")
try:
for der_bytes in der_certs:
certdata = _cf_data_from_bytes(der_bytes)
if not certdata:
raise ssl.SSLError("Unable to allocate memory!")
cert = Security.SecCertificateCreateWithData(
CoreFoundation.kCFAllocatorDefault, certdata
)
CoreFoundation.CFRelease(certdata)
if not cert:
raise ssl.SSLError("Unable to build cert object!")
CoreFoundation.CFArrayAppendValue(cert_array, cert)
CoreFoundation.CFRelease(cert)
except Exception:
# We need to free the array before the exception bubbles further.
# We only want to do that if an error occurs: otherwise, the caller
# should free.
CoreFoundation.CFRelease(cert_array)
raise
return cert_array
def _is_cert(item: CFTypeRef) -> bool:
"""
Returns True if a given CFTypeRef is a certificate.
"""
expected = Security.SecCertificateGetTypeID()
return CoreFoundation.CFGetTypeID(item) == expected # type: ignore[no-any-return]
def _is_identity(item: CFTypeRef) -> bool:
"""
Returns True if a given CFTypeRef is an identity.
"""
expected = Security.SecIdentityGetTypeID()
return CoreFoundation.CFGetTypeID(item) == expected # type: ignore[no-any-return]
def _temporary_keychain() -> tuple[SecKeychainRef, str]:
"""
This function creates a temporary Mac keychain that we can use to work with
credentials. This keychain uses a one-time password and a temporary file to
store the data. We expect to have one keychain per socket. The returned
SecKeychainRef must be freed by the caller, including calling
SecKeychainDelete.
Returns a tuple of the SecKeychainRef and the path to the temporary
directory that contains it.
"""
# Unfortunately, SecKeychainCreate requires a path to a keychain. This
# means we cannot use mkstemp to use a generic temporary file. Instead,
# we're going to create a temporary directory and a filename to use there.
# This filename will be 8 random bytes expanded into base64. We also need
# some random bytes to password-protect the keychain we're creating, so we
# ask for 40 random bytes.
random_bytes = os.urandom(40)
filename = base64.b16encode(random_bytes[:8]).decode("utf-8")
password = base64.b16encode(random_bytes[8:]) # Must be valid UTF-8
tempdirectory = tempfile.mkdtemp()
keychain_path = os.path.join(tempdirectory, filename).encode("utf-8")
# We now want to create the keychain itself.
keychain = Security.SecKeychainRef()
status = Security.SecKeychainCreate(
keychain_path, len(password), password, False, None, ctypes.byref(keychain)
)
_assert_no_error(status)
# Having created the keychain, we want to pass it off to the caller.
return keychain, tempdirectory
def _load_items_from_file(
keychain: SecKeychainRef, path: str
) -> tuple[list[CFTypeRef], list[CFTypeRef]]:
"""
Given a single file, loads all the trust objects from it into arrays and
the keychain.
Returns a tuple of lists: the first list is a list of identities, the
second a list of certs.
"""
certificates = []
identities = []
result_array = None
with open(path, "rb") as f:
raw_filedata = f.read()
try:
filedata = CoreFoundation.CFDataCreate(
CoreFoundation.kCFAllocatorDefault, raw_filedata, len(raw_filedata)
)
result_array = CoreFoundation.CFArrayRef()
result = Security.SecItemImport(
filedata, # cert data
None, # Filename, leaving it out for now
None, # What the type of the file is, we don't care
None, # what's in the file, we don't care
0, # import flags
None, # key params, can include passphrase in the future
keychain, # The keychain to insert into
ctypes.byref(result_array), # Results
)
_assert_no_error(result)
# A CFArray is not very useful to us as an intermediary
# representation, so we are going to extract the objects we want
# and then free the array. We don't need to keep hold of keys: the
# keychain already has them!
result_count = CoreFoundation.CFArrayGetCount(result_array)
for index in range(result_count):
item = CoreFoundation.CFArrayGetValueAtIndex(result_array, index)
item = ctypes.cast(item, CoreFoundation.CFTypeRef)
if _is_cert(item):
CoreFoundation.CFRetain(item)
certificates.append(item)
elif _is_identity(item):
CoreFoundation.CFRetain(item)
identities.append(item)
finally:
if result_array:
CoreFoundation.CFRelease(result_array)
CoreFoundation.CFRelease(filedata)
return (identities, certificates)
def _load_client_cert_chain(keychain: SecKeychainRef, *paths: str | None) -> CFArray:
"""
Load certificates and maybe keys from a number of files. Has the end goal
of returning a CFArray containing one SecIdentityRef, and then zero or more
SecCertificateRef objects, suitable for use as a client certificate trust
chain.
"""
# Ok, the strategy.
#
# This relies on knowing that macOS will not give you a SecIdentityRef
# unless you have imported a key into a keychain. This is a somewhat
# artificial limitation of macOS (for example, it doesn't necessarily
# affect iOS), but there is nothing inside Security.framework that lets you
# get a SecIdentityRef without having a key in a keychain.
#
# So the policy here is we take all the files and iterate them in order.
# Each one will use SecItemImport to have one or more objects loaded from
# it. We will also point at a keychain that macOS can use to work with the
# private key.
#
# Once we have all the objects, we'll check what we actually have. If we
# already have a SecIdentityRef in hand, fab: we'll use that. Otherwise,
# we'll take the first certificate (which we assume to be our leaf) and
# ask the keychain to give us a SecIdentityRef with that cert's associated
# key.
#
# We'll then return a CFArray containing the trust chain: one
# SecIdentityRef and then zero-or-more SecCertificateRef objects. The
# responsibility for freeing this CFArray will be with the caller. This
# CFArray must remain alive for the entire connection, so in practice it
# will be stored with a single SSLSocket, along with the reference to the
# keychain.
certificates = []
identities = []
# Filter out bad paths.
filtered_paths = (path for path in paths if path)
try:
for file_path in filtered_paths:
new_identities, new_certs = _load_items_from_file(keychain, file_path)
identities.extend(new_identities)
certificates.extend(new_certs)
# Ok, we have everything. The question is: do we have an identity? If
# not, we want to grab one from the first cert we have.
if not identities:
new_identity = Security.SecIdentityRef()
status = Security.SecIdentityCreateWithCertificate(
keychain, certificates[0], ctypes.byref(new_identity)
)
_assert_no_error(status)
identities.append(new_identity)
# We now want to release the original certificate, as we no longer
# need it.
CoreFoundation.CFRelease(certificates.pop(0))
# We now need to build a new CFArray that holds the trust chain.
trust_chain = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
for item in itertools.chain(identities, certificates):
# ArrayAppendValue does a CFRetain on the item. That's fine,
# because the finally block will release our other refs to them.
CoreFoundation.CFArrayAppendValue(trust_chain, item)
return trust_chain
finally:
for obj in itertools.chain(identities, certificates):
CoreFoundation.CFRelease(obj)
TLS_PROTOCOL_VERSIONS = {
"SSLv2": (0, 2),
"SSLv3": (3, 0),
"TLSv1": (3, 1),
"TLSv1.1": (3, 2),
"TLSv1.2": (3, 3),
}
def _build_tls_unknown_ca_alert(version: str) -> bytes:
"""
Builds a TLS alert record for an unknown CA.
"""
ver_maj, ver_min = TLS_PROTOCOL_VERSIONS[version]
severity_fatal = 0x02
description_unknown_ca = 0x30
msg = struct.pack(">BB", severity_fatal, description_unknown_ca)
msg_len = len(msg)
record_type_alert = 0x15
record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg
return record
class SecurityConst:
"""
A class object that acts as essentially a namespace for Security constants.
"""
kSSLSessionOptionBreakOnServerAuth = 0
kSSLProtocol2 = 1
kSSLProtocol3 = 2
kTLSProtocol1 = 4
kTLSProtocol11 = 7
kTLSProtocol12 = 8
# SecureTransport does not support TLS 1.3 even if there's a constant for it
kTLSProtocol13 = 10
kTLSProtocolMaxSupported = 999
kSSLClientSide = 1
kSSLStreamType = 0
kSecFormatPEMSequence = 10
kSecTrustResultInvalid = 0
kSecTrustResultProceed = 1
# This gap is present on purpose: this was kSecTrustResultConfirm, which
# is deprecated.
kSecTrustResultDeny = 3
kSecTrustResultUnspecified = 4
kSecTrustResultRecoverableTrustFailure = 5
kSecTrustResultFatalTrustFailure = 6
kSecTrustResultOtherError = 7
errSSLProtocol = -9800
errSSLWouldBlock = -9803
errSSLClosedGraceful = -9805
errSSLClosedNoNotify = -9816
errSSLClosedAbort = -9806
errSSLXCertChainInvalid = -9807
errSSLCrypto = -9809
errSSLInternal = -9810
errSSLCertExpired = -9814
errSSLCertNotYetValid = -9815
errSSLUnknownRootCert = -9812
errSSLNoRootCert = -9813
errSSLHostNameMismatch = -9843
errSSLPeerHandshakeFail = -9824
errSSLPeerUserCancelled = -9839
errSSLWeakPeerEphemeralDHKey = -9850
errSSLServerAuthCompleted = -9841
errSSLRecordOverflow = -9847
errSecVerifyFailed = -67808
errSecNoTrustSettings = -25263
errSecItemNotFound = -25300
errSecInvalidTrustSettings = -25262

View file

@ -8,10 +8,10 @@ This needs the following packages installed:
* `pyOpenSSL`_ (tested with 16.0.0) * `pyOpenSSL`_ (tested with 16.0.0)
* `cryptography`_ (minimum 1.3.4, from pyopenssl) * `cryptography`_ (minimum 1.3.4, from pyopenssl)
* `idna`_ (minimum 2.0, from cryptography) * `idna`_ (minimum 2.0)
However, pyOpenSSL depends on cryptography, which depends on idna, so while we However, pyOpenSSL depends on cryptography, so while we use all three directly here we
use all three directly here we end up having relatively few packages required. end up having relatively few packages required.
You can install them with the following command: You can install them with the following command:
@ -40,7 +40,7 @@ like this:
from __future__ import annotations from __future__ import annotations
import OpenSSL.SSL # type: ignore[import] import OpenSSL.SSL # type: ignore[import-untyped]
from cryptography import x509 from cryptography import x509
try: try:
@ -61,13 +61,13 @@ from socket import timeout
from .. import util from .. import util
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from OpenSSL.crypto import X509 # type: ignore[import] from OpenSSL.crypto import X509 # type: ignore[import-untyped]
__all__ = ["inject_into_urllib3", "extract_from_urllib3"] __all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# Map from urllib3 to PyOpenSSL compatible parameter-values. # Map from urllib3 to PyOpenSSL compatible parameter-values.
_openssl_versions = { _openssl_versions: dict[int, int] = {
util.ssl_.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined] util.ssl_.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined]
util.ssl_.PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined] util.ssl_.PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined]
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,

View file

@ -1,913 +0,0 @@
"""
SecureTranport support for urllib3 via ctypes.
This makes platform-native TLS available to urllib3 users on macOS without the
use of a compiler. This is an important feature because the Python Package
Index is moving to become a TLSv1.2-or-higher server, and the default OpenSSL
that ships with macOS is not capable of doing TLSv1.2. The only way to resolve
this is to give macOS users an alternative solution to the problem, and that
solution is to use SecureTransport.
We use ctypes here because this solution must not require a compiler. That's
because pip is not allowed to require a compiler either.
This is not intended to be a seriously long-term solution to this problem.
The hope is that PEP 543 will eventually solve this issue for us, at which
point we can retire this contrib module. But in the short term, we need to
solve the impending tire fire that is Python on Mac without this kind of
contrib module. So...here we are.
To use this module, simply import and inject it::
import urllib3.contrib.securetransport
urllib3.contrib.securetransport.inject_into_urllib3()
Happy TLSing!
This code is a bastardised version of the code found in Will Bond's oscrypto
library. An enormous debt is owed to him for blazing this trail for us. For
that reason, this code should be considered to be covered both by urllib3's
license and by oscrypto's:
.. code-block::
Copyright (c) 2015-2016 Will Bond <will@wbond.net>
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import contextlib
import ctypes
import errno
import os.path
import shutil
import socket
import ssl
import struct
import threading
import typing
import warnings
import weakref
from socket import socket as socket_cls
from .. import util
from ._securetransport.bindings import ( # type: ignore[attr-defined]
CoreFoundation,
Security,
)
from ._securetransport.low_level import (
SecurityConst,
_assert_no_error,
_build_tls_unknown_ca_alert,
_cert_array_from_pem,
_create_cfstring_array,
_load_client_cert_chain,
_temporary_keychain,
)
warnings.warn(
"'urllib3.contrib.securetransport' module is deprecated and will be removed "
"in urllib3 v2.1.0. Read more in this issue: "
"https://github.com/urllib3/urllib3/issues/2681",
category=DeprecationWarning,
stacklevel=2,
)
if typing.TYPE_CHECKING:
from typing_extensions import Literal
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
orig_util_SSLContext = util.ssl_.SSLContext
# This dictionary is used by the read callback to obtain a handle to the
# calling wrapped socket. This is a pretty silly approach, but for now it'll
# do. I feel like I should be able to smuggle a handle to the wrapped socket
# directly in the SSLConnectionRef, but for now this approach will work I
# guess.
#
# We need to lock around this structure for inserts, but we don't do it for
# reads/writes in the callbacks. The reasoning here goes as follows:
#
# 1. It is not possible to call into the callbacks before the dictionary is
# populated, so once in the callback the id must be in the dictionary.
# 2. The callbacks don't mutate the dictionary, they only read from it, and
# so cannot conflict with any of the insertions.
#
# This is good: if we had to lock in the callbacks we'd drastically slow down
# the performance of this code.
_connection_refs: weakref.WeakValueDictionary[
int, WrappedSocket
] = weakref.WeakValueDictionary()
_connection_ref_lock = threading.Lock()
# Limit writes to 16kB. This is OpenSSL's limit, but we'll cargo-cult it over
# for no better reason than we need *a* limit, and this one is right there.
SSL_WRITE_BLOCKSIZE = 16384
# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of
# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version.
# TLSv1 to 1.2 are supported on macOS 10.8+
_protocol_to_min_max = {
util.ssl_.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12), # type: ignore[attr-defined]
util.ssl_.PROTOCOL_TLS_CLIENT: ( # type: ignore[attr-defined]
SecurityConst.kTLSProtocol1,
SecurityConst.kTLSProtocol12,
),
}
if hasattr(ssl, "PROTOCOL_SSLv2"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv2] = (
SecurityConst.kSSLProtocol2,
SecurityConst.kSSLProtocol2,
)
if hasattr(ssl, "PROTOCOL_SSLv3"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv3] = (
SecurityConst.kSSLProtocol3,
SecurityConst.kSSLProtocol3,
)
if hasattr(ssl, "PROTOCOL_TLSv1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1] = (
SecurityConst.kTLSProtocol1,
SecurityConst.kTLSProtocol1,
)
if hasattr(ssl, "PROTOCOL_TLSv1_1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_1] = (
SecurityConst.kTLSProtocol11,
SecurityConst.kTLSProtocol11,
)
if hasattr(ssl, "PROTOCOL_TLSv1_2"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_2] = (
SecurityConst.kTLSProtocol12,
SecurityConst.kTLSProtocol12,
)
_tls_version_to_st: dict[int, int] = {
ssl.TLSVersion.MINIMUM_SUPPORTED: SecurityConst.kTLSProtocol1,
ssl.TLSVersion.TLSv1: SecurityConst.kTLSProtocol1,
ssl.TLSVersion.TLSv1_1: SecurityConst.kTLSProtocol11,
ssl.TLSVersion.TLSv1_2: SecurityConst.kTLSProtocol12,
ssl.TLSVersion.MAXIMUM_SUPPORTED: SecurityConst.kTLSProtocol12,
}
def inject_into_urllib3() -> None:
"""
Monkey-patch urllib3 with SecureTransport-backed SSL-support.
"""
util.SSLContext = SecureTransportContext # type: ignore[assignment]
util.ssl_.SSLContext = SecureTransportContext # type: ignore[assignment]
util.IS_SECURETRANSPORT = True
util.ssl_.IS_SECURETRANSPORT = True
def extract_from_urllib3() -> None:
"""
Undo monkey-patching by :func:`inject_into_urllib3`.
"""
util.SSLContext = orig_util_SSLContext
util.ssl_.SSLContext = orig_util_SSLContext
util.IS_SECURETRANSPORT = False
util.ssl_.IS_SECURETRANSPORT = False
def _read_callback(
connection_id: int, data_buffer: int, data_length_pointer: bytearray
) -> int:
"""
SecureTransport read callback. This is called by ST to request that data
be returned from the socket.
"""
wrapped_socket = None
try:
wrapped_socket = _connection_refs.get(connection_id)
if wrapped_socket is None:
return SecurityConst.errSSLInternal
base_socket = wrapped_socket.socket
requested_length = data_length_pointer[0]
timeout = wrapped_socket.gettimeout()
error = None
read_count = 0
try:
while read_count < requested_length:
if timeout is None or timeout >= 0:
if not util.wait_for_read(base_socket, timeout):
raise OSError(errno.EAGAIN, "timed out")
remaining = requested_length - read_count
buffer = (ctypes.c_char * remaining).from_address(
data_buffer + read_count
)
chunk_size = base_socket.recv_into(buffer, remaining)
read_count += chunk_size
if not chunk_size:
if not read_count:
return SecurityConst.errSSLClosedGraceful
break
except OSError as e:
error = e.errno
if error is not None and error != errno.EAGAIN:
data_length_pointer[0] = read_count
if error == errno.ECONNRESET or error == errno.EPIPE:
return SecurityConst.errSSLClosedAbort
raise
data_length_pointer[0] = read_count
if read_count != requested_length:
return SecurityConst.errSSLWouldBlock
return 0
except Exception as e:
if wrapped_socket is not None:
wrapped_socket._exception = e
return SecurityConst.errSSLInternal
def _write_callback(
connection_id: int, data_buffer: int, data_length_pointer: bytearray
) -> int:
"""
SecureTransport write callback. This is called by ST to request that data
actually be sent on the network.
"""
wrapped_socket = None
try:
wrapped_socket = _connection_refs.get(connection_id)
if wrapped_socket is None:
return SecurityConst.errSSLInternal
base_socket = wrapped_socket.socket
bytes_to_write = data_length_pointer[0]
data = ctypes.string_at(data_buffer, bytes_to_write)
timeout = wrapped_socket.gettimeout()
error = None
sent = 0
try:
while sent < bytes_to_write:
if timeout is None or timeout >= 0:
if not util.wait_for_write(base_socket, timeout):
raise OSError(errno.EAGAIN, "timed out")
chunk_sent = base_socket.send(data)
sent += chunk_sent
# This has some needless copying here, but I'm not sure there's
# much value in optimising this data path.
data = data[chunk_sent:]
except OSError as e:
error = e.errno
if error is not None and error != errno.EAGAIN:
data_length_pointer[0] = sent
if error == errno.ECONNRESET or error == errno.EPIPE:
return SecurityConst.errSSLClosedAbort
raise
data_length_pointer[0] = sent
if sent != bytes_to_write:
return SecurityConst.errSSLWouldBlock
return 0
except Exception as e:
if wrapped_socket is not None:
wrapped_socket._exception = e
return SecurityConst.errSSLInternal
# We need to keep these two objects references alive: if they get GC'd while
# in use then SecureTransport could attempt to call a function that is in freed
# memory. That would be...uh...bad. Yeah, that's the word. Bad.
_read_callback_pointer = Security.SSLReadFunc(_read_callback)
_write_callback_pointer = Security.SSLWriteFunc(_write_callback)
class WrappedSocket:
"""
API-compatibility wrapper for Python's OpenSSL wrapped socket object.
"""
def __init__(self, socket: socket_cls) -> None:
self.socket = socket
self.context = None
self._io_refs = 0
self._closed = False
self._real_closed = False
self._exception: Exception | None = None
self._keychain = None
self._keychain_dir: str | None = None
self._client_cert_chain = None
# We save off the previously-configured timeout and then set it to
# zero. This is done because we use select and friends to handle the
# timeouts, but if we leave the timeout set on the lower socket then
# Python will "kindly" call select on that socket again for us. Avoid
# that by forcing the timeout to zero.
self._timeout = self.socket.gettimeout()
self.socket.settimeout(0)
@contextlib.contextmanager
def _raise_on_error(self) -> typing.Generator[None, None, None]:
"""
A context manager that can be used to wrap calls that do I/O from
SecureTransport. If any of the I/O callbacks hit an exception, this
context manager will correctly propagate the exception after the fact.
This avoids silently swallowing those exceptions.
It also correctly forces the socket closed.
"""
self._exception = None
# We explicitly don't catch around this yield because in the unlikely
# event that an exception was hit in the block we don't want to swallow
# it.
yield
if self._exception is not None:
exception, self._exception = self._exception, None
self._real_close()
raise exception
def _set_alpn_protocols(self, protocols: list[bytes] | None) -> None:
"""
Sets up the ALPN protocols on the context.
"""
if not protocols:
return
protocols_arr = _create_cfstring_array(protocols)
try:
result = Security.SSLSetALPNProtocols(self.context, protocols_arr)
_assert_no_error(result)
finally:
CoreFoundation.CFRelease(protocols_arr)
def _custom_validate(self, verify: bool, trust_bundle: bytes | None) -> None:
"""
Called when we have set custom validation. We do this in two cases:
first, when cert validation is entirely disabled; and second, when
using a custom trust DB.
Raises an SSLError if the connection is not trusted.
"""
# If we disabled cert validation, just say: cool.
if not verify or trust_bundle is None:
return
successes = (
SecurityConst.kSecTrustResultUnspecified,
SecurityConst.kSecTrustResultProceed,
)
try:
trust_result = self._evaluate_trust(trust_bundle)
if trust_result in successes:
return
reason = f"error code: {int(trust_result)}"
exc = None
except Exception as e:
# Do not trust on error
reason = f"exception: {e!r}"
exc = e
# SecureTransport does not send an alert nor shuts down the connection.
rec = _build_tls_unknown_ca_alert(self.version())
self.socket.sendall(rec)
# close the connection immediately
# l_onoff = 1, activate linger
# l_linger = 0, linger for 0 seoncds
opts = struct.pack("ii", 1, 0)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts)
self._real_close()
raise ssl.SSLError(f"certificate verify failed, {reason}") from exc
def _evaluate_trust(self, trust_bundle: bytes) -> int:
# We want data in memory, so load it up.
if os.path.isfile(trust_bundle):
with open(trust_bundle, "rb") as f:
trust_bundle = f.read()
cert_array = None
trust = Security.SecTrustRef()
try:
# Get a CFArray that contains the certs we want.
cert_array = _cert_array_from_pem(trust_bundle)
# Ok, now the hard part. We want to get the SecTrustRef that ST has
# created for this connection, shove our CAs into it, tell ST to
# ignore everything else it knows, and then ask if it can build a
# chain. This is a buuuunch of code.
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
_assert_no_error(result)
if not trust:
raise ssl.SSLError("Failed to copy trust reference")
result = Security.SecTrustSetAnchorCertificates(trust, cert_array)
_assert_no_error(result)
result = Security.SecTrustSetAnchorCertificatesOnly(trust, True)
_assert_no_error(result)
trust_result = Security.SecTrustResultType()
result = Security.SecTrustEvaluate(trust, ctypes.byref(trust_result))
_assert_no_error(result)
finally:
if trust:
CoreFoundation.CFRelease(trust)
if cert_array is not None:
CoreFoundation.CFRelease(cert_array)
return trust_result.value # type: ignore[no-any-return]
def handshake(
self,
server_hostname: bytes | str | None,
verify: bool,
trust_bundle: bytes | None,
min_version: int,
max_version: int,
client_cert: str | None,
client_key: str | None,
client_key_passphrase: typing.Any,
alpn_protocols: list[bytes] | None,
) -> None:
"""
Actually performs the TLS handshake. This is run automatically by
wrapped socket, and shouldn't be needed in user code.
"""
# First, we do the initial bits of connection setup. We need to create
# a context, set its I/O funcs, and set the connection reference.
self.context = Security.SSLCreateContext(
None, SecurityConst.kSSLClientSide, SecurityConst.kSSLStreamType
)
result = Security.SSLSetIOFuncs(
self.context, _read_callback_pointer, _write_callback_pointer
)
_assert_no_error(result)
# Here we need to compute the handle to use. We do this by taking the
# id of self modulo 2**31 - 1. If this is already in the dictionary, we
# just keep incrementing by one until we find a free space.
with _connection_ref_lock:
handle = id(self) % 2147483647
while handle in _connection_refs:
handle = (handle + 1) % 2147483647
_connection_refs[handle] = self
result = Security.SSLSetConnection(self.context, handle)
_assert_no_error(result)
# If we have a server hostname, we should set that too.
# RFC6066 Section 3 tells us not to use SNI when the host is an IP, but we have
# to do it anyway to match server_hostname against the server certificate
if server_hostname:
if not isinstance(server_hostname, bytes):
server_hostname = server_hostname.encode("utf-8")
result = Security.SSLSetPeerDomainName(
self.context, server_hostname, len(server_hostname)
)
_assert_no_error(result)
# Setup the ALPN protocols.
self._set_alpn_protocols(alpn_protocols)
# Set the minimum and maximum TLS versions.
result = Security.SSLSetProtocolVersionMin(self.context, min_version)
_assert_no_error(result)
result = Security.SSLSetProtocolVersionMax(self.context, max_version)
_assert_no_error(result)
# If there's a trust DB, we need to use it. We do that by telling
# SecureTransport to break on server auth. We also do that if we don't
# want to validate the certs at all: we just won't actually do any
# authing in that case.
if not verify or trust_bundle is not None:
result = Security.SSLSetSessionOption(
self.context, SecurityConst.kSSLSessionOptionBreakOnServerAuth, True
)
_assert_no_error(result)
# If there's a client cert, we need to use it.
if client_cert:
self._keychain, self._keychain_dir = _temporary_keychain()
self._client_cert_chain = _load_client_cert_chain(
self._keychain, client_cert, client_key
)
result = Security.SSLSetCertificate(self.context, self._client_cert_chain)
_assert_no_error(result)
while True:
with self._raise_on_error():
result = Security.SSLHandshake(self.context)
if result == SecurityConst.errSSLWouldBlock:
raise socket.timeout("handshake timed out")
elif result == SecurityConst.errSSLServerAuthCompleted:
self._custom_validate(verify, trust_bundle)
continue
else:
_assert_no_error(result)
break
def fileno(self) -> int:
return self.socket.fileno()
# Copy-pasted from Python 3.5 source code
def _decref_socketios(self) -> None:
if self._io_refs > 0:
self._io_refs -= 1
if self._closed:
self.close()
def recv(self, bufsiz: int) -> bytes:
buffer = ctypes.create_string_buffer(bufsiz)
bytes_read = self.recv_into(buffer, bufsiz)
data = buffer[:bytes_read]
return typing.cast(bytes, data)
def recv_into(
self, buffer: ctypes.Array[ctypes.c_char], nbytes: int | None = None
) -> int:
# Read short on EOF.
if self._real_closed:
return 0
if nbytes is None:
nbytes = len(buffer)
buffer = (ctypes.c_char * nbytes).from_buffer(buffer)
processed_bytes = ctypes.c_size_t(0)
with self._raise_on_error():
result = Security.SSLRead(
self.context, buffer, nbytes, ctypes.byref(processed_bytes)
)
# There are some result codes that we want to treat as "not always
# errors". Specifically, those are errSSLWouldBlock,
# errSSLClosedGraceful, and errSSLClosedNoNotify.
if result == SecurityConst.errSSLWouldBlock:
# If we didn't process any bytes, then this was just a time out.
# However, we can get errSSLWouldBlock in situations when we *did*
# read some data, and in those cases we should just read "short"
# and return.
if processed_bytes.value == 0:
# Timed out, no data read.
raise socket.timeout("recv timed out")
elif result in (
SecurityConst.errSSLClosedGraceful,
SecurityConst.errSSLClosedNoNotify,
):
# The remote peer has closed this connection. We should do so as
# well. Note that we don't actually return here because in
# principle this could actually be fired along with return data.
# It's unlikely though.
self._real_close()
else:
_assert_no_error(result)
# Ok, we read and probably succeeded. We should return whatever data
# was actually read.
return processed_bytes.value
def settimeout(self, timeout: float) -> None:
self._timeout = timeout
def gettimeout(self) -> float | None:
return self._timeout
def send(self, data: bytes) -> int:
processed_bytes = ctypes.c_size_t(0)
with self._raise_on_error():
result = Security.SSLWrite(
self.context, data, len(data), ctypes.byref(processed_bytes)
)
if result == SecurityConst.errSSLWouldBlock and processed_bytes.value == 0:
# Timed out
raise socket.timeout("send timed out")
else:
_assert_no_error(result)
# We sent, and probably succeeded. Tell them how much we sent.
return processed_bytes.value
def sendall(self, data: bytes) -> None:
total_sent = 0
while total_sent < len(data):
sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE])
total_sent += sent
def shutdown(self) -> None:
with self._raise_on_error():
Security.SSLClose(self.context)
def close(self) -> None:
self._closed = True
# TODO: should I do clean shutdown here? Do I have to?
if self._io_refs <= 0:
self._real_close()
def _real_close(self) -> None:
self._real_closed = True
if self.context:
CoreFoundation.CFRelease(self.context)
self.context = None
if self._client_cert_chain:
CoreFoundation.CFRelease(self._client_cert_chain)
self._client_cert_chain = None
if self._keychain:
Security.SecKeychainDelete(self._keychain)
CoreFoundation.CFRelease(self._keychain)
shutil.rmtree(self._keychain_dir)
self._keychain = self._keychain_dir = None
return self.socket.close()
def getpeercert(self, binary_form: bool = False) -> bytes | None:
# Urgh, annoying.
#
# Here's how we do this:
#
# 1. Call SSLCopyPeerTrust to get hold of the trust object for this
# connection.
# 2. Call SecTrustGetCertificateAtIndex for index 0 to get the leaf.
# 3. To get the CN, call SecCertificateCopyCommonName and process that
# string so that it's of the appropriate type.
# 4. To get the SAN, we need to do something a bit more complex:
# a. Call SecCertificateCopyValues to get the data, requesting
# kSecOIDSubjectAltName.
# b. Mess about with this dictionary to try to get the SANs out.
#
# This is gross. Really gross. It's going to be a few hundred LoC extra
# just to repeat something that SecureTransport can *already do*. So my
# operating assumption at this time is that what we want to do is
# instead to just flag to urllib3 that it shouldn't do its own hostname
# validation when using SecureTransport.
if not binary_form:
raise ValueError("SecureTransport only supports dumping binary certs")
trust = Security.SecTrustRef()
certdata = None
der_bytes = None
try:
# Grab the trust store.
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
_assert_no_error(result)
if not trust:
# Probably we haven't done the handshake yet. No biggie.
return None
cert_count = Security.SecTrustGetCertificateCount(trust)
if not cert_count:
# Also a case that might happen if we haven't handshaked.
# Handshook? Handshaken?
return None
leaf = Security.SecTrustGetCertificateAtIndex(trust, 0)
assert leaf
# Ok, now we want the DER bytes.
certdata = Security.SecCertificateCopyData(leaf)
assert certdata
data_length = CoreFoundation.CFDataGetLength(certdata)
data_buffer = CoreFoundation.CFDataGetBytePtr(certdata)
der_bytes = ctypes.string_at(data_buffer, data_length)
finally:
if certdata:
CoreFoundation.CFRelease(certdata)
if trust:
CoreFoundation.CFRelease(trust)
return der_bytes
def version(self) -> str:
protocol = Security.SSLProtocol()
result = Security.SSLGetNegotiatedProtocolVersion(
self.context, ctypes.byref(protocol)
)
_assert_no_error(result)
if protocol.value == SecurityConst.kTLSProtocol13:
raise ssl.SSLError("SecureTransport does not support TLS 1.3")
elif protocol.value == SecurityConst.kTLSProtocol12:
return "TLSv1.2"
elif protocol.value == SecurityConst.kTLSProtocol11:
return "TLSv1.1"
elif protocol.value == SecurityConst.kTLSProtocol1:
return "TLSv1"
elif protocol.value == SecurityConst.kSSLProtocol3:
return "SSLv3"
elif protocol.value == SecurityConst.kSSLProtocol2:
return "SSLv2"
else:
raise ssl.SSLError(f"Unknown TLS version: {protocol!r}")
def makefile(
self: socket_cls,
mode: (
Literal["r"] | Literal["w"] | Literal["rw"] | Literal["wr"] | Literal[""]
) = "r",
buffering: int | None = None,
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.BinaryIO | typing.TextIO:
# We disable buffering with SecureTransport because it conflicts with
# the buffering that ST does internally (see issue #1153 for more).
buffering = 0
return socket_cls.makefile(self, mode, buffering, *args, **kwargs)
WrappedSocket.makefile = makefile # type: ignore[attr-defined]
class SecureTransportContext:
"""
I am a wrapper class for the SecureTransport library, to translate the
interface of the standard library ``SSLContext`` object to calls into
SecureTransport.
"""
def __init__(self, protocol: int) -> None:
self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED
self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED
if protocol not in (None, ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_CLIENT):
self._min_version, self._max_version = _protocol_to_min_max[protocol]
self._options = 0
self._verify = False
self._trust_bundle: bytes | None = None
self._client_cert: str | None = None
self._client_key: str | None = None
self._client_key_passphrase = None
self._alpn_protocols: list[bytes] | None = None
@property
def check_hostname(self) -> Literal[True]:
"""
SecureTransport cannot have its hostname checking disabled. For more,
see the comment on getpeercert() in this file.
"""
return True
@check_hostname.setter
def check_hostname(self, value: typing.Any) -> None:
"""
SecureTransport cannot have its hostname checking disabled. For more,
see the comment on getpeercert() in this file.
"""
@property
def options(self) -> int:
# TODO: Well, crap.
#
# So this is the bit of the code that is the most likely to cause us
# trouble. Essentially we need to enumerate all of the SSL options that
# users might want to use and try to see if we can sensibly translate
# them, or whether we should just ignore them.
return self._options
@options.setter
def options(self, value: int) -> None:
# TODO: Update in line with above.
self._options = value
@property
def verify_mode(self) -> int:
return ssl.CERT_REQUIRED if self._verify else ssl.CERT_NONE
@verify_mode.setter
def verify_mode(self, value: int) -> None:
self._verify = value == ssl.CERT_REQUIRED
def set_default_verify_paths(self) -> None:
# So, this has to do something a bit weird. Specifically, what it does
# is nothing.
#
# This means that, if we had previously had load_verify_locations
# called, this does not undo that. We need to do that because it turns
# out that the rest of the urllib3 code will attempt to load the
# default verify paths if it hasn't been told about any paths, even if
# the context itself was sometime earlier. We resolve that by just
# ignoring it.
pass
def load_default_certs(self) -> None:
return self.set_default_verify_paths()
def set_ciphers(self, ciphers: typing.Any) -> None:
raise ValueError("SecureTransport doesn't support custom cipher strings")
def load_verify_locations(
self,
cafile: str | None = None,
capath: str | None = None,
cadata: bytes | None = None,
) -> None:
# OK, we only really support cadata and cafile.
if capath is not None:
raise ValueError("SecureTransport does not support cert directories")
# Raise if cafile does not exist.
if cafile is not None:
with open(cafile):
pass
self._trust_bundle = cafile or cadata # type: ignore[assignment]
def load_cert_chain(
self,
certfile: str,
keyfile: str | None = None,
password: str | None = None,
) -> None:
self._client_cert = certfile
self._client_key = keyfile
self._client_cert_passphrase = password
def set_alpn_protocols(self, protocols: list[str | bytes]) -> None:
"""
Sets the ALPN protocols that will later be set on the context.
Raises a NotImplementedError if ALPN is not supported.
"""
if not hasattr(Security, "SSLSetALPNProtocols"):
raise NotImplementedError(
"SecureTransport supports ALPN only in macOS 10.12+"
)
self._alpn_protocols = [util.util.to_bytes(p, "ascii") for p in protocols]
def wrap_socket(
self,
sock: socket_cls,
server_side: bool = False,
do_handshake_on_connect: bool = True,
suppress_ragged_eofs: bool = True,
server_hostname: bytes | str | None = None,
) -> WrappedSocket:
# So, what do we do here? Firstly, we assert some properties. This is a
# stripped down shim, so there is some functionality we don't support.
# See PEP 543 for the real deal.
assert not server_side
assert do_handshake_on_connect
assert suppress_ragged_eofs
# Ok, we're good to go. Now we want to create the wrapped socket object
# and store it in the appropriate place.
wrapped_socket = WrappedSocket(sock)
# Now we can handshake
wrapped_socket.handshake(
server_hostname,
self._verify,
self._trust_bundle,
_tls_version_to_st[self._minimum_version],
_tls_version_to_st[self._maximum_version],
self._client_cert,
self._client_key,
self._client_key_passphrase,
self._alpn_protocols,
)
return wrapped_socket
@property
def minimum_version(self) -> int:
return self._minimum_version
@minimum_version.setter
def minimum_version(self, minimum_version: int) -> None:
self._minimum_version = minimum_version
@property
def maximum_version(self) -> int:
return self._maximum_version
@maximum_version.setter
def maximum_version(self, maximum_version: int) -> None:
self._maximum_version = maximum_version

View file

@ -41,7 +41,7 @@ with the proxy:
from __future__ import annotations from __future__ import annotations
try: try:
import socks # type: ignore[import] import socks # type: ignore[import-not-found]
except ImportError: except ImportError:
import warnings import warnings
@ -51,7 +51,7 @@ except ImportError:
( (
"SOCKS support in urllib3 requires the installation of optional " "SOCKS support in urllib3 requires the installation of optional "
"dependencies: specifically, PySocks. For more information, see " "dependencies: specifically, PySocks. For more information, see "
"https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies" "https://urllib3.readthedocs.io/en/latest/advanced-usage.html#socks-proxies"
), ),
DependencyWarning, DependencyWarning,
) )
@ -71,10 +71,10 @@ try:
except ImportError: except ImportError:
ssl = None # type: ignore[assignment] ssl = None # type: ignore[assignment]
try: from typing import TypedDict
from typing import TypedDict
class _TYPE_SOCKS_OPTIONS(TypedDict):
class _TYPE_SOCKS_OPTIONS(TypedDict):
socks_version: int socks_version: int
proxy_host: str | None proxy_host: str | None
proxy_port: str | None proxy_port: str | None
@ -82,9 +82,6 @@ try:
password: str | None password: str | None
rdns: bool rdns: bool
except ImportError: # Python 3.7
_TYPE_SOCKS_OPTIONS = typing.Dict[str, typing.Any] # type: ignore[misc, assignment]
class SOCKSConnection(HTTPConnection): class SOCKSConnection(HTTPConnection):
""" """

View file

@ -252,13 +252,16 @@ class IncompleteRead(HTTPError, httplib_IncompleteRead):
for ``partial`` to avoid creating large objects on streamed reads. for ``partial`` to avoid creating large objects on streamed reads.
""" """
partial: int # type: ignore[assignment]
expected: int
def __init__(self, partial: int, expected: int) -> None: def __init__(self, partial: int, expected: int) -> None:
self.partial = partial # type: ignore[assignment] self.partial = partial
self.expected = expected self.expected = expected
def __repr__(self) -> str: def __repr__(self) -> str:
return "IncompleteRead(%i bytes read, %i more expected)" % ( return "IncompleteRead(%i bytes read, %i more expected)" % (
self.partial, # type: ignore[str-format] self.partial,
self.expected, self.expected,
) )

View file

@ -225,13 +225,9 @@ class RequestField:
if isinstance(value, tuple): if isinstance(value, tuple):
if len(value) == 3: if len(value) == 3:
filename, data, content_type = typing.cast( filename, data, content_type = value
typing.Tuple[str, _TYPE_FIELD_VALUE, str], value
)
else: else:
filename, data = typing.cast( filename, data = value
typing.Tuple[str, _TYPE_FIELD_VALUE], value
)
content_type = guess_content_type(filename) content_type = guess_content_type(filename)
else: else:
filename = None filename = None

229
lib/urllib3/http2.py Normal file
View file

@ -0,0 +1,229 @@
from __future__ import annotations
import threading
import types
import typing
import h2.config # type: ignore[import-untyped]
import h2.connection # type: ignore[import-untyped]
import h2.events # type: ignore[import-untyped]
import urllib3.connection
import urllib3.util.ssl_
from urllib3.response import BaseHTTPResponse
from ._collections import HTTPHeaderDict
from .connection import HTTPSConnection
from .connectionpool import HTTPSConnectionPool
orig_HTTPSConnection = HTTPSConnection
T = typing.TypeVar("T")
class _LockedObject(typing.Generic[T]):
"""
A wrapper class that hides a specific object behind a lock.
The goal here is to provide a simple way to protect access to an object
that cannot safely be simultaneously accessed from multiple threads. The
intended use of this class is simple: take hold of it with a context
manager, which returns the protected object.
"""
def __init__(self, obj: T):
self.lock = threading.RLock()
self._obj = obj
def __enter__(self) -> T:
self.lock.acquire()
return self._obj
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
self.lock.release()
class HTTP2Connection(HTTPSConnection):
def __init__(
self, host: str, port: int | None = None, **kwargs: typing.Any
) -> None:
self._h2_conn = self._new_h2_conn()
self._h2_stream: int | None = None
self._h2_headers: list[tuple[bytes, bytes]] = []
if "proxy" in kwargs or "proxy_config" in kwargs: # Defensive:
raise NotImplementedError("Proxies aren't supported with HTTP/2")
super().__init__(host, port, **kwargs)
def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]:
config = h2.config.H2Configuration(client_side=True)
return _LockedObject(h2.connection.H2Connection(config=config))
def connect(self) -> None:
super().connect()
with self._h2_conn as h2_conn:
h2_conn.initiate_connection()
self.sock.sendall(h2_conn.data_to_send())
def putrequest(
self,
method: str,
url: str,
skip_host: bool = False,
skip_accept_encoding: bool = False,
) -> None:
with self._h2_conn as h2_conn:
self._request_url = url
self._h2_stream = h2_conn.get_next_available_stream_id()
if ":" in self.host:
authority = f"[{self.host}]:{self.port or 443}"
else:
authority = f"{self.host}:{self.port or 443}"
self._h2_headers.extend(
(
(b":scheme", b"https"),
(b":method", method.encode()),
(b":authority", authority.encode()),
(b":path", url.encode()),
)
)
def putheader(self, header: str, *values: str) -> None: # type: ignore[override]
for value in values:
self._h2_headers.append(
(header.encode("utf-8").lower(), value.encode("utf-8"))
)
def endheaders(self) -> None: # type: ignore[override]
with self._h2_conn as h2_conn:
h2_conn.send_headers(
stream_id=self._h2_stream,
headers=self._h2_headers,
end_stream=True,
)
if data_to_send := h2_conn.data_to_send():
self.sock.sendall(data_to_send)
def send(self, data: bytes) -> None: # type: ignore[override] # Defensive:
if not data:
return
raise NotImplementedError("Sending data isn't supported yet")
def getresponse( # type: ignore[override]
self,
) -> HTTP2Response:
status = None
data = bytearray()
with self._h2_conn as h2_conn:
end_stream = False
while not end_stream:
# TODO: Arbitrary read value.
if received_data := self.sock.recv(65535):
events = h2_conn.receive_data(received_data)
for event in events:
if isinstance(event, h2.events.ResponseReceived):
headers = HTTPHeaderDict()
for header, value in event.headers:
if header == b":status":
status = int(value.decode())
else:
headers.add(
header.decode("ascii"), value.decode("ascii")
)
elif isinstance(event, h2.events.DataReceived):
data += event.data
h2_conn.acknowledge_received_data(
event.flow_controlled_length, event.stream_id
)
elif isinstance(event, h2.events.StreamEnded):
end_stream = True
if data_to_send := h2_conn.data_to_send():
self.sock.sendall(data_to_send)
# We always close to not have to handle connection management.
self.close()
assert status is not None
return HTTP2Response(
status=status,
headers=headers,
request_url=self._request_url,
data=bytes(data),
)
def close(self) -> None:
with self._h2_conn as h2_conn:
try:
h2_conn.close_connection()
if data := h2_conn.data_to_send():
self.sock.sendall(data)
except Exception:
pass
# Reset all our HTTP/2 connection state.
self._h2_conn = self._new_h2_conn()
self._h2_stream = None
self._h2_headers = []
super().close()
class HTTP2Response(BaseHTTPResponse):
# TODO: This is a woefully incomplete response object, but works for non-streaming.
def __init__(
self,
status: int,
headers: HTTPHeaderDict,
request_url: str,
data: bytes,
decode_content: bool = False, # TODO: support decoding
) -> None:
super().__init__(
status=status,
headers=headers,
# Following CPython, we map HTTP versions to major * 10 + minor integers
version=20,
# No reason phrase in HTTP/2
reason=None,
decode_content=decode_content,
request_url=request_url,
)
self._data = data
self.length_remaining = 0
@property
def data(self) -> bytes:
return self._data
def get_redirect_location(self) -> None:
return None
def close(self) -> None:
pass
def inject_into_urllib3() -> None:
HTTPSConnectionPool.ConnectionCls = HTTP2Connection
urllib3.connection.HTTPSConnection = HTTP2Connection # type: ignore[misc]
# TODO: Offer 'http/1.1' as well, but for testing purposes this is handy.
urllib3.util.ssl_.ALPN_PROTOCOLS = ["h2"]
def extract_from_urllib3() -> None:
HTTPSConnectionPool.ConnectionCls = orig_HTTPSConnection
urllib3.connection.HTTPSConnection = orig_HTTPSConnection # type: ignore[misc]
urllib3.util.ssl_.ALPN_PROTOCOLS = ["http/1.1"]

View file

@ -26,8 +26,7 @@ from .util.url import Url, parse_url
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
import ssl import ssl
from typing import Literal
from typing_extensions import Literal
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"] __all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
@ -39,6 +38,7 @@ SSL_KEYWORDS = (
"cert_file", "cert_file",
"cert_reqs", "cert_reqs",
"ca_certs", "ca_certs",
"ca_cert_data",
"ssl_version", "ssl_version",
"ssl_minimum_version", "ssl_minimum_version",
"ssl_maximum_version", "ssl_maximum_version",
@ -74,6 +74,7 @@ class PoolKey(typing.NamedTuple):
key_cert_file: str | None key_cert_file: str | None
key_cert_reqs: str | None key_cert_reqs: str | None
key_ca_certs: str | None key_ca_certs: str | None
key_ca_cert_data: str | bytes | None
key_ssl_version: int | str | None key_ssl_version: int | str | None
key_ssl_minimum_version: ssl.TLSVersion | None key_ssl_minimum_version: ssl.TLSVersion | None
key_ssl_maximum_version: ssl.TLSVersion | None key_ssl_maximum_version: ssl.TLSVersion | None

View file

@ -14,16 +14,19 @@ from http.client import HTTPMessage as _HttplibHTTPMessage
from http.client import HTTPResponse as _HttplibHTTPResponse from http.client import HTTPResponse as _HttplibHTTPResponse
from socket import timeout as SocketTimeout from socket import timeout as SocketTimeout
if typing.TYPE_CHECKING:
from ._base_connection import BaseHTTPConnection
try: try:
try: try:
import brotlicffi as brotli # type: ignore[import] import brotlicffi as brotli # type: ignore[import-not-found]
except ImportError: except ImportError:
import brotli # type: ignore[import] import brotli # type: ignore[import-not-found]
except ImportError: except ImportError:
brotli = None brotli = None
try: try:
import zstandard as zstd # type: ignore[import] import zstandard as zstd # type: ignore[import-not-found]
# The package 'zstandard' added the 'eof' property starting # The package 'zstandard' added the 'eof' property starting
# in v0.18.0 which we require to ensure a complete and # in v0.18.0 which we require to ensure a complete and
@ -58,7 +61,7 @@ from .util.response import is_fp_closed, is_response_to_head
from .util.retry import Retry from .util.retry import Retry
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing_extensions import Literal from typing import Literal
from .connectionpool import HTTPConnectionPool from .connectionpool import HTTPConnectionPool
@ -208,7 +211,9 @@ def _get_decoder(mode: str) -> ContentDecoder:
if "," in mode: if "," in mode:
return MultiDecoder(mode) return MultiDecoder(mode)
if mode == "gzip": # According to RFC 9110 section 8.4.1.3, recipients should
# consider x-gzip equivalent to gzip
if mode in ("gzip", "x-gzip"):
return GzipDecoder() return GzipDecoder()
if brotli is not None and mode == "br": if brotli is not None and mode == "br":
@ -278,9 +283,23 @@ class BytesQueueBuffer:
return ret.getvalue() return ret.getvalue()
def get_all(self) -> bytes:
buffer = self.buffer
if not buffer:
assert self._size == 0
return b""
if len(buffer) == 1:
result = buffer.pop()
else:
ret = io.BytesIO()
ret.writelines(buffer.popleft() for _ in range(len(buffer)))
result = ret.getvalue()
self._size = 0
return result
class BaseHTTPResponse(io.IOBase): class BaseHTTPResponse(io.IOBase):
CONTENT_DECODERS = ["gzip", "deflate"] CONTENT_DECODERS = ["gzip", "x-gzip", "deflate"]
if brotli is not None: if brotli is not None:
CONTENT_DECODERS += ["br"] CONTENT_DECODERS += ["br"]
if zstd is not None: if zstd is not None:
@ -325,6 +344,7 @@ class BaseHTTPResponse(io.IOBase):
self.chunked = True self.chunked = True
self._decoder: ContentDecoder | None = None self._decoder: ContentDecoder | None = None
self.length_remaining: int | None
def get_redirect_location(self) -> str | None | Literal[False]: def get_redirect_location(self) -> str | None | Literal[False]:
""" """
@ -364,7 +384,7 @@ class BaseHTTPResponse(io.IOBase):
raise NotImplementedError() raise NotImplementedError()
@property @property
def connection(self) -> HTTPConnection | None: def connection(self) -> BaseHTTPConnection | None:
raise NotImplementedError() raise NotImplementedError()
@property @property
@ -391,6 +411,13 @@ class BaseHTTPResponse(io.IOBase):
) -> bytes: ) -> bytes:
raise NotImplementedError() raise NotImplementedError()
def read1(
self,
amt: int | None = None,
decode_content: bool | None = None,
) -> bytes:
raise NotImplementedError()
def read_chunked( def read_chunked(
self, self,
amt: int | None = None, amt: int | None = None,
@ -722,8 +749,18 @@ class HTTPResponse(BaseHTTPResponse):
raise ReadTimeoutError(self._pool, None, "Read timed out.") from e # type: ignore[arg-type] raise ReadTimeoutError(self._pool, None, "Read timed out.") from e # type: ignore[arg-type]
except IncompleteRead as e:
if (
e.expected is not None
and e.partial is not None
and e.expected == -e.partial
):
arg = "Response may not contain content."
else:
arg = f"Connection broken: {e!r}"
raise ProtocolError(arg, e) from e
except (HTTPException, OSError) as e: except (HTTPException, OSError) as e:
# This includes IncompleteRead.
raise ProtocolError(f"Connection broken: {e!r}", e) from e raise ProtocolError(f"Connection broken: {e!r}", e) from e
# If no exception is thrown, we should avoid cleaning up # If no exception is thrown, we should avoid cleaning up
@ -750,7 +787,12 @@ class HTTPResponse(BaseHTTPResponse):
if self._original_response and self._original_response.isclosed(): if self._original_response and self._original_response.isclosed():
self.release_conn() self.release_conn()
def _fp_read(self, amt: int | None = None) -> bytes: def _fp_read(
self,
amt: int | None = None,
*,
read1: bool = False,
) -> bytes:
""" """
Read a response with the thought that reading the number of bytes Read a response with the thought that reading the number of bytes
larger than can fit in a 32-bit int at a time via SSL in some larger than can fit in a 32-bit int at a time via SSL in some
@ -767,13 +809,15 @@ class HTTPResponse(BaseHTTPResponse):
assert self._fp assert self._fp
c_int_max = 2**31 - 1 c_int_max = 2**31 - 1
if ( if (
(
(amt and amt > c_int_max) (amt and amt > c_int_max)
or (self.length_remaining and self.length_remaining > c_int_max) or (
amt is None
and self.length_remaining
and self.length_remaining > c_int_max
) )
and not util.IS_SECURETRANSPORT ) and (util.IS_PYOPENSSL or sys.version_info < (3, 10)):
and (util.IS_PYOPENSSL or sys.version_info < (3, 10)) if read1:
): return self._fp.read1(c_int_max)
buffer = io.BytesIO() buffer = io.BytesIO()
# Besides `max_chunk_amt` being a maximum chunk size, it # Besides `max_chunk_amt` being a maximum chunk size, it
# affects memory overhead of reading a response by this # affects memory overhead of reading a response by this
@ -794,6 +838,8 @@ class HTTPResponse(BaseHTTPResponse):
buffer.write(data) buffer.write(data)
del data # to reduce peak memory usage by `max_chunk_amt`. del data # to reduce peak memory usage by `max_chunk_amt`.
return buffer.getvalue() return buffer.getvalue()
elif read1:
return self._fp.read1(amt) if amt is not None else self._fp.read1()
else: else:
# StringIO doesn't like amt=None # StringIO doesn't like amt=None
return self._fp.read(amt) if amt is not None else self._fp.read() return self._fp.read(amt) if amt is not None else self._fp.read()
@ -801,6 +847,8 @@ class HTTPResponse(BaseHTTPResponse):
def _raw_read( def _raw_read(
self, self,
amt: int | None = None, amt: int | None = None,
*,
read1: bool = False,
) -> bytes: ) -> bytes:
""" """
Reads `amt` of bytes from the socket. Reads `amt` of bytes from the socket.
@ -811,7 +859,7 @@ class HTTPResponse(BaseHTTPResponse):
fp_closed = getattr(self._fp, "closed", False) fp_closed = getattr(self._fp, "closed", False)
with self._error_catcher(): with self._error_catcher():
data = self._fp_read(amt) if not fp_closed else b"" data = self._fp_read(amt, read1=read1) if not fp_closed else b""
if amt is not None and amt != 0 and not data: if amt is not None and amt != 0 and not data:
# Platform-specific: Buggy versions of Python. # Platform-specific: Buggy versions of Python.
# Close the connection when no data is returned # Close the connection when no data is returned
@ -833,6 +881,14 @@ class HTTPResponse(BaseHTTPResponse):
# raised during streaming, so all calls with incorrect # raised during streaming, so all calls with incorrect
# Content-Length are caught. # Content-Length are caught.
raise IncompleteRead(self._fp_bytes_read, self.length_remaining) raise IncompleteRead(self._fp_bytes_read, self.length_remaining)
elif read1 and (
(amt != 0 and not data) or self.length_remaining == len(data)
):
# All data has been read, but `self._fp.read1` in
# CPython 3.12 and older doesn't always close
# `http.client.HTTPResponse`, so we close it here.
# See https://github.com/python/cpython/issues/113199
self._fp.close()
if data: if data:
self._fp_bytes_read += len(data) self._fp_bytes_read += len(data)
@ -911,6 +967,57 @@ class HTTPResponse(BaseHTTPResponse):
return data return data
def read1(
self,
amt: int | None = None,
decode_content: bool | None = None,
) -> bytes:
"""
Similar to ``http.client.HTTPResponse.read1`` and documented
in :meth:`io.BufferedReader.read1`, but with an additional parameter:
``decode_content``.
:param amt:
How much of the content to read.
:param decode_content:
If True, will attempt to decode the body based on the
'content-encoding' header.
"""
if decode_content is None:
decode_content = self.decode_content
# try and respond without going to the network
if self._has_decoded_content:
if not decode_content:
raise RuntimeError(
"Calling read1(decode_content=False) is not supported after "
"read1(decode_content=True) was called."
)
if len(self._decoded_buffer) > 0:
if amt is None:
return self._decoded_buffer.get_all()
return self._decoded_buffer.get(amt)
if amt == 0:
return b""
# FIXME, this method's type doesn't say returning None is possible
data = self._raw_read(amt, read1=True)
if not decode_content or data is None:
return data
self._init_decoder()
while True:
flush_decoder = not data
decoded_data = self._decode(data, decode_content, flush_decoder)
self._decoded_buffer.put(decoded_data)
if decoded_data or flush_decoder:
break
data = self._raw_read(8192, read1=True)
if amt is None:
return self._decoded_buffer.get_all()
return self._decoded_buffer.get(amt)
def stream( def stream(
self, amt: int | None = 2**16, decode_content: bool | None = None self, amt: int | None = 2**16, decode_content: bool | None = None
) -> typing.Generator[bytes, None, None]: ) -> typing.Generator[bytes, None, None]:
@ -1003,9 +1110,13 @@ class HTTPResponse(BaseHTTPResponse):
try: try:
self.chunk_left = int(line, 16) self.chunk_left = int(line, 16)
except ValueError: except ValueError:
# Invalid chunked protocol response, abort.
self.close() self.close()
if line:
# Invalid chunked protocol response, abort.
raise InvalidChunkLength(self, line) from None raise InvalidChunkLength(self, line) from None
else:
# Truncated at start of next chunk
raise ProtocolError("Response ended prematurely") from None
def _handle_chunk(self, amt: int | None) -> bytes: def _handle_chunk(self, amt: int | None) -> bytes:
returned_chunk = None returned_chunk = None

View file

@ -8,7 +8,6 @@ from .retry import Retry
from .ssl_ import ( from .ssl_ import (
ALPN_PROTOCOLS, ALPN_PROTOCOLS,
IS_PYOPENSSL, IS_PYOPENSSL,
IS_SECURETRANSPORT,
SSLContext, SSLContext,
assert_fingerprint, assert_fingerprint,
create_urllib3_context, create_urllib3_context,
@ -22,7 +21,6 @@ from .wait import wait_for_read, wait_for_write
__all__ = ( __all__ = (
"IS_PYOPENSSL", "IS_PYOPENSSL",
"IS_SECURETRANSPORT",
"SSLContext", "SSLContext",
"ALPN_PROTOCOLS", "ALPN_PROTOCOLS",
"Retry", "Retry",

View file

@ -9,7 +9,7 @@ from ..exceptions import UnrewindableBodyError
from .util import to_bytes from .util import to_bytes
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing_extensions import Final from typing import Final
# Pass as a value within ``headers`` to skip # Pass as a value within ``headers`` to skip
# emitting some HTTP headers that are added automatically. # emitting some HTTP headers that are added automatically.
@ -21,15 +21,15 @@ SKIPPABLE_HEADERS = frozenset(["accept-encoding", "host", "user-agent"])
ACCEPT_ENCODING = "gzip,deflate" ACCEPT_ENCODING = "gzip,deflate"
try: try:
try: try:
import brotlicffi as _unused_module_brotli # type: ignore[import] # noqa: F401 import brotlicffi as _unused_module_brotli # type: ignore[import-not-found] # noqa: F401
except ImportError: except ImportError:
import brotli as _unused_module_brotli # type: ignore[import] # noqa: F401 import brotli as _unused_module_brotli # type: ignore[import-not-found] # noqa: F401
except ImportError: except ImportError:
pass pass
else: else:
ACCEPT_ENCODING += ",br" ACCEPT_ENCODING += ",br"
try: try:
import zstandard as _unused_module_zstd # type: ignore[import] # noqa: F401 import zstandard as _unused_module_zstd # type: ignore[import-not-found] # noqa: F401
except ImportError: except ImportError:
pass pass
else: else:

View file

@ -16,7 +16,6 @@ SSLContext = None
SSLTransport = None SSLTransport = None
HAS_NEVER_CHECK_COMMON_NAME = False HAS_NEVER_CHECK_COMMON_NAME = False
IS_PYOPENSSL = False IS_PYOPENSSL = False
IS_SECURETRANSPORT = False
ALPN_PROTOCOLS = ["http/1.1"] ALPN_PROTOCOLS = ["http/1.1"]
_TYPE_VERSION_INFO = typing.Tuple[int, int, int, str, int] _TYPE_VERSION_INFO = typing.Tuple[int, int, int, str, int]
@ -42,7 +41,7 @@ def _is_bpo_43522_fixed(
""" """
if implementation_name == "pypy": if implementation_name == "pypy":
# https://foss.heptapod.net/pypy/pypy/-/issues/3129 # https://foss.heptapod.net/pypy/pypy/-/issues/3129
return pypy_version_info >= (7, 3, 8) and version_info >= (3, 8) # type: ignore[operator] return pypy_version_info >= (7, 3, 8) # type: ignore[operator]
elif implementation_name == "cpython": elif implementation_name == "cpython":
major_minor = version_info[:2] major_minor = version_info[:2]
micro = version_info[2] micro = version_info[2]
@ -79,8 +78,7 @@ def _is_has_never_check_common_name_reliable(
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from ssl import VerifyMode from ssl import VerifyMode
from typing import Literal, TypedDict
from typing_extensions import Literal, TypedDict
from .ssltransport import SSLTransport as SSLTransportType from .ssltransport import SSLTransport as SSLTransportType
@ -321,13 +319,9 @@ def create_urllib3_context(
# Enable post-handshake authentication for TLS 1.3, see GH #1634. PHA is # Enable post-handshake authentication for TLS 1.3, see GH #1634. PHA is
# necessary for conditional client cert authentication with TLS 1.3. # necessary for conditional client cert authentication with TLS 1.3.
# The attribute is None for OpenSSL <= 1.1.0 or does not exist in older # The attribute is None for OpenSSL <= 1.1.0 or does not exist when using
# versions of Python. We only enable on Python 3.7.4+ or if certificate # an SSLContext created by pyOpenSSL.
# verification is enabled to work around Python issue #37428 if getattr(context, "post_handshake_auth", None) is not None:
# See: https://bugs.python.org/issue37428
if (cert_reqs == ssl.CERT_REQUIRED or sys.version_info >= (3, 7, 4)) and getattr(
context, "post_handshake_auth", None
) is not None:
context.post_handshake_auth = True context.post_handshake_auth = True
# The order of the below lines setting verify_mode and check_hostname # The order of the below lines setting verify_mode and check_hostname

View file

@ -8,7 +8,7 @@ import typing
from ..exceptions import ProxySchemeUnsupported from ..exceptions import ProxySchemeUnsupported
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing_extensions import Literal from typing import Literal
from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT

View file

@ -8,7 +8,7 @@ from socket import getdefaulttimeout
from ..exceptions import TimeoutStateError from ..exceptions import TimeoutStateError
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing_extensions import Final from typing import Final
class _TYPE_DEFAULT(Enum): class _TYPE_DEFAULT(Enum):
@ -101,10 +101,6 @@ class Timeout:
the case; if a server streams one byte every fifteen seconds, a timeout the case; if a server streams one byte every fifteen seconds, a timeout
of 20 seconds will not trigger, even though the request will take of 20 seconds will not trigger, even though the request will take
several minutes to complete. several minutes to complete.
If your goal is to cut off any request after a set amount of wall clock
time, consider having a second "watcher" thread to cut off a slow
request.
""" """
#: A sentinel object representing the default timeout value #: A sentinel object representing the default timeout value