)
from ..utils.networking import HTTPHeaderDict, normalize_url
-if typing.TYPE_CHECKING:
- RequestData = bytes | Iterable[bytes] | typing.IO | None
+
+def register_preference(*handlers: type[RequestHandler]):
+ assert all(issubclass(handler, RequestHandler) for handler in handlers)
+
+ def outer(preference: Preference):
+ @functools.wraps(preference)
+ def inner(handler, *args, **kwargs):
+ if not handlers or isinstance(handler, handlers):
+ return preference(handler, *args, **kwargs)
+ return 0
+ _RH_PREFERENCES.add(inner)
+ return inner
+ return outer
class RequestDirector:
Helper class that, when given a request, forward it to a RequestHandler that supports it.
+ Preference functions in the form of func(handler, request) -> int
+ can be registered into the `preferences` set. These are used to sort handlers
+ in order of preference.
+
@param logger: Logger instance.
@param verbose: Print debug request information to stdout.
"""
def __init__(self, logger, verbose=False):
self.handlers: dict[str, RequestHandler] = {}
+ self.preferences: set[Preference] = set()
self.logger = logger # TODO(Grub4k): default logger
self.verbose = verbose
def close(self):
for handler in self.handlers.values():
handler.close()
+ self.handlers.clear()
def add_handler(self, handler: RequestHandler):
"""Add a handler. If a handler of the same RH_KEY exists, it will overwrite it"""
assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
self.handlers[handler.RH_KEY] = handler
+ def _get_handlers(self, request: Request) -> list[RequestHandler]:
+ """Sorts handlers by preference, given a request"""
+ preferences = {
+ rh: sum(pref(rh, request) for pref in self.preferences)
+ for rh in self.handlers.values()
+ }
+ self._print_verbose('Handler preferences for this request: %s' % ', '.join(
+ f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items()))
+ return sorted(self.handlers.values(), key=preferences.get, reverse=True)
+
def _print_verbose(self, msg):
if self.verbose:
self.logger.stdout(f'director: {msg}')
unexpected_errors = []
unsupported_errors = []
- # TODO (future): add a per-request preference system
- for handler in reversed(list(self.handlers.values())):
+ for handler in self._get_handlers(request):
self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
try:
handler.validate(request)
def _merge_headers(self, request_headers):
return HTTPHeaderDict(self.headers, request_headers)
+ def _calculate_timeout(self, request):
+ return float(request.extensions.get('timeout') or self.timeout)
+
+ def _get_cookiejar(self, request):
+ return request.extensions.get('cookiejar') or self.cookiejar
+
+ def _get_proxies(self, request):
+ return (request.proxies or self.proxies).copy()
+
def _check_url_scheme(self, request: Request):
scheme = urllib.parse.urlparse(request.url).scheme.lower()
if self._SUPPORTED_URL_SCHEMES is not None and scheme not in self._SUPPORTED_URL_SCHEMES:
@headers.setter
def headers(self, new_headers: Mapping):
- """Replaces headers of the request. If not a CaseInsensitiveDict, it will be converted to one."""
+ """Replaces headers of the request. If not a HTTPHeaderDict, it will be converted to one."""
if isinstance(new_headers, HTTPHeaderDict):
self._headers = new_headers
elif isinstance(new_headers, Mapping):
else:
raise TypeError('headers must be a mapping')
- def update(self, url=None, data=None, headers=None, query=None):
+ def update(self, url=None, data=None, headers=None, query=None, extensions=None):
self.data = data if data is not None else self.data
self.headers.update(headers or {})
+ self.extensions.update(extensions or {})
self.url = update_url_query(url or self.url, query or {})
def copy(self):
@param headers: response headers.
@param status: Response HTTP status code. Default is 200 OK.
@param reason: HTTP status reason. Will use built-in reasons based on status code if not provided.
+ @param extensions: Dictionary of handler-specific response extensions.
"""
def __init__(
self,
- fp: typing.IO,
+ fp: io.IOBase,
url: str,
headers: Mapping[str, str],
status: int = 200,
- reason: str = None):
+ reason: str = None,
+ extensions: dict = None
+ ):
self.fp = fp
self.headers = Message()
self.reason = reason or HTTPStatus(status).phrase
except ValueError:
self.reason = None
+ self.extensions = extensions or {}
def readable(self):
return self.fp.readable()
def getheader(self, name, default=None):
deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
return self.get_header(name, default)
+
+
+if typing.TYPE_CHECKING:
+ RequestData = bytes | Iterable[bytes] | typing.IO | None
+ Preference = typing.Callable[[RequestHandler, Request], int]
+
+_RH_PREFERENCES: set[Preference] = set()