]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/networking/_urllib.py
[rh:requests] Add handler for `requests` HTTP library (#3668)
[yt-dlp.git] / yt_dlp / networking / _urllib.py
index 1f5871ae674f3373fff600be9a71b64ef22574b8..68bab2b087ae81eb37fab91c36422e3111d610e9 100644 (file)
@@ -1,79 +1,69 @@
+from __future__ import annotations
+
 import functools
-import gzip
 import http.client
 import io
-import socket
 import ssl
 import urllib.error
 import urllib.parse
 import urllib.request
 import urllib.response
 import zlib
+from urllib.request import (
+    DataHandler,
+    FileHandler,
+    FTPHandler,
+    HTTPCookieProcessor,
+    HTTPDefaultErrorHandler,
+    HTTPErrorProcessor,
+    UnknownHandler,
+)
 
 from ._helper import (
+    InstanceStoreMixin,
     add_accept_encoding_header,
+    create_connection,
+    create_socks_proxy_socket,
     get_redirect_method,
     make_socks_proxy_opts,
+    select_proxy,
+)
+from .common import Features, RequestHandler, Response, register_rh
+from .exceptions import (
+    CertificateVerifyError,
+    HTTPError,
+    IncompleteRead,
+    ProxyError,
+    RequestError,
+    SSLError,
+    TransportError,
 )
 from ..dependencies import brotli
-from ..socks import sockssocket
-from ..utils import escape_url, update_url_query
-from ..utils.networking import clean_headers, std_headers
+from ..socks import ProxyError as SocksProxyError
+from ..utils import update_url_query
+from ..utils.networking import normalize_url
 
 SUPPORTED_ENCODINGS = ['gzip', 'deflate']
+CONTENT_DECODE_ERRORS = [zlib.error, OSError]
 
 if brotli:
     SUPPORTED_ENCODINGS.append('br')
+    CONTENT_DECODE_ERRORS.append(brotli.error)
 
 
-def _create_http_connection(ydl_handler, http_class, is_https, *args, **kwargs):
+def _create_http_connection(http_class, source_address, *args, **kwargs):
     hc = http_class(*args, **kwargs)
-    source_address = ydl_handler._params.get('source_address')
+
+    if hasattr(hc, '_create_connection'):
+        hc._create_connection = create_connection
 
     if source_address is not None:
-        # This is to workaround _create_connection() from socket where it will try all
-        # address data from getaddrinfo() including IPv6. This filters the result from
-        # getaddrinfo() based on the source_address value.
-        # This is based on the cpython socket.create_connection() function.
-        # https://github.com/python/cpython/blob/master/Lib/socket.py#L691
-        def _create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
-            host, port = address
-            err = None
-            addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
-            af = socket.AF_INET if '.' in source_address[0] else socket.AF_INET6
-            ip_addrs = [addr for addr in addrs if addr[0] == af]
-            if addrs and not ip_addrs:
-                ip_version = 'v4' if af == socket.AF_INET else 'v6'
-                raise OSError(
-                    "No remote IP%s addresses available for connect, can't use '%s' as source address"
-                    % (ip_version, source_address[0]))
-            for res in ip_addrs:
-                af, socktype, proto, canonname, sa = res
-                sock = None
-                try:
-                    sock = socket.socket(af, socktype, proto)
-                    if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
-                        sock.settimeout(timeout)
-                    sock.bind(source_address)
-                    sock.connect(sa)
-                    err = None  # Explicitly break reference cycle
-                    return sock
-                except OSError as _:
-                    err = _
-                    if sock is not None:
-                        sock.close()
-            if err is not None:
-                raise err
-            else:
-                raise OSError('getaddrinfo returns an empty list')
-        if hasattr(hc, '_create_connection'):
-            hc._create_connection = _create_connection
         hc.source_address = (source_address, 0)
 
     return hc
 
 
-class HTTPHandler(urllib.request.HTTPHandler):
+class HTTPHandler(urllib.request.AbstractHTTPHandler):
     """Handler for HTTP requests and responses.
 
     This class, when installed with an OpenerDirector, automatically adds
@@ -88,21 +78,30 @@ class HTTPHandler(urllib.request.HTTPHandler):
     public domain.
     """
 
-    def __init__(self, params, *args, **kwargs):
-        urllib.request.HTTPHandler.__init__(self, *args, **kwargs)
-        self._params = params
-
-    def http_open(self, req):
-        conn_class = http.client.HTTPConnection
+    def __init__(self, context=None, source_address=None, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._source_address = source_address
+        self._context = context
 
-        socks_proxy = req.headers.get('Ytdl-socks-proxy')
+    @staticmethod
+    def _make_conn_class(base, req):
+        conn_class = base
+        socks_proxy = req.headers.pop('Ytdl-socks-proxy', None)
         if socks_proxy:
             conn_class = make_socks_conn_class(conn_class, socks_proxy)
-            del req.headers['Ytdl-socks-proxy']
+        return conn_class
 
+    def http_open(self, req):
+        conn_class = self._make_conn_class(http.client.HTTPConnection, req)
         return self.do_open(functools.partial(
-            _create_http_connection, self, conn_class, False),
-            req)
+            _create_http_connection, conn_class, self._source_address), req)
+
+    def https_open(self, req):
+        conn_class = self._make_conn_class(http.client.HTTPSConnection, req)
+        return self.do_open(
+            functools.partial(
+                _create_http_connection, conn_class, self._source_address),
+            req, context=self._context)
 
     @staticmethod
     def deflate(data):
