from collections.abc import Iterable, Mapping
from email.message import Message
from http import HTTPStatus
-from http.cookiejar import CookieJar
from ._helper import make_ssl_context, wrap_request_errors
from .exceptions import (
TransportError,
UnsupportedRequest,
)
+from ..compat.types import NoneType
+from ..cookies import YoutubeDLCookieJar
from ..utils import (
bug_reports_message,
classproperty,
+ deprecation_warning,
error_to_str,
- escape_url,
update_url_query,
)
-from ..utils.networking import HTTPHeaderDict
+from ..utils.networking import HTTPHeaderDict, normalize_url
+
+DEFAULT_TIMEOUT = 20
-if typing.TYPE_CHECKING:
- RequestData = bytes | Iterable[bytes] | typing.IO | None
+
+def register_preference(*handlers: type[RequestHandler]):
+ assert all(issubclass(handler, RequestHandler) for handler in handlers)
+
+ def outer(preference: Preference):
+ @functools.wraps(preference)
+ def inner(handler, *args, **kwargs):
+ if not handlers or isinstance(handler, handlers):
+ return preference(handler, *args, **kwargs)
+ return 0
+ _RH_PREFERENCES.add(inner)
+ return inner
+ return outer
class RequestDirector:
Helper class that, when given a request, forward it to a RequestHandler that supports it.
+ Preference functions in the form of func(handler, request) -> int
+ can be registered into the `preferences` set. These are used to sort handlers
+ in order of preference.
+
@param logger: Logger instance.
@param verbose: Print debug request information to stdout.
"""
def __init__(self, logger, verbose=False):
self.handlers: dict[str, RequestHandler] = {}
+ self.preferences: set[Preference] = set()
self.logger = logger # TODO(Grub4k): default logger
self.verbose = verbose
def close(self):
for handler in self.handlers.values():
handler.close()
+ self.handlers.clear()
def add_handler(self, handler: RequestHandler):
"""Add a handler. If a handler of the same RH_KEY exists, it will overwrite it"""
assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
self.handlers[handler.RH_KEY] = handler
+ def _get_handlers(self, request: Request) -> list[RequestHandler]:
+ """Sorts handlers by preference, given a request"""
+ preferences = {
+ rh: sum(pref(rh, request) for pref in self.preferences)
+ for rh in self.handlers.values()
+ }
+ self._print_verbose('Handler preferences for this request: {}'.format(', '.join(
+ f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items())))
+ return sorted(self.handlers.values(), key=preferences.get, reverse=True)
+
def _print_verbose(self, msg):
if self.verbose:
self.logger.stdout(f'director: {msg}')
unexpected_errors = []
unsupported_errors = []
- # TODO (future): add a per-request preference system
- for handler in reversed(list(self.handlers.values())):
+ for handler in self._get_handlers(request):
self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
try:
handler.validate(request)
_REQUEST_HANDLERS = {}
-def register(handler):
+def register_rh(handler):
"""Register a RequestHandler class"""
assert issubclass(handler, RequestHandler), f'{handler} must be a subclass of RequestHandler'
assert handler.RH_KEY not in _REQUEST_HANDLERS, f'RequestHandler {handler.RH_KEY} already registered'
a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
- `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
+
The above may be set to None to disable the checks.
Parameters:
Requests may have additional optional parameters defined as extensions.
RequestHandler subclasses may choose to support custom extensions.
+ If an extension is supported, subclasses should extend _check_extensions(extensions)
+ to pop and validate the extension.
+ - Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised.
+
The following extensions are defined for RequestHandler:
- - `cookiejar`: Cookiejar to use for this request
- - `timeout`: socket timeout to use for this request
+ - `cookiejar`: Cookiejar to use for this request.
+ - `timeout`: socket timeout to use for this request.
+ To enable these, add extensions.pop('<extension>', None) to _check_extensions
Apart from the url protocol, proxies dict may contain the following keys:
- `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
self, *,
logger, # TODO(Grub4k): default logger
headers: HTTPHeaderDict = None,
- cookiejar: CookieJar = None,
+ cookiejar: YoutubeDLCookieJar = None,
timeout: float | int | None = None,
- proxies: dict = None,
- source_address: str = None,
+ proxies: dict | None = None,
+ source_address: str | None = None,
verbose: bool = False,
prefer_system_certs: bool = False,
- client_cert: dict[str, str | None] = None,
+ client_cert: dict[str, str | None] | None = None,
verify: bool = True,
legacy_ssl_support: bool = False,
**_,
self._logger = logger
self.headers = headers or {}
- self.cookiejar = cookiejar if cookiejar is not None else CookieJar()
- self.timeout = float(timeout or 20)
+ self.cookiejar = cookiejar if cookiejar is not None else YoutubeDLCookieJar()
+ self.timeout = float(timeout or DEFAULT_TIMEOUT)
self.proxies = proxies or {}
self.source_address = source_address
self.verbose = verbose
def _merge_headers(self, request_headers):
return HTTPHeaderDict(self.headers, request_headers)
+ def _calculate_timeout(self, request):
+ return float(request.extensions.get('timeout') or self.timeout)
+
+ def _get_cookiejar(self, request):
+ return request.extensions.get('cookiejar') or self.cookiejar
+
+ def _get_proxies(self, request):
+ return (request.proxies or self.proxies).copy()
+
def _check_url_scheme(self, request: Request):
scheme = urllib.parse.urlparse(request.url).scheme.lower()
if self._SUPPORTED_URL_SCHEMES is not None and scheme not in self._SUPPORTED_URL_SCHEMES:
# Skip proxy scheme checks
continue
- # Scheme-less proxies are not supported
- if urllib.request._parse_proxy(proxy_url)[0] is None:
- raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
+ try:
+ if urllib.request._parse_proxy(proxy_url)[0] is None:
+ # Scheme-less proxies are not supported
+ raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
+ except ValueError as e:
+ # parse_proxy may raise on some invalid proxy urls such as "/a/b/c"
+ raise UnsupportedRequest(f'Invalid proxy url "{proxy_url}": {e}')
scheme = urllib.parse.urlparse(proxy_url).scheme.lower()
if scheme not in self._SUPPORTED_PROXY_SCHEMES:
raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
- def _check_cookiejar_extension(self, extensions):
- if not extensions.get('cookiejar'):
- return
- if not isinstance(extensions['cookiejar'], CookieJar):
- raise UnsupportedRequest('cookiejar is not a CookieJar')
-
- def _check_timeout_extension(self, extensions):
- if extensions.get('timeout') is None:
- return
- if not isinstance(extensions['timeout'], (float, int)):
- raise UnsupportedRequest('timeout is not a float or int')
-
def _check_extensions(self, extensions):
- self._check_cookiejar_extension(extensions)
- self._check_timeout_extension(extensions)
+ """Check extensions for unsupported extensions. Subclasses should extend this."""
+ assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
+ assert isinstance(extensions.get('timeout'), (float, int, NoneType))
def _validate(self, request):
self._check_url_scheme(request)
self._check_proxies(request.proxies or self.proxies)
- self._check_extensions(request.extensions)
+ extensions = request.extensions.copy()
+ self._check_extensions(extensions)
+ if extensions:
+ # TODO: add support for optional extensions
+ raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}')
@wrap_request_errors
def validate(self, request: Request):
@abc.abstractmethod
def _send(self, request: Request):
"""Handle a request from start to finish. Redefine in subclasses."""
+ pass
- def close(self):
+ def close(self): # noqa: B027
pass
@classproperty
self,
url: str,
data: RequestData = None,
- headers: typing.Mapping = None,
- proxies: dict = None,
- query: dict = None,
- method: str = None,
- extensions: dict = None
+ headers: typing.Mapping | None = None,
+ proxies: dict | None = None,
+ query: dict | None = None,
+ method: str | None = None,
+ extensions: dict | None = None,
):
self._headers = HTTPHeaderDict()
raise TypeError('url must be a string')
elif url.startswith('//'):
url = 'http:' + url
- self._url = escape_url(url)
+ self._url = normalize_url(url)
@property
def method(self):
@headers.setter
def headers(self, new_headers: Mapping):
- """Replaces headers of the request. If not a CaseInsensitiveDict, it will be converted to one."""
+ """Replaces headers of the request. If not a HTTPHeaderDict, it will be converted to one."""
if isinstance(new_headers, HTTPHeaderDict):
self._headers = new_headers
elif isinstance(new_headers, Mapping):
else:
raise TypeError('headers must be a mapping')
- def update(self, url=None, data=None, headers=None, query=None):
- self.data = data or self.data
+ def update(self, url=None, data=None, headers=None, query=None, extensions=None):
+ self.data = data if data is not None else self.data
self.headers.update(headers or {})
+ self.extensions.update(extensions or {})
self.url = update_url_query(url or self.url, query or {})
def copy(self):
@param headers: response headers.
@param status: Response HTTP status code. Default is 200 OK.
@param reason: HTTP status reason. Will use built-in reasons based on status code if not provided.
+ @param extensions: Dictionary of handler-specific response extensions.
"""
def __init__(
self,
- fp: typing.IO,
+ fp: io.IOBase,
url: str,
headers: Mapping[str, str],
status: int = 200,
- reason: str = None):
+ reason: str | None = None,
+ extensions: dict | None = None,
+ ):
self.fp = fp
self.headers = Message()
self.reason = reason or HTTPStatus(status).phrase
except ValueError:
self.reason = None
+ self.extensions = extensions or {}
def readable(self):
return self.fp.readable()
- def read(self, amt: int = None) -> bytes:
+ def read(self, amt: int | None = None) -> bytes:
# Expected errors raised here should be of type RequestError or subclasses.
# Subclasses should redefine this method with more precise error handling.
try:
# The following methods are for compatability reasons and are deprecated
@property
def code(self):
+ deprecation_warning('Response.code is deprecated, use Response.status', stacklevel=2)
return self.status
def getcode(self):
+ deprecation_warning('Response.getcode() is deprecated, use Response.status', stacklevel=2)
return self.status
def geturl(self):
+ deprecation_warning('Response.geturl() is deprecated, use Response.url', stacklevel=2)
return self.url
def info(self):
+ deprecation_warning('Response.info() is deprecated, use Response.headers', stacklevel=2)
return self.headers
def getheader(self, name, default=None):
+ deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
return self.get_header(name, default)
+
+
+if typing.TYPE_CHECKING:
+ RequestData = bytes | Iterable[bytes] | typing.IO | None
+ Preference = typing.Callable[[RequestHandler, Request], int]
+
+_RH_PREFERENCES: set[Preference] = set()