]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/networking/_helper.py
[rh:requests] Add handler for `requests` HTTP library (#3668)
[yt-dlp.git] / yt_dlp / networking / _helper.py
index 367f3f4447bd7c3d6b007cca98f33f2fd78915e1..a6fa3550bd1e27ec8407eeffcf08586494eaa7e3 100644 (file)
@@ -1,13 +1,23 @@
 from __future__ import annotations
 
 import contextlib
+import functools
+import socket
 import ssl
 import sys
+import typing
 import urllib.parse
+import urllib.request
 
+from .exceptions import RequestError, UnsupportedRequest
 from ..dependencies import certifi
-from ..socks import ProxyType
-from ..utils import YoutubeDLError
+from ..socks import ProxyType, sockssocket
+from ..utils import format_field, traverse_obj
+
+if typing.TYPE_CHECKING:
+    from collections.abc import Iterable
+
+    from ..utils.networking import HTTPHeaderDict
 
 
 def ssl_load_certs(context: ssl.SSLContext, use_certifi=True):
@@ -23,11 +33,11 @@ def ssl_load_certs(context: ssl.SSLContext, use_certifi=True):
             # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
             if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
                 for storename in ('CA', 'ROOT'):
-                    _ssl_load_windows_store_certs(context, storename)
+                    ssl_load_windows_store_certs(context, storename)
             context.set_default_verify_paths()
 
 
-def _ssl_load_windows_store_certs(ssl_context, storename):
+def ssl_load_windows_store_certs(ssl_context, storename):
     # Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py
     try:
         certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename)
@@ -44,10 +54,18 @@ def make_socks_proxy_opts(socks_proxy):
     url_components = urllib.parse.urlparse(socks_proxy)
     if url_components.scheme.lower() == 'socks5':
         socks_type = ProxyType.SOCKS5
-    elif url_components.scheme.lower() in ('socks', 'socks4'):
+        rdns = False
+    elif url_components.scheme.lower() == 'socks5h':
+        socks_type = ProxyType.SOCKS5
+        rdns = True
+    elif url_components.scheme.lower() == 'socks4':
         socks_type = ProxyType.SOCKS4
+        rdns = False
     elif url_components.scheme.lower() == 'socks4a':
         socks_type = ProxyType.SOCKS4A
+        rdns = True
+    else:
+        raise ValueError(f'Unknown SOCKS proxy version: {url_components.scheme.lower()}')
 
     def unquote_if_non_empty(s):
         if not s:
@@ -57,12 +75,25 @@ def unquote_if_non_empty(s):
         'proxytype': socks_type,
         'addr': url_components.hostname,
         'port': url_components.port or 1080,
-        'rdns': True,
+        'rdns': rdns,
         'username': unquote_if_non_empty(url_components.username),
         'password': unquote_if_non_empty(url_components.password),
     }
 
 
+def select_proxy(url, proxies):
+    """Unified proxy selector for all backends"""
+    url_components = urllib.parse.urlparse(url)
+    if 'no' in proxies:
+        hostport = url_components.hostname + format_field(url_components.port, None, ':%s')
+        if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}):
+            return
+        elif urllib.request.proxy_bypass(hostport):  # check system settings
+            return
+
+    return traverse_obj(proxies, url_components.scheme or 'http', 'all')
+
+
 def get_redirect_method(method, status):
     """Unified redirect method handling"""
 
@@ -126,14 +157,127 @@ def make_ssl_context(
                 client_certificate, keyfile=client_certificate_key,
                 password=client_certificate_password)
         except ssl.SSLError:
-            raise YoutubeDLError('Unable to load client certificate')
+            raise RequestError('Unable to load client certificate')
 
+        if getattr(context, 'post_handshake_auth', None) is not None:
+            context.post_handshake_auth = True
     return context
 
 
-def add_accept_encoding_header(headers, supported_encodings):
-    if supported_encodings and 'Accept-Encoding' not in headers:
-        headers['Accept-Encoding'] = ', '.join(supported_encodings)
+class InstanceStoreMixin:
+    def __init__(self, **kwargs):
+        self.__instances = []
+        super().__init__(**kwargs)  # So that both MRO works
+
+    @staticmethod
+    def _create_instance(**kwargs):
+        raise NotImplementedError
+
+    def _get_instance(self, **kwargs):
+        for key, instance in self.__instances:
+            if key == kwargs:
+                return instance
+
+        instance = self._create_instance(**kwargs)
+        self.__instances.append((kwargs, instance))
+        return instance
+
+    def _close_instance(self, instance):
+        if callable(getattr(instance, 'close', None)):
+            instance.close()
+
+    def _clear_instances(self):
+        for _, instance in self.__instances:
+            self._close_instance(instance)
+        self.__instances.clear()
+
+
+def add_accept_encoding_header(headers: HTTPHeaderDict, supported_encodings: Iterable[str]):
+    if 'Accept-Encoding' not in headers:
+        headers['Accept-Encoding'] = ', '.join(supported_encodings) or 'identity'
+
 
-    elif 'Accept-Encoding' not in headers:
-        headers['Accept-Encoding'] = 'identity'
+def wrap_request_errors(func):
+    @functools.wraps(func)
+    def wrapper(self, *args, **kwargs):
+        try:
+            return func(self, *args, **kwargs)
+        except UnsupportedRequest as e:
+            if e.handler is None:
+                e.handler = self
+            raise
+    return wrapper
+
+
+def _socket_connect(ip_addr, timeout, source_address):
+    af, socktype, proto, canonname, sa = ip_addr
+    sock = socket.socket(af, socktype, proto)
+    try:
+        if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
+            sock.settimeout(timeout)
+        if source_address:
+            sock.bind(source_address)
+        sock.connect(sa)
+        return sock
+    except socket.error:
+        sock.close()
+        raise
+
+
+def create_socks_proxy_socket(dest_addr, proxy_args, proxy_ip_addr, timeout, source_address):
+    af, socktype, proto, canonname, sa = proxy_ip_addr
+    sock = sockssocket(af, socktype, proto)
+    try:
+        connect_proxy_args = proxy_args.copy()
+        connect_proxy_args.update({'addr': sa[0], 'port': sa[1]})
+        sock.setproxy(**connect_proxy_args)
+        if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:  # noqa: E721
+            sock.settimeout(timeout)
+        if source_address:
+            sock.bind(source_address)
+        sock.connect(dest_addr)
+        return sock
+    except socket.error:
+        sock.close()
+        raise
+
+
+def create_connection(
+    address,
+    timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
+    source_address=None,
+    *,
+    _create_socket_func=_socket_connect
+):
+    # Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6.
+    # This filters the addresses based on the given source_address.
+    # Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810
+    host, port = address
+    ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
+    if not ip_addrs:
+        raise socket.error('getaddrinfo returns an empty list')
+    if source_address is not None:
+        af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6
+        ip_addrs = [addr for addr in ip_addrs if addr[0] == af]
+        if not ip_addrs:
+            raise OSError(
+                f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. '
+                f'Can\'t use "{source_address[0]}" as source address')
+
+    err = None
+    for ip_addr in ip_addrs:
+        try:
+            sock = _create_socket_func(ip_addr, timeout, source_address)
+            # Explicitly break __traceback__ reference cycle
+            # https://bugs.python.org/issue36820
+            err = None
+            return sock
+        except socket.error as e:
+            err = e
+
+    try:
+        raise err
+    finally:
+        # Explicitly break __traceback__ reference cycle
+        # https://bugs.python.org/issue36820
+        err = None