@@ -121,20 +120,11 @@ def brotli(data):
 
     @staticmethod
     def gz(data):
-        gz = gzip.GzipFile(fileobj=io.BytesIO(data), mode='rb')
-        try:
-            return gz.read()
-        except OSError as original_oserror:
-            # There may be junk add the end of the file
-            # See http://stackoverflow.com/q/4928560/35070 for details
-            for i in range(1, 1024):
-                try:
-                    gz = gzip.GzipFile(fileobj=io.BytesIO(data[:-i]), mode='rb')
-                    return gz.read()
-                except OSError:
-                    continue
-            else:
-                raise original_oserror
+        # There may be junk added the end of the file
+        # We ignore it by only ever decoding a single gzip payload
+        if not data:
+            return data
+        return zlib.decompress(data, wbits=zlib.MAX_WBITS | 16)
 
     def http_request(self, req):
         # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not
@@ -146,20 +136,12 @@ def http_request(self, req):
         # Since redirects are also affected (e.g. http://www.southpark.de/alle-episoden/s18e09)
         # the code of this workaround has been moved here from YoutubeDL.urlopen()
         url = req.get_full_url()
-        url_escaped = escape_url(url)
+        url_escaped = normalize_url(url)
 
         # Substitute URL if any change after escaping
         if url != url_escaped:
             req = update_Request(req, url=url_escaped)
 
-        for h, v in self._params.get('http_headers', std_headers).items():
-            # Capitalize is needed because of Python bug 2275: http://bugs.python.org/issue2275
-            # The dict keys are capitalized because of this bug by urllib
-            if h.capitalize() not in req.headers:
-                req.add_header(h, v)
-
-        clean_headers(req.headers)
-        add_accept_encoding_header(req.headers, SUPPORTED_ENCODINGS)
         return super().do_request_(req)
 
     def http_response(self, req, resp):
@@ -187,7 +169,7 @@ def http_response(self, req, resp):
             if location:
                 # As of RFC 2616 default charset is iso-8859-1 that is respected by python 3
                 location = location.encode('iso-8859-1').decode()
-                location_escaped = escape_url(location)
+                location_escaped = normalize_url(location)
                 if location != location_escaped:
                     del resp.headers['Location']
                     resp.headers['Location'] = location_escaped
@@ -204,19 +186,17 @@ def make_socks_conn_class(base_class, socks_proxy):
     proxy_args = make_socks_proxy_opts(socks_proxy)
 
     class SocksConnection(base_class):
-        def connect(self):
-            self.sock = sockssocket()
-            self.sock.setproxy(**proxy_args)
-            if isinstance(self.timeout, (int, float)):
-                self.sock.settimeout(self.timeout)
-            self.sock.connect((self.host, self.port))
+        _create_connection = create_connection
 
+        def connect(self):
+            self.sock = create_connection(
+                (proxy_args['addr'], proxy_args['port']),
+                timeout=self.timeout,
+                source_address=self.source_address,
+                _create_socket_func=functools.partial(
+                    create_socks_proxy_socket, (self.host, self.port), proxy_args))
             if isinstance(self, http.client.HTTPSConnection):
