from collections.abc import Iterable, Mapping
from email.message import Message
from http import HTTPStatus
-from http.cookiejar import CookieJar
from ._helper import make_ssl_context, wrap_request_errors
from .exceptions import (
UnsupportedRequest,
)
from ..compat.types import NoneType
+from ..cookies import YoutubeDLCookieJar
from ..utils import (
bug_reports_message,
classproperty,
deprecation_warning,
error_to_str,
- escape_url,
update_url_query,
)
-from ..utils.networking import HTTPHeaderDict
+from ..utils.networking import HTTPHeaderDict, normalize_url
-if typing.TYPE_CHECKING:
- RequestData = bytes | Iterable[bytes] | typing.IO | None
+
+def register_preference(*handlers: type[RequestHandler]):
+ assert all(issubclass(handler, RequestHandler) for handler in handlers)
+
+ def outer(preference: Preference):
+ @functools.wraps(preference)
+ def inner(handler, *args, **kwargs):
+ if not handlers or isinstance(handler, handlers):
+ return preference(handler, *args, **kwargs)
+ return 0
+ _RH_PREFERENCES.add(inner)
+ return inner
+ return outer
class RequestDirector:
Helper class that, when given a request, forward it to a RequestHandler that supports it.
+ Preference functions in the form of func(handler, request) -> int
+ can be registered into the `preferences` set. These are used to sort handlers
+ in order of preference.
+
@param logger: Logger instance.
@param verbose: Print debug request information to stdout.
"""
def __init__(self, logger, verbose=False):
self.handlers: dict[str, RequestHandler] = {}
+ self.preferences: set[Preference] = set()
self.logger = logger # TODO(Grub4k): default logger
self.verbose = verbose
assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
self.handlers[handler.RH_KEY] = handler
+ def _get_handlers(self, request: Request) -> list[RequestHandler]:
+ """Sorts handlers by preference, given a request"""
+ preferences = {
+ rh: sum(pref(rh, request) for pref in self.preferences)
+ for rh in self.handlers.values()
+ }
+ self._print_verbose('Handler preferences for this request: %s' % ', '.join(
+ f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items()))
+ return sorted(self.handlers.values(), key=preferences.get, reverse=True)
+
def _print_verbose(self, msg):
if self.verbose:
self.logger.stdout(f'director: {msg}')
unexpected_errors = []
unsupported_errors = []
- # TODO (future): add a per-request preference system
- for handler in reversed(list(self.handlers.values())):
+ for handler in self._get_handlers(request):
self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
try:
handler.validate(request)
self, *,
logger, # TODO(Grub4k): default logger
headers: HTTPHeaderDict = None,
- cookiejar: CookieJar = None,
+ cookiejar: YoutubeDLCookieJar = None,
timeout: float | int | None = None,
proxies: dict = None,
source_address: str = None,
self._logger = logger
self.headers = headers or {}
- self.cookiejar = cookiejar if cookiejar is not None else CookieJar()
+ self.cookiejar = cookiejar if cookiejar is not None else YoutubeDLCookieJar()
self.timeout = float(timeout or 20)
self.proxies = proxies or {}
self.source_address = source_address
# Skip proxy scheme checks
continue
- # Scheme-less proxies are not supported
- if urllib.request._parse_proxy(proxy_url)[0] is None:
- raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
+ try:
+ if urllib.request._parse_proxy(proxy_url)[0] is None:
+ # Scheme-less proxies are not supported
+ raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
+ except ValueError as e:
+ # parse_proxy may raise on some invalid proxy urls such as "/a/b/c"
+ raise UnsupportedRequest(f'Invalid proxy url "{proxy_url}": {e}')
scheme = urllib.parse.urlparse(proxy_url).scheme.lower()
if scheme not in self._SUPPORTED_PROXY_SCHEMES:
def _check_extensions(self, extensions):
"""Check extensions for unsupported extensions. Subclasses should extend this."""
- assert isinstance(extensions.get('cookiejar'), (CookieJar, NoneType))
+ assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
assert isinstance(extensions.get('timeout'), (float, int, NoneType))
def _validate(self, request):
@abc.abstractmethod
def _send(self, request: Request):
"""Handle a request from start to finish. Redefine in subclasses."""
+ pass
def close(self):
pass
raise TypeError('url must be a string')
elif url.startswith('//'):
url = 'http:' + url
- self._url = escape_url(url)
+ self._url = normalize_url(url)
@property
def method(self):
def getheader(self, name, default=None):
deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
return self.get_header(name, default)
+
+
+if typing.TYPE_CHECKING:
+ RequestData = bytes | Iterable[bytes] | typing.IO | None
+ Preference = typing.Callable[[RequestHandler, Request], int]
+
+_RH_PREFERENCES: set[Preference] = set()