]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/networking/_urllib.py
[rh:requests] Add handler for `requests` HTTP library (#3668)
[yt-dlp.git] / yt_dlp / networking / _urllib.py
index ff3a22c8c18809163c2454dfb752edabe086b07d..68bab2b087ae81eb37fab91c36422e3111d610e9 100644 (file)
@@ -1,10 +1,8 @@
 from __future__ import annotations
 
 import functools
-import gzip
 import http.client
 import io
-import socket
 import ssl
 import urllib.error
 import urllib.parse
@@ -24,6 +22,8 @@
 from ._helper import (
     InstanceStoreMixin,
     add_accept_encoding_header,
+    create_connection,
+    create_socks_proxy_socket,
     get_redirect_method,
     make_socks_proxy_opts,
     select_proxy,
@@ -40,8 +40,8 @@
 )
 from ..dependencies import brotli
 from ..socks import ProxyError as SocksProxyError
-from ..socks import sockssocket
-from ..utils import escape_url, update_url_query
+from ..utils import update_url_query
+from ..utils.networking import normalize_url
 
 SUPPORTED_ENCODINGS = ['gzip', 'deflate']
 CONTENT_DECODE_ERRORS = [zlib.error, OSError]
 def _create_http_connection(http_class, source_address, *args, **kwargs):
     hc = http_class(*args, **kwargs)
 
+    if hasattr(hc, '_create_connection'):
+        hc._create_connection = create_connection
+
     if source_address is not None:
-        # This is to workaround _create_connection() from socket where it will try all
-        # address data from getaddrinfo() including IPv6. This filters the result from
-        # getaddrinfo() based on the source_address value.
-        # This is based on the cpython socket.create_connection() function.
-        # https://github.com/python/cpython/blob/master/Lib/socket.py#L691
-        def _create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
-            host, port = address
-            err = None
-            addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
-            af = socket.AF_INET if '.' in source_address[0] else socket.AF_INET6
-            ip_addrs = [addr for addr in addrs if addr[0] == af]
-            if addrs and not ip_addrs:
-                ip_version = 'v4' if af == socket.AF_INET else 'v6'
-                raise OSError(
-                    "No remote IP%s addresses available for connect, can't use '%s' as source address"
-                    % (ip_version, source_address[0]))
-            for res in ip_addrs:
-                af, socktype, proto, canonname, sa = res
-                sock = None
-                try:
-                    sock = socket.socket(af, socktype, proto)
-                    if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
-                        sock.settimeout(timeout)
-                    sock.bind(source_address)
-                    sock.connect(sa)
-                    err = None  # Explicitly break reference cycle
-                    return sock
-                except OSError as _:
-                    err = _
-                    if sock is not None:
-                        sock.close()
-            if err is not None:
-                raise err
-            else:
-                raise OSError('getaddrinfo returns an empty list')
-        if hasattr(hc, '_create_connection'):
-            hc._create_connection = _create_connection
         hc.source_address = (source_address, 0)
 
     return hc
@@ -154,20 +120,11 @@ def brotli(data):
 
     @staticmethod
     def gz(data):
-        gz = gzip.GzipFile(fileobj=io.BytesIO(data), mode='rb')
-        try:
-            return gz.read()
-        except OSError as original_oserror:
-            # There may be junk add the end of the file
-            # See http://stackoverflow.com/q/4928560/35070 for details
-            for i in range(1, 1024):
-                try:
-                    gz = gzip.GzipFile(fileobj=io.BytesIO(data[:-i]), mode='rb')
-                    return gz.read()
-                except OSError:
-                    continue
-            else:
-                raise original_oserror
+        # There may be junk added the end of the file
+        # We ignore it by only ever decoding a single gzip payload
+        if not data:
+            return data
+        return zlib.decompress(data, wbits=zlib.MAX_WBITS | 16)
 
     def http_request(self, req):
         # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not
@@ -179,7 +136,7 @@ def http_request(self, req):
         # Since redirects are also affected (e.g. http://www.southpark.de/alle-episoden/s18e09)
         # the code of this workaround has been moved here from YoutubeDL.urlopen()
         url = req.get_full_url()
-        url_escaped = escape_url(url)
+        url_escaped = normalize_url(url)
 
         # Substitute URL if any change after escaping
         if url != url_escaped:
@@ -212,7 +169,7 @@ def http_response(self, req, resp):
             if location:
                 # As of RFC 2616 default charset is iso-8859-1 that is respected by python 3
                 location = location.encode('iso-8859-1').decode()
-                location_escaped = escape_url(location)
+                location_escaped = normalize_url(location)
                 if location != location_escaped:
                     del resp.headers['Location']
                     resp.headers['Location'] = location_escaped
@@ -229,13 +186,15 @@ def make_socks_conn_class(base_class, socks_proxy):
     proxy_args = make_socks_proxy_opts(socks_proxy)
 
     class SocksConnection(base_class):
-        def connect(self):
-            self.sock = sockssocket()
-            self.sock.setproxy(**proxy_args)
-            if type(self.timeout) in (int, float):  # noqa: E721
-                self.sock.settimeout(self.timeout)
-            self.sock.connect((self.host, self.port))
+        _create_connection = create_connection
 
+        def connect(self):
+            self.sock = create_connection(
+                (proxy_args['addr'], proxy_args['port']),
+                timeout=self.timeout,
+                source_address=self.source_address,
+                _create_socket_func=functools.partial(
+                    create_socks_proxy_socket, (self.host, self.port), proxy_args))
             if isinstance(self, http.client.HTTPSConnection):
                 self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host)
 
@@ -364,7 +323,7 @@ def handle_sslerror(e: ssl.SSLError):
 
 def handle_response_read_exceptions(e):
     if isinstance(e, http.client.IncompleteRead):
-        raise IncompleteRead(partial=e.partial, cause=e, expected=e.expected) from e
+        raise IncompleteRead(partial=len(e.partial), cause=e, expected=e.expected) from e
     elif isinstance(e, ssl.SSLError):
         handle_sslerror(e)
     elif isinstance(e, (OSError, EOFError, http.client.HTTPException, *CONTENT_DECODE_ERRORS)):
@@ -385,6 +344,11 @@ def __init__(self, *, enable_file_urls: bool = False, **kwargs):
         if self.enable_file_urls:
             self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file')
 
+    def _check_extensions(self, extensions):
+        super()._check_extensions(extensions)
+        extensions.pop('cookiejar', None)
+        extensions.pop('timeout', None)
+
     def _create_instance(self, proxies, cookiejar):
         opener = urllib.request.OpenerDirector()
         handlers = [
@@ -433,7 +397,7 @@ def _send(self, request):
         except urllib.error.HTTPError as e:
             if isinstance(e.fp, (http.client.HTTPResponse, urllib.response.addinfourl)):
                 # Prevent file object from being closed when urllib.error.HTTPError is destroyed.
-                e._closer.file = None
+                e._closer.close_called = True
                 raise HTTPError(UrllibResponseAdapter(e.fp), redirect_loop='redirect error' in str(e)) from e
             raise  # unexpected
         except urllib.error.URLError as e: