]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/networking/common.py
[ie/cbc.ca:player] Improve `_VALID_URL` (#9866)
[yt-dlp.git] / yt_dlp / networking / common.py
index 8fba8c1c5ae86c07c159e68eef23b3718f6da85a..a2217034c90cb3812c47dc94cd57087b42feae82 100644 (file)
 )
 from ..utils.networking import HTTPHeaderDict, normalize_url
 
-if typing.TYPE_CHECKING:
-    RequestData = bytes | Iterable[bytes] | typing.IO | None
+
+def register_preference(*handlers: type[RequestHandler]):
+    assert all(issubclass(handler, RequestHandler) for handler in handlers)
+
+    def outer(preference: Preference):
+        @functools.wraps(preference)
+        def inner(handler, *args, **kwargs):
+            if not handlers or isinstance(handler, handlers):
+                return preference(handler, *args, **kwargs)
+            return 0
+        _RH_PREFERENCES.add(inner)
+        return inner
+    return outer
 
 
 class RequestDirector:
@@ -40,24 +51,40 @@ class RequestDirector:
 
     Helper class that, when given a request, forward it to a RequestHandler that supports it.
 
+    Preference functions in the form of func(handler, request) -> int
+    can be registered into the `preferences` set. These are used to sort handlers
+    in order of preference.
+
     @param logger: Logger instance.
     @param verbose: Print debug request information to stdout.
     """
 
     def __init__(self, logger, verbose=False):
         self.handlers: dict[str, RequestHandler] = {}
+        self.preferences: set[Preference] = set()
         self.logger = logger  # TODO(Grub4k): default logger
         self.verbose = verbose
 
     def close(self):
         for handler in self.handlers.values():
             handler.close()
+        self.handlers.clear()
 
     def add_handler(self, handler: RequestHandler):
         """Add a handler. If a handler of the same RH_KEY exists, it will overwrite it"""
         assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
         self.handlers[handler.RH_KEY] = handler
 
+    def _get_handlers(self, request: Request) -> list[RequestHandler]:
+        """Sorts handlers by preference, given a request"""
+        preferences = {
+            rh: sum(pref(rh, request) for pref in self.preferences)
+            for rh in self.handlers.values()
+        }
+        self._print_verbose('Handler preferences for this request: %s' % ', '.join(
+            f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items()))
+        return sorted(self.handlers.values(), key=preferences.get, reverse=True)
+
     def _print_verbose(self, msg):
         if self.verbose:
             self.logger.stdout(f'director: {msg}')
@@ -73,8 +100,7 @@ def send(self, request: Request) -> Response:
 
         unexpected_errors = []
         unsupported_errors = []
-        # TODO (future): add a per-request preference system
-        for handler in reversed(list(self.handlers.values())):
+        for handler in self._get_handlers(request):
             self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
             try:
                 handler.validate(request)
@@ -230,6 +256,15 @@ def _make_sslcontext(self):
     def _merge_headers(self, request_headers):
         return HTTPHeaderDict(self.headers, request_headers)
 
+    def _calculate_timeout(self, request):
+        return float(request.extensions.get('timeout') or self.timeout)
+
+    def _get_cookiejar(self, request):
+        return request.extensions.get('cookiejar') or self.cookiejar
+
+    def _get_proxies(self, request):
+        return (request.proxies or self.proxies).copy()
+
     def _check_url_scheme(self, request: Request):
         scheme = urllib.parse.urlparse(request.url).scheme.lower()
         if self._SUPPORTED_URL_SCHEMES is not None and scheme not in self._SUPPORTED_URL_SCHEMES:
@@ -420,7 +455,7 @@ def headers(self) -> HTTPHeaderDict:
 
     @headers.setter
     def headers(self, new_headers: Mapping):
-        """Replaces headers of the request. If not a CaseInsensitiveDict, it will be converted to one."""
+        """Replaces headers of the request. If not a HTTPHeaderDict, it will be converted to one."""
         if isinstance(new_headers, HTTPHeaderDict):
             self._headers = new_headers
         elif isinstance(new_headers, Mapping):
@@ -428,9 +463,10 @@ def headers(self, new_headers: Mapping):
         else:
             raise TypeError('headers must be a mapping')
 
-    def update(self, url=None, data=None, headers=None, query=None):
+    def update(self, url=None, data=None, headers=None, query=None, extensions=None):
         self.data = data if data is not None else self.data
         self.headers.update(headers or {})
+        self.extensions.update(extensions or {})
         self.url = update_url_query(url or self.url, query or {})
 
     def copy(self):
@@ -461,15 +497,18 @@ class Response(io.IOBase):
     @param headers: response headers.
     @param status: Response HTTP status code. Default is 200 OK.
     @param reason: HTTP status reason. Will use built-in reasons based on status code if not provided.
+    @param extensions: Dictionary of handler-specific response extensions.
     """
 
     def __init__(
             self,
-            fp: typing.IO,
+            fp: io.IOBase,
             url: str,
             headers: Mapping[str, str],
             status: int = 200,
-            reason: str = None):
+            reason: str = None,
+            extensions: dict = None
+    ):
 
         self.fp = fp
         self.headers = Message()
@@ -481,6 +520,7 @@ def __init__(
             self.reason = reason or HTTPStatus(status).phrase
         except ValueError:
             self.reason = None
+        self.extensions = extensions or {}
 
     def readable(self):
         return self.fp.readable()
@@ -530,3 +570,10 @@ def info(self):
     def getheader(self, name, default=None):
         deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
         return self.get_header(name, default)
+
+
+if typing.TYPE_CHECKING:
+    RequestData = bytes | Iterable[bytes] | typing.IO | None
+    Preference = typing.Callable[[RequestHandler, Request], int]
+
+_RH_PREFERENCES: set[Preference] = set()