from collections.abc import Iterable, Mapping
from email.message import Message
from http import HTTPStatus
-from http.cookiejar import CookieJar
from ._helper import make_ssl_context, wrap_request_errors
from .exceptions import (
TransportError,
UnsupportedRequest,
)
+from ..compat.types import NoneType
+from ..cookies import YoutubeDLCookieJar
from ..utils import (
bug_reports_message,
classproperty,
deprecation_warning,
error_to_str,
- escape_url,
update_url_query,
)
-from ..utils.networking import HTTPHeaderDict
+from ..utils.networking import HTTPHeaderDict, normalize_url
-if typing.TYPE_CHECKING:
- RequestData = bytes | Iterable[bytes] | typing.IO | None
+
+def register_preference(*handlers: type[RequestHandler]):
+ assert all(issubclass(handler, RequestHandler) for handler in handlers)
+
+ def outer(preference: Preference):
+ @functools.wraps(preference)
+ def inner(handler, *args, **kwargs):
+ if not handlers or isinstance(handler, handlers):
+ return preference(handler, *args, **kwargs)
+ return 0
+ _RH_PREFERENCES.add(inner)
+ return inner
+ return outer
class RequestDirector:
Helper class that, when given a request, forward it to a RequestHandler that supports it.
+ Preference functions in the form of func(handler, request) -> int
+ can be registered into the `preferences` set. These are used to sort handlers
+ in order of preference.
+
@param logger: Logger instance.
@param verbose: Print debug request information to stdout.
"""
def __init__(self, logger, verbose=False):
self.handlers: dict[str, RequestHandler] = {}
+ self.preferences: set[Preference] = set()
self.logger = logger # TODO(Grub4k): default logger
self.verbose = verbose
assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
self.handlers[handler.RH_KEY] = handler
+ def _get_handlers(self, request: Request) -> list[RequestHandler]:
+ """Sorts handlers by preference, given a request"""
+ preferences = {
+ rh: sum(pref(rh, request) for pref in self.preferences)
+ for rh in self.handlers.values()
+ }
+ self._print_verbose('Handler preferences for this request: %s' % ', '.join(
+ f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items()))
+ return sorted(self.handlers.values(), key=preferences.get, reverse=True)
+
def _print_verbose(self, msg):
if self.verbose:
self.logger.stdout(f'director: {msg}')
unexpected_errors = []
unsupported_errors = []
- # TODO (future): add a per-request preference system
- for handler in reversed(list(self.handlers.values())):
+ for handler in self._get_handlers(request):
self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
try:
handler.validate(request)
_REQUEST_HANDLERS = {}
-def register(handler):
+def register_rh(handler):
"""Register a RequestHandler class"""
assert issubclass(handler, RequestHandler), f'{handler} must be a subclass of RequestHandler'
assert handler.RH_KEY not in _REQUEST_HANDLERS, f'RequestHandler {handler.RH_KEY} already registered'
a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
- `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
+
The above may be set to None to disable the checks.
Parameters:
Requests may have additional optional parameters defined as extensions.
RequestHandler subclasses may choose to support custom extensions.
+ If an extension is supported, subclasses should extend _check_extensions(extensions)
+ to pop and validate the extension.
+ - Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised.
+
The following extensions are defined for RequestHandler:
- - `cookiejar`: Cookiejar to use for this request
- - `timeout`: socket timeout to use for this request
+ - `cookiejar`: Cookiejar to use for this request.
+ - `timeout`: socket timeout to use for this request.
+ To enable these, add extensions.pop('<extension>', None) to _check_extensions
Apart from the url protocol, proxies dict may contain the following keys:
- `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
self, *,
logger, # TODO(Grub4k): default logger
headers: HTTPHeaderDict = None,
- cookiejar: CookieJar = None,
+ cookiejar: YoutubeDLCookieJar = None,
timeout: float | int | None = None,
proxies: dict = None,
source_address: str = None,
self._logger = logger
self.headers = headers or {}
- self.cookiejar = cookiejar if cookiejar is not None else CookieJar()
+ self.cookiejar = cookiejar if cookiejar is not None else YoutubeDLCookieJar()
self.timeout = float(timeout or 20)
self.proxies = proxies or {}
self.source_address = source_address
# Skip proxy scheme checks
continue
- # Scheme-less proxies are not supported
- if urllib.request._parse_proxy(proxy_url)[0] is None:
- raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
+ try:
+ if urllib.request._parse_proxy(proxy_url)[0] is None:
+ # Scheme-less proxies are not supported
+ raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
+ except ValueError as e:
+ # parse_proxy may raise on some invalid proxy urls such as "/a/b/c"
+ raise UnsupportedRequest(f'Invalid proxy url "{proxy_url}": {e}')
scheme = urllib.parse.urlparse(proxy_url).scheme.lower()
if scheme not in self._SUPPORTED_PROXY_SCHEMES:
raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
- def _check_cookiejar_extension(self, extensions):
- if not extensions.get('cookiejar'):
- return
- if not isinstance(extensions['cookiejar'], CookieJar):
- raise UnsupportedRequest('cookiejar is not a CookieJar')
-
- def _check_timeout_extension(self, extensions):
- if extensions.get('timeout') is None:
- return
- if not isinstance(extensions['timeout'], (float, int)):
- raise UnsupportedRequest('timeout is not a float or int')
-
def _check_extensions(self, extensions):
- self._check_cookiejar_extension(extensions)
- self._check_timeout_extension(extensions)
+ """Check extensions for unsupported extensions. Subclasses should extend this."""
+ assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
+ assert isinstance(extensions.get('timeout'), (float, int, NoneType))
def _validate(self, request):
self._check_url_scheme(request)
self._check_proxies(request.proxies or self.proxies)
- self._check_extensions(request.extensions)
+ extensions = request.extensions.copy()
+ self._check_extensions(extensions)
+ if extensions:
+ # TODO: add support for optional extensions
+ raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}')
@wrap_request_errors
def validate(self, request: Request):
@abc.abstractmethod
def _send(self, request: Request):
"""Handle a request from start to finish. Redefine in subclasses."""
+ pass
def close(self):
pass
raise TypeError('url must be a string')
elif url.startswith('//'):
url = 'http:' + url
- self._url = escape_url(url)
+ self._url = normalize_url(url)
@property
def method(self):
raise TypeError('headers must be a mapping')
def update(self, url=None, data=None, headers=None, query=None):
- self.data = data or self.data
+ self.data = data if data is not None else self.data
self.headers.update(headers or {})
self.url = update_url_query(url or self.url, query or {})
def getheader(self, name, default=None):
deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
return self.get_header(name, default)
+
+
+if typing.TYPE_CHECKING:
+ RequestData = bytes | Iterable[bytes] | typing.IO | None
+ Preference = typing.Callable[[RequestHandler, Request], int]
+
+_RH_PREFERENCES: set[Preference] = set()