]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/networking/common.py
[networking] Fix various socks proxy bugs (#8065)
[yt-dlp.git] / yt_dlp / networking / common.py
index 8fba8c1c5ae86c07c159e68eef23b3718f6da85a..584c7bb4db4b6792f484c6b95376402cd548c21f 100644 (file)
 )
 from ..utils.networking import HTTPHeaderDict, normalize_url
 
-if typing.TYPE_CHECKING:
-    RequestData = bytes | Iterable[bytes] | typing.IO | None
+
+def register_preference(*handlers: type[RequestHandler]):
+    assert all(issubclass(handler, RequestHandler) for handler in handlers)
+
+    def outer(preference: Preference):
+        @functools.wraps(preference)
+        def inner(handler, *args, **kwargs):
+            if not handlers or isinstance(handler, handlers):
+                return preference(handler, *args, **kwargs)
+            return 0
+        _RH_PREFERENCES.add(inner)
+        return inner
+    return outer
 
 
 class RequestDirector:
@@ -40,12 +51,17 @@ class RequestDirector:
 
     Helper class that, when given a request, forward it to a RequestHandler that supports it.
 
+    Preference functions in the form of func(handler, request) -> int
+    can be registered into the `preferences` set. These are used to sort handlers
+    in order of preference.
+
     @param logger: Logger instance.
     @param verbose: Print debug request information to stdout.
     """
 
     def __init__(self, logger, verbose=False):
         self.handlers: dict[str, RequestHandler] = {}
+        self.preferences: set[Preference] = set()
         self.logger = logger  # TODO(Grub4k): default logger
         self.verbose = verbose
 
@@ -58,6 +74,16 @@ def add_handler(self, handler: RequestHandler):
         assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
         self.handlers[handler.RH_KEY] = handler
 
+    def _get_handlers(self, request: Request) -> list[RequestHandler]:
+        """Sorts handlers by preference, given a request"""
+        preferences = {
+            rh: sum(pref(rh, request) for pref in self.preferences)
+            for rh in self.handlers.values()
+        }
+        self._print_verbose('Handler preferences for this request: %s' % ', '.join(
+            f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items()))
+        return sorted(self.handlers.values(), key=preferences.get, reverse=True)
+
     def _print_verbose(self, msg):
         if self.verbose:
             self.logger.stdout(f'director: {msg}')
@@ -73,8 +99,7 @@ def send(self, request: Request) -> Response:
 
         unexpected_errors = []
         unsupported_errors = []
-        # TODO (future): add a per-request preference system
-        for handler in reversed(list(self.handlers.values())):
+        for handler in self._get_handlers(request):
             self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
             try:
                 handler.validate(request)
@@ -530,3 +555,10 @@ def info(self):
     def getheader(self, name, default=None):
         deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
         return self.get_header(name, default)
+
+
+if typing.TYPE_CHECKING:
+    RequestData = bytes | Iterable[bytes] | typing.IO | None
+    Preference = typing.Callable[[RequestHandler, Request], int]
+
+_RH_PREFERENCES: set[Preference] = set()