]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/networking/common.py
[cleanup] Add more ruff rules (#10149)
[yt-dlp.git] / yt_dlp / networking / common.py
index e4b362827662fe78655212ec3b6a150fc584d63b..a6db167158213d8e9139318e1d65a40ba6d5de41 100644 (file)
@@ -12,7 +12,6 @@
 from collections.abc import Iterable, Mapping
 from email.message import Message
 from http import HTTPStatus
-from http.cookiejar import CookieJar
 
 from ._helper import make_ssl_context, wrap_request_errors
 from .exceptions import (
     TransportError,
     UnsupportedRequest,
 )
+from ..compat.types import NoneType
+from ..cookies import YoutubeDLCookieJar
 from ..utils import (
     bug_reports_message,
     classproperty,
+    deprecation_warning,
     error_to_str,
-    escape_url,
     update_url_query,
 )
-from ..utils.networking import HTTPHeaderDict
+from ..utils.networking import HTTPHeaderDict, normalize_url
+
+DEFAULT_TIMEOUT = 20
 
-if typing.TYPE_CHECKING:
-    RequestData = bytes | Iterable[bytes] | typing.IO | None
+
+def register_preference(*handlers: type[RequestHandler]):
+    assert all(issubclass(handler, RequestHandler) for handler in handlers)
+
+    def outer(preference: Preference):
+        @functools.wraps(preference)
+        def inner(handler, *args, **kwargs):
+            if not handlers or isinstance(handler, handlers):
+                return preference(handler, *args, **kwargs)
+            return 0
+        _RH_PREFERENCES.add(inner)
+        return inner
+    return outer
 
 
 class RequestDirector:
@@ -39,24 +53,40 @@ class RequestDirector:
 
     Helper class that, when given a request, forward it to a RequestHandler that supports it.
 
+    Preference functions in the form of func(handler, request) -> int
+    can be registered into the `preferences` set. These are used to sort handlers
+    in order of preference.
+
     @param logger: Logger instance.
     @param verbose: Print debug request information to stdout.
     """
 
     def __init__(self, logger, verbose=False):
         self.handlers: dict[str, RequestHandler] = {}
+        self.preferences: set[Preference] = set()
         self.logger = logger  # TODO(Grub4k): default logger
         self.verbose = verbose
 
     def close(self):
         for handler in self.handlers.values():
             handler.close()
+        self.handlers.clear()
 
     def add_handler(self, handler: RequestHandler):
         """Add a handler. If a handler of the same RH_KEY exists, it will overwrite it"""
         assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
         self.handlers[handler.RH_KEY] = handler
 
+    def _get_handlers(self, request: Request) -> list[RequestHandler]:
+        """Sorts handlers by preference, given a request"""
+        preferences = {
+            rh: sum(pref(rh, request) for pref in self.preferences)
+            for rh in self.handlers.values()
+        }
+        self._print_verbose('Handler preferences for this request: {}'.format(', '.join(
+            f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items())))
+        return sorted(self.handlers.values(), key=preferences.get, reverse=True)
+
     def _print_verbose(self, msg):
         if self.verbose:
             self.logger.stdout(f'director: {msg}')
@@ -72,8 +102,7 @@ def send(self, request: Request) -> Response:
 
         unexpected_errors = []
         unsupported_errors = []
-        # TODO (future): add a per-request preference system
-        for handler in reversed(list(self.handlers.values())):
+        for handler in self._get_handlers(request):
             self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
             try:
                 handler.validate(request)
@@ -104,7 +133,7 @@ def send(self, request: Request) -> Response:
 _REQUEST_HANDLERS = {}
 
 
-def register(handler):
+def register_rh(handler):
     """Register a RequestHandler class"""
     assert issubclass(handler, RequestHandler), f'{handler} must be a subclass of RequestHandler'
     assert handler.RH_KEY not in _REQUEST_HANDLERS, f'RequestHandler {handler.RH_KEY} already registered'
@@ -146,6 +175,7 @@ class RequestHandler(abc.ABC):
         a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
 
     - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
+
     The above may be set to None to disable the checks.
 
     Parameters:
@@ -168,9 +198,14 @@ class RequestHandler(abc.ABC):
     Requests may have additional optional parameters defined as extensions.
      RequestHandler subclasses may choose to support custom extensions.
 
+    If an extension is supported, subclasses should extend _check_extensions(extensions)
+    to pop and validate the extension.
+    - Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised.
+
     The following extensions are defined for RequestHandler:
-    - `cookiejar`: Cookiejar to use for this request
-    - `timeout`: socket timeout to use for this request
+    - `cookiejar`: Cookiejar to use for this request.
+    - `timeout`: socket timeout to use for this request.
+    To enable these, add extensions.pop('<extension>', None) to _check_extensions
 
     Apart from the url protocol, proxies dict may contain the following keys:
     - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
@@ -187,13 +222,13 @@ def __init__(
         self, *,
         logger,  # TODO(Grub4k): default logger
         headers: HTTPHeaderDict = None,
-        cookiejar: CookieJar = None,
+        cookiejar: YoutubeDLCookieJar = None,
         timeout: float | int | None = None,
-        proxies: dict = None,
-        source_address: str = None,
+        proxies: dict | None = None,
+        source_address: str | None = None,
         verbose: bool = False,
         prefer_system_certs: bool = False,
-        client_cert: dict[str, str | None] = None,
+        client_cert: dict[str, str | None] | None = None,
         verify: bool = True,
         legacy_ssl_support: bool = False,
         **_,
@@ -201,8 +236,8 @@ def __init__(
 
         self._logger = logger
         self.headers = headers or {}
-        self.cookiejar = cookiejar if cookiejar is not None else CookieJar()
-        self.timeout = float(timeout or 20)
+        self.cookiejar = cookiejar if cookiejar is not None else YoutubeDLCookieJar()
+        self.timeout = float(timeout or DEFAULT_TIMEOUT)
         self.proxies = proxies or {}
         self.source_address = source_address
         self.verbose = verbose
@@ -223,6 +258,15 @@ def _make_sslcontext(self):
     def _merge_headers(self, request_headers):
         return HTTPHeaderDict(self.headers, request_headers)
 
+    def _calculate_timeout(self, request):
+        return float(request.extensions.get('timeout') or self.timeout)
+
+    def _get_cookiejar(self, request):
+        return request.extensions.get('cookiejar') or self.cookiejar
+
+    def _get_proxies(self, request):
+        return (request.proxies or self.proxies).copy()
+
     def _check_url_scheme(self, request: Request):
         scheme = urllib.parse.urlparse(request.url).scheme.lower()
         if self._SUPPORTED_URL_SCHEMES is not None and scheme not in self._SUPPORTED_URL_SCHEMES:
@@ -254,34 +298,31 @@ def _check_proxies(self, proxies):
                 # Skip proxy scheme checks
                 continue
 
-            # Scheme-less proxies are not supported
-            if urllib.request._parse_proxy(proxy_url)[0] is None:
-                raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
+            try:
+                if urllib.request._parse_proxy(proxy_url)[0] is None:
+                    # Scheme-less proxies are not supported
+                    raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
+            except ValueError as e:
+                # parse_proxy may raise on some invalid proxy urls such as "/a/b/c"
+                raise UnsupportedRequest(f'Invalid proxy url "{proxy_url}": {e}')
 
             scheme = urllib.parse.urlparse(proxy_url).scheme.lower()
             if scheme not in self._SUPPORTED_PROXY_SCHEMES:
                 raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
 
-    def _check_cookiejar_extension(self, extensions):
-        if not extensions.get('cookiejar'):
-            return
-        if not isinstance(extensions['cookiejar'], CookieJar):
-            raise UnsupportedRequest('cookiejar is not a CookieJar')
-
-    def _check_timeout_extension(self, extensions):
-        if extensions.get('timeout') is None:
-            return
-        if not isinstance(extensions['timeout'], (float, int)):
-            raise UnsupportedRequest('timeout is not a float or int')
-
     def _check_extensions(self, extensions):
-        self._check_cookiejar_extension(extensions)
-        self._check_timeout_extension(extensions)
+        """Check extensions for unsupported extensions. Subclasses should extend this."""
+        assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
+        assert isinstance(extensions.get('timeout'), (float, int, NoneType))
 
     def _validate(self, request):
         self._check_url_scheme(request)
         self._check_proxies(request.proxies or self.proxies)
-        self._check_extensions(request.extensions)
+        extensions = request.extensions.copy()
+        self._check_extensions(extensions)
+        if extensions:
+            # TODO: add support for optional extensions
+            raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}')
 
     @wrap_request_errors
     def validate(self, request: Request):
@@ -298,8 +339,9 @@ def send(self, request: Request) -> Response:
     @abc.abstractmethod
     def _send(self, request: Request):
         """Handle a request from start to finish. Redefine in subclasses."""
+        pass
 
-    def close(self):
+    def close(self):  # noqa: B027
         pass
 
     @classproperty
@@ -336,11 +378,11 @@ def __init__(
             self,
             url: str,
             data: RequestData = None,
-            headers: typing.Mapping = None,
-            proxies: dict = None,
-            query: dict = None,
-            method: str = None,
-            extensions: dict = None
+            headers: typing.Mapping | None = None,
+            proxies: dict | None = None,
+            query: dict | None = None,
+            method: str | None = None,
+            extensions: dict | None = None,
     ):
 
         self._headers = HTTPHeaderDict()
@@ -367,7 +409,7 @@ def url(self, url):
             raise TypeError('url must be a string')
         elif url.startswith('//'):
             url = 'http:' + url
-        self._url = escape_url(url)
+        self._url = normalize_url(url)
 
     @property
     def method(self):
@@ -415,7 +457,7 @@ def headers(self) -> HTTPHeaderDict:
 
     @headers.setter
     def headers(self, new_headers: Mapping):
-        """Replaces headers of the request. If not a CaseInsensitiveDict, it will be converted to one."""
+        """Replaces headers of the request. If not a HTTPHeaderDict, it will be converted to one."""
         if isinstance(new_headers, HTTPHeaderDict):
             self._headers = new_headers
         elif isinstance(new_headers, Mapping):
@@ -423,9 +465,10 @@ def headers(self, new_headers: Mapping):
         else:
             raise TypeError('headers must be a mapping')
 
-    def update(self, url=None, data=None, headers=None, query=None):
-        self.data = data or self.data
+    def update(self, url=None, data=None, headers=None, query=None, extensions=None):
+        self.data = data if data is not None else self.data
         self.headers.update(headers or {})
+        self.extensions.update(extensions or {})
         self.url = update_url_query(url or self.url, query or {})
 
     def copy(self):
@@ -456,15 +499,18 @@ class Response(io.IOBase):
     @param headers: response headers.
     @param status: Response HTTP status code. Default is 200 OK.
     @param reason: HTTP status reason. Will use built-in reasons based on status code if not provided.
+    @param extensions: Dictionary of handler-specific response extensions.
     """
 
     def __init__(
             self,
-            fp: typing.IO,
+            fp: io.IOBase,
             url: str,
             headers: Mapping[str, str],
             status: int = 200,
-            reason: str = None):
+            reason: str | None = None,
+            extensions: dict | None = None,
+    ):
 
         self.fp = fp
         self.headers = Message()
