]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/networking/_urllib.py
[rh:requests] Add handler for `requests` HTTP library (#3668)
[yt-dlp.git] / yt_dlp / networking / _urllib.py
index 5a804d99b472a30a05c4a5bafc8f73bac65387c1..68bab2b087ae81eb37fab91c36422e3111d610e9 100644 (file)
@@ -3,7 +3,6 @@
 import functools
 import http.client
 import io
-import socket
 import ssl
 import urllib.error
 import urllib.parse
@@ -23,6 +22,8 @@
 from ._helper import (
     InstanceStoreMixin,
     add_accept_encoding_header,
+    create_connection,
+    create_socks_proxy_socket,
     get_redirect_method,
     make_socks_proxy_opts,
     select_proxy,
@@ -39,7 +40,6 @@
 )
 from ..dependencies import brotli
 from ..socks import ProxyError as SocksProxyError
-from ..socks import sockssocket
 from ..utils import update_url_query
 from ..utils.networking import normalize_url
 
 def _create_http_connection(http_class, source_address, *args, **kwargs):
     hc = http_class(*args, **kwargs)
 
+    if hasattr(hc, '_create_connection'):
+        hc._create_connection = create_connection
+
     if source_address is not None:
-        # This is to workaround _create_connection() from socket where it will try all
-        # address data from getaddrinfo() including IPv6. This filters the result from
-        # getaddrinfo() based on the source_address value.
-        # This is based on the cpython socket.create_connection() function.
-        # https://github.com/python/cpython/blob/master/Lib/socket.py#L691
-        def _create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
-            host, port = address
-            err = None
-            addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
-            af = socket.AF_INET if '.' in source_address[0] else socket.AF_INET6
-            ip_addrs = [addr for addr in addrs if addr[0] == af]
-            if addrs and not ip_addrs:
-                ip_version = 'v4' if af == socket.AF_INET else 'v6'
-                raise OSError(
-                    "No remote IP%s addresses available for connect, can't use '%s' as source address"
-                    % (ip_version, source_address[0]))
-            for res in ip_addrs:
-                af, socktype, proto, canonname, sa = res
-                sock = None
-                try:
-                    sock = socket.socket(af, socktype, proto)
-                    if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
-                        sock.settimeout(timeout)
-                    sock.bind(source_address)
-                    sock.connect(sa)
-                    err = None  # Explicitly break reference cycle
-                    return sock
-                except OSError as _:
-                    err = _
-                    if sock is not None:
-                        sock.close()
-            if err is not None:
-                raise err
-            else:
-                raise OSError('getaddrinfo returns an empty list')
-        if hasattr(hc, '_create_connection'):
-            hc._create_connection = _create_connection
         hc.source_address = (source_address, 0)
 
     return hc
@@ -156,6 +122,8 @@ def brotli(data):
     def gz(data):
         # There may be junk added the end of the file
         # We ignore it by only ever decoding a single gzip payload
+        if not data:
+            return data
         return zlib.decompress(data, wbits=zlib.MAX_WBITS | 16)
 
     def http_request(self, req):
@@ -218,13 +186,15 @@ def make_socks_conn_class(base_class, socks_proxy):
     proxy_args = make_socks_proxy_opts(socks_proxy)
 
     class SocksConnection(base_class):
-        def connect(self):
-            self.sock = sockssocket()
-            self.sock.setproxy(**proxy_args)
-            if type(self.timeout) in (int, float):  # noqa: E721
-                self.sock.settimeout(self.timeout)
-            self.sock.connect((self.host, self.port))
+        _create_connection = create_connection
 
+        def connect(self):
+            self.sock = create_connection(
+                (proxy_args['addr'], proxy_args['port']),
+                timeout=self.timeout,
+                source_address=self.source_address,
+                _create_socket_func=functools.partial(
+                    create_socks_proxy_socket, (self.host, self.port), proxy_args))
             if isinstance(self, http.client.HTTPSConnection):
                 self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host)
 
@@ -353,7 +323,7 @@ def handle_sslerror(e: ssl.SSLError):
 
 def handle_response_read_exceptions(e):
     if isinstance(e, http.client.IncompleteRead):
-        raise IncompleteRead(partial=e.partial, cause=e, expected=e.expected) from e
+        raise IncompleteRead(partial=len(e.partial), cause=e, expected=e.expected) from e
     elif isinstance(e, ssl.SSLError):
         handle_sslerror(e)
     elif isinstance(e, (OSError, EOFError, http.client.HTTPException, *CONTENT_DECODE_ERRORS)):
@@ -427,7 +397,7 @@ def _send(self, request):
         except urllib.error.HTTPError as e:
             if isinstance(e.fp, (http.client.HTTPResponse, urllib.response.addinfourl)):
                 # Prevent file object from being closed when urllib.error.HTTPError is destroyed.
-                e._closer.file = None
+                e._closer.close_called = True
                 raise HTTPError(UrllibResponseAdapter(e.fp), redirect_loop='redirect error' in str(e)) from e
             raise  # unexpected
         except urllib.error.URLError as e: