]> jfr.im git - yt-dlp.git/blame - yt_dlp/networking/_helper.py
[rh:requests] Add handler for `requests` HTTP library (#3668)
[yt-dlp.git] / yt_dlp / networking / _helper.py
CommitLineData
c365dba8 1from __future__ import annotations
2
3import contextlib
227bf1a3 4import functools
20fbbd92 5import socket
c365dba8 6import ssl
7import sys
227bf1a3 8import typing
c365dba8 9import urllib.parse
227bf1a3 10import urllib.request
c365dba8 11
227bf1a3 12from .exceptions import RequestError, UnsupportedRequest
c365dba8 13from ..dependencies import certifi
8a8b5452 14from ..socks import ProxyType, sockssocket
227bf1a3 15from ..utils import format_field, traverse_obj
16
17if typing.TYPE_CHECKING:
18 from collections.abc import Iterable
19
20 from ..utils.networking import HTTPHeaderDict
c365dba8 21
22
23def ssl_load_certs(context: ssl.SSLContext, use_certifi=True):
24 if certifi and use_certifi:
25 context.load_verify_locations(cafile=certifi.where())
26 else:
27 try:
28 context.load_default_certs()
29 # Work around the issue in load_default_certs when there are bad certificates. See:
30 # https://github.com/yt-dlp/yt-dlp/issues/1060,
31 # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312
32 except ssl.SSLError:
33 # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
34 if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
35 for storename in ('CA', 'ROOT'):
227bf1a3 36 ssl_load_windows_store_certs(context, storename)
c365dba8 37 context.set_default_verify_paths()
38
39
227bf1a3 40def ssl_load_windows_store_certs(ssl_context, storename):
c365dba8 41 # Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py
42 try:
43 certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename)
44 if encoding == 'x509_asn' and (
45 trust is True or ssl.Purpose.SERVER_AUTH.oid in trust)]
46 except PermissionError:
47 return
48 for cert in certs:
49 with contextlib.suppress(ssl.SSLError):
50 ssl_context.load_verify_locations(cadata=cert)
51
52
53def make_socks_proxy_opts(socks_proxy):
54 url_components = urllib.parse.urlparse(socks_proxy)
55 if url_components.scheme.lower() == 'socks5':
56 socks_type = ProxyType.SOCKS5
227bf1a3 57 rdns = False
58 elif url_components.scheme.lower() == 'socks5h':
59 socks_type = ProxyType.SOCKS5
60 rdns = True
61 elif url_components.scheme.lower() == 'socks4':
c365dba8 62 socks_type = ProxyType.SOCKS4
227bf1a3 63 rdns = False
c365dba8 64 elif url_components.scheme.lower() == 'socks4a':
65 socks_type = ProxyType.SOCKS4A
227bf1a3 66 rdns = True
67 else:
68 raise ValueError(f'Unknown SOCKS proxy version: {url_components.scheme.lower()}')
c365dba8 69
70 def unquote_if_non_empty(s):
71 if not s:
72 return s
73 return urllib.parse.unquote_plus(s)
74 return {
75 'proxytype': socks_type,
76 'addr': url_components.hostname,
77 'port': url_components.port or 1080,
227bf1a3 78 'rdns': rdns,
c365dba8 79 'username': unquote_if_non_empty(url_components.username),
80 'password': unquote_if_non_empty(url_components.password),
81 }
82
83
227bf1a3 84def select_proxy(url, proxies):
85 """Unified proxy selector for all backends"""
86 url_components = urllib.parse.urlparse(url)
87 if 'no' in proxies:
88 hostport = url_components.hostname + format_field(url_components.port, None, ':%s')
89 if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}):
90 return
91 elif urllib.request.proxy_bypass(hostport): # check system settings
92 return
93
94 return traverse_obj(proxies, url_components.scheme or 'http', 'all')
95
96
c365dba8 97def get_redirect_method(method, status):
98 """Unified redirect method handling"""
99
100 # A 303 must either use GET or HEAD for subsequent request
101 # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4
102 if status == 303 and method != 'HEAD':
103 method = 'GET'
104 # 301 and 302 redirects are commonly turned into a GET from a POST
105 # for subsequent requests by browsers, so we'll do the same.
106 # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2
107 # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3
108 if status in (301, 302) and method == 'POST':
109 method = 'GET'
110 return method
111
112
113def make_ssl_context(
114 verify=True,
115 client_certificate=None,
116 client_certificate_key=None,
117 client_certificate_password=None,
118 legacy_support=False,
119 use_certifi=True,
120):
121 context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
122 context.check_hostname = verify
123 context.verify_mode = ssl.CERT_REQUIRED if verify else ssl.CERT_NONE
124
125 # Some servers may reject requests if ALPN extension is not sent. See:
126 # https://github.com/python/cpython/issues/85140
127 # https://github.com/yt-dlp/yt-dlp/issues/3878
128 with contextlib.suppress(NotImplementedError):
129 context.set_alpn_protocols(['http/1.1'])
130 if verify:
131 ssl_load_certs(context, use_certifi)
132
133 if legacy_support:
134 context.options |= 4 # SSL_OP_LEGACY_SERVER_CONNECT
135 context.set_ciphers('DEFAULT') # compat
136
137 elif ssl.OPENSSL_VERSION_INFO >= (1, 1, 1) and not ssl.OPENSSL_VERSION.startswith('LibreSSL'):
138 # Use the default SSL ciphers and minimum TLS version settings from Python 3.10 [1].
139 # This is to ensure consistent behavior across Python versions and libraries, and help avoid fingerprinting
140 # in some situations [2][3].
141 # Python 3.10 only supports OpenSSL 1.1.1+ [4]. Because this change is likely
142 # untested on older versions, we only apply this to OpenSSL 1.1.1+ to be safe.
143 # LibreSSL is excluded until further investigation due to cipher support issues [5][6].
144 # 1. https://github.com/python/cpython/commit/e983252b516edb15d4338b0a47631b59ef1e2536
145 # 2. https://github.com/yt-dlp/yt-dlp/issues/4627
146 # 3. https://github.com/yt-dlp/yt-dlp/pull/5294
147 # 4. https://peps.python.org/pep-0644/
148 # 5. https://peps.python.org/pep-0644/#libressl-support
149 # 6. https://github.com/yt-dlp/yt-dlp/commit/5b9f253fa0aee996cf1ed30185d4b502e00609c4#commitcomment-89054368
150 context.set_ciphers(
151 '@SECLEVEL=2:ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES:DHE+AES:!aNULL:!eNULL:!aDSS:!SHA1:!AESCCM')
152 context.minimum_version = ssl.TLSVersion.TLSv1_2
153
154 if client_certificate:
155 try:
156 context.load_cert_chain(
157 client_certificate, keyfile=client_certificate_key,
158 password=client_certificate_password)
159 except ssl.SSLError:
227bf1a3 160 raise RequestError('Unable to load client certificate')
c365dba8 161
227bf1a3 162 if getattr(context, 'post_handshake_auth', None) is not None:
163 context.post_handshake_auth = True
c365dba8 164 return context
165
166
227bf1a3 167class InstanceStoreMixin:
168 def __init__(self, **kwargs):
169 self.__instances = []
170 super().__init__(**kwargs) # So that both MRO works
171
172 @staticmethod
173 def _create_instance(**kwargs):
174 raise NotImplementedError
c365dba8 175
227bf1a3 176 def _get_instance(self, **kwargs):
177 for key, instance in self.__instances:
178 if key == kwargs:
179 return instance
180
181 instance = self._create_instance(**kwargs)
182 self.__instances.append((kwargs, instance))
183 return instance
184
185 def _close_instance(self, instance):
186 if callable(getattr(instance, 'close', None)):
187 instance.close()
188
189 def _clear_instances(self):
190 for _, instance in self.__instances:
191 self._close_instance(instance)
192 self.__instances.clear()
193
194
195def add_accept_encoding_header(headers: HTTPHeaderDict, supported_encodings: Iterable[str]):
196 if 'Accept-Encoding' not in headers:
197 headers['Accept-Encoding'] = ', '.join(supported_encodings) or 'identity'
198
199
200def wrap_request_errors(func):
201 @functools.wraps(func)
202 def wrapper(self, *args, **kwargs):
203 try:
204 return func(self, *args, **kwargs)
205 except UnsupportedRequest as e:
206 if e.handler is None:
207 e.handler = self
208 raise
209 return wrapper
20fbbd92 210
211
212def _socket_connect(ip_addr, timeout, source_address):
213 af, socktype, proto, canonname, sa = ip_addr
214 sock = socket.socket(af, socktype, proto)
215 try:
216 if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
217 sock.settimeout(timeout)
218 if source_address:
219 sock.bind(source_address)
220 sock.connect(sa)
221 return sock
222 except socket.error:
223 sock.close()
224 raise
225
226
8a8b5452 227def create_socks_proxy_socket(dest_addr, proxy_args, proxy_ip_addr, timeout, source_address):
228 af, socktype, proto, canonname, sa = proxy_ip_addr
229 sock = sockssocket(af, socktype, proto)
230 try:
231 connect_proxy_args = proxy_args.copy()
232 connect_proxy_args.update({'addr': sa[0], 'port': sa[1]})
233 sock.setproxy(**connect_proxy_args)
234 if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: # noqa: E721
235 sock.settimeout(timeout)
236 if source_address:
237 sock.bind(source_address)
238 sock.connect(dest_addr)
239 return sock
240 except socket.error:
241 sock.close()
242 raise
243
244
20fbbd92 245def create_connection(
246 address,
247 timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
248 source_address=None,
249 *,
250 _create_socket_func=_socket_connect
251):
252 # Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6.
253 # This filters the addresses based on the given source_address.
254 # Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810
255 host, port = address
256 ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
257 if not ip_addrs:
258 raise socket.error('getaddrinfo returns an empty list')
259 if source_address is not None:
260 af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6
261 ip_addrs = [addr for addr in ip_addrs if addr[0] == af]
262 if not ip_addrs:
263 raise OSError(
264 f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. '
265 f'Can\'t use "{source_address[0]}" as source address')
266
267 err = None
268 for ip_addr in ip_addrs:
269 try:
270 sock = _create_socket_func(ip_addr, timeout, source_address)
271 # Explicitly break __traceback__ reference cycle
272 # https://bugs.python.org/issue36820
273 err = None
274 return sock
275 except socket.error as e:
276 err = e
277
278 try:
279 raise err
280 finally:
281 # Explicitly break __traceback__ reference cycle
282 # https://bugs.python.org/issue36820
283 err = None