@@ -476,11 +522,12 @@ def __init__(
             self.reason = reason or HTTPStatus(status).phrase
         except ValueError:
             self.reason = None
+        self.extensions = extensions or {}
 
     def readable(self):
         return self.fp.readable()
 
-    def read(self, amt: int = None) -> bytes:
+    def read(self, amt: int | None = None) -> bytes:
         # Expected errors raised here should be of type RequestError or subclasses.
         # Subclasses should redefine this method with more precise error handling.
         try:
@@ -507,16 +554,28 @@ def get_header(self, name, default=None):
     # The following methods are for compatability reasons and are deprecated
     @property
     def code(self):
+        deprecation_warning('Response.code is deprecated, use Response.status', stacklevel=2)
         return self.status
 
     def getcode(self):
+        deprecation_warning('Response.getcode() is deprecated, use Response.status', stacklevel=2)
         return self.status
 
     def geturl(self):
+        deprecation_warning('Response.geturl() is deprecated, use Response.url', stacklevel=2)
         return self.url
 
     def info(self):
+        deprecation_warning('Response.info() is deprecated, use Response.headers', stacklevel=2)
         return self.headers
 
     def getheader(self, name, default=None):
+        deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
         return self.get_header(name, default)
+
+
+if typing.TYPE_CHECKING:
+    RequestData = bytes | Iterable[bytes] | typing.IO | None
+    Preference = typing.Callable[[RequestHandler, Request], int]
+
+_RH_PREFERENCES: set[Preference] = set()