]> jfr.im git - yt-dlp.git/blame - yt_dlp/networking/_helper.py
[cleanup] Add more ruff rules (#10149)
[yt-dlp.git] / yt_dlp / networking / _helper.py
CommitLineData
c365dba8 1from __future__ import annotations
2
3import contextlib
227bf1a3 4import functools
79a451e5 5import os
20fbbd92 6import socket
c365dba8 7import ssl
8import sys
227bf1a3 9import typing
c365dba8 10import urllib.parse
227bf1a3 11import urllib.request
c365dba8 12
227bf1a3 13from .exceptions import RequestError, UnsupportedRequest
c365dba8 14from ..dependencies import certifi
8a8b5452 15from ..socks import ProxyType, sockssocket
227bf1a3 16from ..utils import format_field, traverse_obj
17
18if typing.TYPE_CHECKING:
19 from collections.abc import Iterable
20
21 from ..utils.networking import HTTPHeaderDict
c365dba8 22
23
24def ssl_load_certs(context: ssl.SSLContext, use_certifi=True):
25 if certifi and use_certifi:
26 context.load_verify_locations(cafile=certifi.where())
27 else:
28 try:
29 context.load_default_certs()
30 # Work around the issue in load_default_certs when there are bad certificates. See:
31 # https://github.com/yt-dlp/yt-dlp/issues/1060,
32 # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312
33 except ssl.SSLError:
34 # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
35 if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
36 for storename in ('CA', 'ROOT'):
227bf1a3 37 ssl_load_windows_store_certs(context, storename)
c365dba8 38 context.set_default_verify_paths()
39
40
227bf1a3 41def ssl_load_windows_store_certs(ssl_context, storename):
c365dba8 42 # Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py
43 try:
44 certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename)
45 if encoding == 'x509_asn' and (
46 trust is True or ssl.Purpose.SERVER_AUTH.oid in trust)]
47 except PermissionError:
48 return
49 for cert in certs:
50 with contextlib.suppress(ssl.SSLError):
51 ssl_context.load_verify_locations(cadata=cert)
52
53
54def make_socks_proxy_opts(socks_proxy):
55 url_components = urllib.parse.urlparse(socks_proxy)
56 if url_components.scheme.lower() == 'socks5':
57 socks_type = ProxyType.SOCKS5
227bf1a3 58 rdns = False
59 elif url_components.scheme.lower() == 'socks5h':
60 socks_type = ProxyType.SOCKS5
61 rdns = True
62 elif url_components.scheme.lower() == 'socks4':
c365dba8 63 socks_type = ProxyType.SOCKS4
227bf1a3 64 rdns = False
c365dba8 65 elif url_components.scheme.lower() == 'socks4a':
66 socks_type = ProxyType.SOCKS4A
227bf1a3 67 rdns = True
68 else:
69 raise ValueError(f'Unknown SOCKS proxy version: {url_components.scheme.lower()}')
c365dba8 70
71 def unquote_if_non_empty(s):
72 if not s:
73 return s
74 return urllib.parse.unquote_plus(s)
75 return {
76 'proxytype': socks_type,
77 'addr': url_components.hostname,
78 'port': url_components.port or 1080,
227bf1a3 79 'rdns': rdns,
c365dba8 80 'username': unquote_if_non_empty(url_components.username),
81 'password': unquote_if_non_empty(url_components.password),
82 }
83
84
227bf1a3 85def select_proxy(url, proxies):
86 """Unified proxy selector for all backends"""
87 url_components = urllib.parse.urlparse(url)
88 if 'no' in proxies:
89 hostport = url_components.hostname + format_field(url_components.port, None, ':%s')
90 if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}):
91 return
92 elif urllib.request.proxy_bypass(hostport): # check system settings
93 return
94
95 return traverse_obj(proxies, url_components.scheme or 'http', 'all')
96
97
c365dba8 98def get_redirect_method(method, status):
99 """Unified redirect method handling"""
100
101 # A 303 must either use GET or HEAD for subsequent request
102 # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4
103 if status == 303 and method != 'HEAD':
104 method = 'GET'
105 # 301 and 302 redirects are commonly turned into a GET from a POST
106 # for subsequent requests by browsers, so we'll do the same.
107 # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2
108 # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3
109 if status in (301, 302) and method == 'POST':
110 method = 'GET'
111 return method
112
113
114def make_ssl_context(
115 verify=True,
116 client_certificate=None,
117 client_certificate_key=None,
118 client_certificate_password=None,
119 legacy_support=False,
120 use_certifi=True,
121):
122 context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
123 context.check_hostname = verify
124 context.verify_mode = ssl.CERT_REQUIRED if verify else ssl.CERT_NONE
79a451e5 125 # OpenSSL 1.1.1+ Python 3.8+ keylog file
126 if hasattr(context, 'keylog_filename'):
216f6a3c 127 context.keylog_filename = os.environ.get('SSLKEYLOGFILE') or None
c365dba8 128
129 # Some servers may reject requests if ALPN extension is not sent. See:
130 # https://github.com/python/cpython/issues/85140
131 # https://github.com/yt-dlp/yt-dlp/issues/3878
132 with contextlib.suppress(NotImplementedError):
133 context.set_alpn_protocols(['http/1.1'])
134 if verify:
135 ssl_load_certs(context, use_certifi)
136
137 if legacy_support:
138 context.options |= 4 # SSL_OP_LEGACY_SERVER_CONNECT
139 context.set_ciphers('DEFAULT') # compat
140
141 elif ssl.OPENSSL_VERSION_INFO >= (1, 1, 1) and not ssl.OPENSSL_VERSION.startswith('LibreSSL'):
142 # Use the default SSL ciphers and minimum TLS version settings from Python 3.10 [1].
143 # This is to ensure consistent behavior across Python versions and libraries, and help avoid fingerprinting
144 # in some situations [2][3].
145 # Python 3.10 only supports OpenSSL 1.1.1+ [4]. Because this change is likely
146 # untested on older versions, we only apply this to OpenSSL 1.1.1+ to be safe.
147 # LibreSSL is excluded until further investigation due to cipher support issues [5][6].
148 # 1. https://github.com/python/cpython/commit/e983252b516edb15d4338b0a47631b59ef1e2536
149 # 2. https://github.com/yt-dlp/yt-dlp/issues/4627
150 # 3. https://github.com/yt-dlp/yt-dlp/pull/5294
151 # 4. https://peps.python.org/pep-0644/
152 # 5. https://peps.python.org/pep-0644/#libressl-support
153 # 6. https://github.com/yt-dlp/yt-dlp/commit/5b9f253fa0aee996cf1ed30185d4b502e00609c4#commitcomment-89054368
154 context.set_ciphers(
155 '@SECLEVEL=2:ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES:DHE+AES:!aNULL:!eNULL:!aDSS:!SHA1:!AESCCM')
156 context.minimum_version = ssl.TLSVersion.TLSv1_2
157
158 if client_certificate:
159 try:
160 context.load_cert_chain(
161 client_certificate, keyfile=client_certificate_key,
162 password=client_certificate_password)
163 except ssl.SSLError:
227bf1a3 164 raise RequestError('Unable to load client certificate')
c365dba8 165
227bf1a3 166 if getattr(context, 'post_handshake_auth', None) is not None:
167 context.post_handshake_auth = True
c365dba8 168 return context
169
170
227bf1a3 171class InstanceStoreMixin:
172 def __init__(self, **kwargs):
173 self.__instances = []
174 super().__init__(**kwargs) # So that both MRO works
175
176 @staticmethod
177 def _create_instance(**kwargs):
178 raise NotImplementedError
c365dba8 179
227bf1a3 180 def _get_instance(self, **kwargs):
181 for key, instance in self.__instances:
182 if key == kwargs:
183 return instance
184
185 instance = self._create_instance(**kwargs)
186 self.__instances.append((kwargs, instance))
187 return instance
188
189 def _close_instance(self, instance):
190 if callable(getattr(instance, 'close', None)):
191 instance.close()
192
193 def _clear_instances(self):
194 for _, instance in self.__instances:
195 self._close_instance(instance)
196 self.__instances.clear()
197
198
199def add_accept_encoding_header(headers: HTTPHeaderDict, supported_encodings: Iterable[str]):
200 if 'Accept-Encoding' not in headers:
201 headers['Accept-Encoding'] = ', '.join(supported_encodings) or 'identity'
202
203
204def wrap_request_errors(func):
205 @functools.wraps(func)
206 def wrapper(self, *args, **kwargs):
207 try:
208 return func(self, *args, **kwargs)
209 except UnsupportedRequest as e:
210 if e.handler is None:
211 e.handler = self
212 raise
213 return wrapper
20fbbd92 214
215
216def _socket_connect(ip_addr, timeout, source_address):
217 af, socktype, proto, canonname, sa = ip_addr
218 sock = socket.socket(af, socktype, proto)
219 try:
220 if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
221 sock.settimeout(timeout)
222 if source_address:
223 sock.bind(source_address)
224 sock.connect(sa)
225 return sock
f9fb3ce8 226 except OSError:
20fbbd92 227 sock.close()
228 raise
229
230
8a8b5452 231def create_socks_proxy_socket(dest_addr, proxy_args, proxy_ip_addr, timeout, source_address):
232 af, socktype, proto, canonname, sa = proxy_ip_addr
233 sock = sockssocket(af, socktype, proto)
234 try:
235 connect_proxy_args = proxy_args.copy()
236 connect_proxy_args.update({'addr': sa[0], 'port': sa[1]})
237 sock.setproxy(**connect_proxy_args)
add96eb9 238 if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
8a8b5452 239 sock.settimeout(timeout)
240 if source_address:
241 sock.bind(source_address)
242 sock.connect(dest_addr)
243 return sock
f9fb3ce8 244 except OSError:
8a8b5452 245 sock.close()
246 raise
247
248
20fbbd92 249def create_connection(
250 address,
251 timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
252 source_address=None,
253 *,
add96eb9 254 _create_socket_func=_socket_connect,
20fbbd92 255):
256 # Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6.
257 # This filters the addresses based on the given source_address.
258 # Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810
259 host, port = address
260 ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
261 if not ip_addrs:
f9fb3ce8 262 raise OSError('getaddrinfo returns an empty list')
20fbbd92 263 if source_address is not None:
264 af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6
265 ip_addrs = [addr for addr in ip_addrs if addr[0] == af]
266 if not ip_addrs:
267 raise OSError(
268 f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. '
269 f'Can\'t use "{source_address[0]}" as source address')
270
271 err = None
272 for ip_addr in ip_addrs:
273 try:
274 sock = _create_socket_func(ip_addr, timeout, source_address)
275 # Explicitly break __traceback__ reference cycle
276 # https://bugs.python.org/issue36820
277 err = None
278 return sock
f9fb3ce8 279 except OSError as e:
20fbbd92 280 err = e
281
282 try:
283 raise err
284 finally:
285 # Explicitly break __traceback__ reference cycle
286 # https://bugs.python.org/issue36820
287 err = None