-                if hasattr(self, '_context'):  # Python > 2.6
-                    self.sock = self._context.wrap_socket(
-                        self.sock, server_hostname=self.host)
-                else:
-                    self.sock = ssl.wrap_socket(self.sock)
+                self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host)
 
     return SocksConnection
 
@@ -260,29 +240,25 @@ def redirect_request(self, req, fp, code, msg, headers, newurl):
             unverifiable=True, method=new_method, data=new_data)
 
 
-class ProxyHandler(urllib.request.ProxyHandler):
+class ProxyHandler(urllib.request.BaseHandler):
+    handler_order = 100
+
     def __init__(self, proxies=None):
+        self.proxies = proxies
         # Set default handlers
-        for type in ('http', 'https'):
-            setattr(self, '%s_open' % type,
-                    lambda r, proxy='__noproxy__', type=type, meth=self.proxy_open:
-                        meth(r, proxy, type))
-        urllib.request.ProxyHandler.__init__(self, proxies)
-
-    def proxy_open(self, req, proxy, type):
-        req_proxy = req.headers.get('Ytdl-request-proxy')
-        if req_proxy is not None:
-            proxy = req_proxy
-            del req.headers['Ytdl-request-proxy']
-
-        if proxy == '__noproxy__':
-            return None  # No Proxy
-        if urllib.parse.urlparse(proxy).scheme.lower() in ('socks', 'socks4', 'socks4a', 'socks5'):
+        for type in ('http', 'https', 'ftp'):
+            setattr(self, '%s_open' % type, lambda r, meth=self.proxy_open: meth(r))
+
+    def proxy_open(self, req):
+        proxy = select_proxy(req.get_full_url(), self.proxies)
+        if proxy is None:
+            return
+        if urllib.parse.urlparse(proxy).scheme.lower() in ('socks4', 'socks4a', 'socks5', 'socks5h'):
             req.add_header('Ytdl-socks-proxy', proxy)
             # yt-dlp's http/https handlers do wrapping the socket with socks
             return None
         return urllib.request.ProxyHandler.proxy_open(
-            self, req, proxy, type)
+            self, req, proxy, None)
 
 
 class PUTRequest(urllib.request.Request):
@@ -298,7 +274,7 @@ def get_method(self):
 def update_Request(req, url=None, data=None, headers=None, query=None):
     req_headers = req.headers.copy()
     req_headers.update(headers or {})
-    req_data = data or req.data
+    req_data = data if data is not None else req.data
     req_url = update_url_query(url or req.get_full_url(), query)
     req_get_method = req.get_method()
     if req_get_method == 'HEAD':
@@ -313,3 +289,134 @@ def update_Request(req, url=None, data=None, headers=None, query=None):
     if hasattr(req, 'timeout'):
         new_req.timeout = req.timeout
     return new_req
+
+
+class UrllibResponseAdapter(Response):
+    """
+    HTTP Response adapter class for urllib addinfourl and http.client.HTTPResponse
+    """
+
+    def __init__(self, res: http.client.HTTPResponse | urllib.response.addinfourl):
+        # addinfourl: In Python 3.9+, .status was introduced and .getcode() was deprecated [1]
+        # HTTPResponse: .getcode() was deprecated, .status always existed [2]
+        # 1. https://docs.python.org/3/library/urllib.request.html#urllib.response.addinfourl.getcode
+        # 2. https://docs.python.org/3.10/library/http.client.html#http.client.HTTPResponse.status
+        super().__init__(
+            fp=res, headers=res.headers, url=res.url,
+            status=getattr(res, 'status', None) or res.getcode(), reason=getattr(res, 'reason', None))
+
+    def read(self, amt=None):
+        try:
+            return self.fp.read(amt)
+        except Exception as e:
+            handle_response_read_exceptions(e)
+            raise e
+
+
+def handle_sslerror(e: ssl.SSLError):
+    if not isinstance(e, ssl.SSLError):
+        return
+    if isinstance(e, ssl.SSLCertVerificationError):
+        raise CertificateVerifyError(cause=e) from e
+    raise SSLError(cause=e) from e
+
+
+def handle_response_read_exceptions(e):
+    if isinstance(e, http.client.IncompleteRead):
+        raise IncompleteRead(partial=len(e.partial), cause=e, expected=e.expected) from e
+    elif isinstance(e, ssl.SSLError):
+        handle_sslerror(e)
+    elif isinstance(e, (OSError, EOFError, http.client.HTTPException, *CONTENT_DECODE_ERRORS)):
+        # OSErrors raised here should mostly be network related
+        raise TransportError(cause=e) from e
+
+
+@register_rh
+class UrllibRH(RequestHandler, InstanceStoreMixin):
+    _SUPPORTED_URL_SCHEMES = ('http', 'https', 'data', 'ftp')
+    _SUPPORTED_PROXY_SCHEMES = ('http', 'socks4', 'socks4a', 'socks5', 'socks5h')
+    _SUPPORTED_FEATURES = (Features.NO_PROXY, Features.ALL_PROXY)
+    RH_NAME = 'urllib'
+
+    def __init__(self, *, enable_file_urls: bool = False, **kwargs):
+        super().__init__(**kwargs)
+        self.enable_file_urls = enable_file_urls
+        if self.enable_file_urls:
+            self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file')
+
+    def _check_extensions(self, extensions):
+        super()._check_extensions(extensions)
+        extensions.pop('cookiejar', None)
+        extensions.pop('timeout', None)
+
+    def _create_instance(self, proxies, cookiejar):
+        opener = urllib.request.OpenerDirector()
+        handlers = [
+            ProxyHandler(proxies),
+            HTTPHandler(
+                debuglevel=int(bool(self.verbose)),
+                context=self._make_sslcontext(),
+                source_address=self.source_address),
+            HTTPCookieProcessor(cookiejar),
+            DataHandler(),
+            UnknownHandler(),
+            HTTPDefaultErrorHandler(),
+            FTPHandler(),
+            HTTPErrorProcessor(),
+            RedirectHandler(),
+        ]
+
+        if self.enable_file_urls:
+            handlers.append(FileHandler())
+
+        for handler in handlers:
+            opener.add_handler(handler)
+
+        # Delete the default user-agent header, which would otherwise apply in
+        # cases where our custom HTTP handler doesn't come into play
+        # (See https://github.com/ytdl-org/youtube-dl/issues/1309 for details)
+        opener.addheaders = []
+        return opener
+
+    def _send(self, request):
+        headers = self._merge_headers(request.headers)
+        add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
+        urllib_req = urllib.request.Request(
+            url=request.url,
+            data=request.data,
+            headers=dict(headers),
+            method=request.method
+        )
+
+        opener = self._get_instance(
+            proxies=request.proxies or self.proxies,
+            cookiejar=request.extensions.get('cookiejar') or self.cookiejar
+        )
+        try:
+            res = opener.open(urllib_req, timeout=float(request.extensions.get('timeout') or self.timeout))
+        except urllib.error.HTTPError as e:
+            if isinstance(e.fp, (http.client.HTTPResponse, urllib.response.addinfourl)):
+                # Prevent file object from being closed when urllib.error.HTTPError is destroyed.
+                e._closer.close_called = True
+                raise HTTPError(UrllibResponseAdapter(e.fp), redirect_loop='redirect error' in str(e)) from e
+            raise  # unexpected
+        except urllib.error.URLError as e:
+            cause = e.reason  # NOTE: cause may be a string
+
+            # proxy errors
+            if 'tunnel connection failed' in str(cause).lower() or isinstance(cause, SocksProxyError):
+                raise ProxyError(cause=e) from e
+
+            handle_response_read_exceptions(cause)
+            raise TransportError(cause=e) from e
+        except (http.client.InvalidURL, ValueError) as e:
+            # Validation errors
+            # http.client.HTTPConnection raises ValueError in some validation cases
+            # such as if request method contains illegal control characters [1]
+            # 1. https://github.com/python/cpython/blob/987b712b4aeeece336eed24fcc87a950a756c3e2/Lib/http/client.py#L1256
+            raise RequestError(cause=e) from e
+        except Exception as e:
+            handle_response_read_exceptions(e)
+            raise  # unexpected
+
+        return UrllibResponseAdapter(res)