]> jfr.im git - yt-dlp.git/blame - yt_dlp/networking/common.py
[networking] Add strict Request extension checking (#7604)
[yt-dlp.git] / yt_dlp / networking / common.py
CommitLineData
227bf1a3 1from __future__ import annotations
2
3import abc
4import copy
5import enum
6import functools
7import io
8import typing
9import urllib.parse
10import urllib.request
11import urllib.response
12from collections.abc import Iterable, Mapping
13from email.message import Message
14from http import HTTPStatus
15from http.cookiejar import CookieJar
16
17from ._helper import make_ssl_context, wrap_request_errors
18from .exceptions import (
19 NoSupportingHandlers,
20 RequestError,
21 TransportError,
22 UnsupportedRequest,
23)
86aea0d3 24from ..compat.types import NoneType
227bf1a3 25from ..utils import (
26 bug_reports_message,
27 classproperty,
3d2623a8 28 deprecation_warning,
227bf1a3 29 error_to_str,
30 escape_url,
31 update_url_query,
32)
33from ..utils.networking import HTTPHeaderDict
34
35if typing.TYPE_CHECKING:
36 RequestData = bytes | Iterable[bytes] | typing.IO | None
37
38
39class RequestDirector:
40 """RequestDirector class
41
42 Helper class that, when given a request, forward it to a RequestHandler that supports it.
43
44 @param logger: Logger instance.
45 @param verbose: Print debug request information to stdout.
46 """
47
48 def __init__(self, logger, verbose=False):
49 self.handlers: dict[str, RequestHandler] = {}
50 self.logger = logger # TODO(Grub4k): default logger
51 self.verbose = verbose
52
53 def close(self):
54 for handler in self.handlers.values():
55 handler.close()
56
57 def add_handler(self, handler: RequestHandler):
58 """Add a handler. If a handler of the same RH_KEY exists, it will overwrite it"""
59 assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
60 self.handlers[handler.RH_KEY] = handler
61
62 def _print_verbose(self, msg):
63 if self.verbose:
64 self.logger.stdout(f'director: {msg}')
65
66 def send(self, request: Request) -> Response:
67 """
68 Passes a request onto a suitable RequestHandler
69 """
70 if not self.handlers:
71 raise RequestError('No request handlers configured')
72
73 assert isinstance(request, Request)
74
75 unexpected_errors = []
76 unsupported_errors = []
77 # TODO (future): add a per-request preference system
78 for handler in reversed(list(self.handlers.values())):
79 self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
80 try:
81 handler.validate(request)
82 except UnsupportedRequest as e:
83 self._print_verbose(
84 f'"{handler.RH_NAME}" cannot handle this request (reason: {error_to_str(e)})')
85 unsupported_errors.append(e)
86 continue
87
88 self._print_verbose(f'Sending request via "{handler.RH_NAME}"')
89 try:
90 response = handler.send(request)
91 except RequestError:
92 raise
93 except Exception as e:
94 self.logger.error(
95 f'[{handler.RH_NAME}] Unexpected error: {error_to_str(e)}{bug_reports_message()}',
96 is_error=False)
97 unexpected_errors.append(e)
98 continue
99
100 assert isinstance(response, Response)
101 return response
102
103 raise NoSupportingHandlers(unsupported_errors, unexpected_errors)
104
105
106_REQUEST_HANDLERS = {}
107
108
62b5c94c 109def register_rh(handler):
227bf1a3 110 """Register a RequestHandler class"""
111 assert issubclass(handler, RequestHandler), f'{handler} must be a subclass of RequestHandler'
112 assert handler.RH_KEY not in _REQUEST_HANDLERS, f'RequestHandler {handler.RH_KEY} already registered'
113 _REQUEST_HANDLERS[handler.RH_KEY] = handler
114 return handler
115
116
117class Features(enum.Enum):
118 ALL_PROXY = enum.auto()
119 NO_PROXY = enum.auto()
120
121
122class RequestHandler(abc.ABC):
123
124 """Request Handler class
125
126 Request handlers are class that, given a Request,
127 process the request from start to finish and return a Response.
128
129 Concrete subclasses need to redefine the _send(request) method,
130 which handles the underlying request logic and returns a Response.
131
132 RH_NAME class variable may contain a display name for the RequestHandler.
133 By default, this is generated from the class name.
134
135 The concrete request handler MUST have "RH" as the suffix in the class name.
136
137 All exceptions raised by a RequestHandler should be an instance of RequestError.
138 Any other exception raised will be treated as a handler issue.
139
140 If a Request is not supported by the handler, an UnsupportedRequest
141 should be raised with a reason.
142
143 By default, some checks are done on the request in _validate() based on the following class variables:
144 - `_SUPPORTED_URL_SCHEMES`: a tuple of supported url schemes.
145 Any Request with an url scheme not in this list will raise an UnsupportedRequest.
146
147 - `_SUPPORTED_PROXY_SCHEMES`: a tuple of support proxy url schemes. Any Request that contains
148 a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
149
150 - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
86aea0d3 151
227bf1a3 152 The above may be set to None to disable the checks.
153
154 Parameters:
155 @param logger: logger instance
156 @param headers: HTTP Headers to include when sending requests.
157 @param cookiejar: Cookiejar to use for requests.
158 @param timeout: Socket timeout to use when sending requests.
159 @param proxies: Proxies to use for sending requests.
160 @param source_address: Client-side IP address to bind to for requests.
161 @param verbose: Print debug request and traffic information to stdout.
162 @param prefer_system_certs: Whether to prefer system certificates over other means (e.g. certifi).
163 @param client_cert: SSL client certificate configuration.
164 dict with {client_certificate, client_certificate_key, client_certificate_password}
165 @param verify: Verify SSL certificates
166 @param legacy_ssl_support: Enable legacy SSL options such as legacy server connect and older cipher support.
167
168 Some configuration options may be available for individual Requests too. In this case,
169 either the Request configuration option takes precedence or they are merged.
170
171 Requests may have additional optional parameters defined as extensions.
172 RequestHandler subclasses may choose to support custom extensions.
173
86aea0d3 174 If an extension is supported, subclasses should extend _check_extensions(extensions)
175 to pop and validate the extension.
176 - Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised.
177
227bf1a3 178 The following extensions are defined for RequestHandler:
86aea0d3 179 - `cookiejar`: Cookiejar to use for this request.
180 - `timeout`: socket timeout to use for this request.
181 To enable these, add extensions.pop('<extension>', None) to _check_extensions
227bf1a3 182
183 Apart from the url protocol, proxies dict may contain the following keys:
184 - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
185 - `no`: comma seperated list of hostnames (optionally with port) to not use a proxy for.
186 Note: a RequestHandler may not support these, as defined in `_SUPPORTED_FEATURES`.
187
188 """
189
190 _SUPPORTED_URL_SCHEMES = ()
191 _SUPPORTED_PROXY_SCHEMES = ()
192 _SUPPORTED_FEATURES = ()
193
194 def __init__(
195 self, *,
196 logger, # TODO(Grub4k): default logger
197 headers: HTTPHeaderDict = None,
198 cookiejar: CookieJar = None,
199 timeout: float | int | None = None,
200 proxies: dict = None,
201 source_address: str = None,
202 verbose: bool = False,
203 prefer_system_certs: bool = False,
204 client_cert: dict[str, str | None] = None,
205 verify: bool = True,
206 legacy_ssl_support: bool = False,
207 **_,
208 ):
209
210 self._logger = logger
211 self.headers = headers or {}
212 self.cookiejar = cookiejar if cookiejar is not None else CookieJar()
213 self.timeout = float(timeout or 20)
214 self.proxies = proxies or {}
215 self.source_address = source_address
216 self.verbose = verbose
217 self.prefer_system_certs = prefer_system_certs
218 self._client_cert = client_cert or {}
219 self.verify = verify
220 self.legacy_ssl_support = legacy_ssl_support
221 super().__init__()
222
223 def _make_sslcontext(self):
224 return make_ssl_context(
225 verify=self.verify,
226 legacy_support=self.legacy_ssl_support,
227 use_certifi=not self.prefer_system_certs,
228 **self._client_cert,
229 )
230
231 def _merge_headers(self, request_headers):
232 return HTTPHeaderDict(self.headers, request_headers)
233
234 def _check_url_scheme(self, request: Request):
235 scheme = urllib.parse.urlparse(request.url).scheme.lower()
236 if self._SUPPORTED_URL_SCHEMES is not None and scheme not in self._SUPPORTED_URL_SCHEMES:
237 raise UnsupportedRequest(f'Unsupported url scheme: "{scheme}"')
238 return scheme # for further processing
239
240 def _check_proxies(self, proxies):
241 for proxy_key, proxy_url in proxies.items():
242 if proxy_url is None:
243 continue
244 if proxy_key == 'no':
245 if self._SUPPORTED_FEATURES is not None and Features.NO_PROXY not in self._SUPPORTED_FEATURES:
246 raise UnsupportedRequest('"no" proxy is not supported')
247 continue
248 if (
249 proxy_key == 'all'
250 and self._SUPPORTED_FEATURES is not None
251 and Features.ALL_PROXY not in self._SUPPORTED_FEATURES
252 ):
253 raise UnsupportedRequest('"all" proxy is not supported')
254
255 # Unlikely this handler will use this proxy, so ignore.
256 # This is to allow a case where a proxy may be set for a protocol
257 # for one handler in which such protocol (and proxy) is not supported by another handler.
258 if self._SUPPORTED_URL_SCHEMES is not None and proxy_key not in (*self._SUPPORTED_URL_SCHEMES, 'all'):
259 continue
260
261 if self._SUPPORTED_PROXY_SCHEMES is None:
262 # Skip proxy scheme checks
263 continue
264
265 # Scheme-less proxies are not supported
266 if urllib.request._parse_proxy(proxy_url)[0] is None:
267 raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
268
269 scheme = urllib.parse.urlparse(proxy_url).scheme.lower()
270 if scheme not in self._SUPPORTED_PROXY_SCHEMES:
271 raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
272
227bf1a3 273 def _check_extensions(self, extensions):
86aea0d3 274 """Check extensions for unsupported extensions. Subclasses should extend this."""
275 assert isinstance(extensions.get('cookiejar'), (CookieJar, NoneType))
276 assert isinstance(extensions.get('timeout'), (float, int, NoneType))
227bf1a3 277
278 def _validate(self, request):
279 self._check_url_scheme(request)
280 self._check_proxies(request.proxies or self.proxies)
86aea0d3 281 extensions = request.extensions.copy()
282 self._check_extensions(extensions)
283 if extensions:
284 # TODO: add support for optional extensions
285 raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}')
227bf1a3 286
287 @wrap_request_errors
288 def validate(self, request: Request):
289 if not isinstance(request, Request):
290 raise TypeError('Expected an instance of Request')
291 self._validate(request)
292
293 @wrap_request_errors
294 def send(self, request: Request) -> Response:
295 if not isinstance(request, Request):
296 raise TypeError('Expected an instance of Request')
297 return self._send(request)
298
299 @abc.abstractmethod
300 def _send(self, request: Request):
301 """Handle a request from start to finish. Redefine in subclasses."""
302
303 def close(self):
304 pass
305
306 @classproperty
307 def RH_NAME(cls):
308 return cls.__name__[:-2]
309
310 @classproperty
311 def RH_KEY(cls):
312 assert cls.__name__.endswith('RH'), 'RequestHandler class names must end with "RH"'
313 return cls.__name__[:-2]
314
315 def __enter__(self):
316 return self
317
318 def __exit__(self, *args):
319 self.close()
320
321
322class Request:
323 """
324 Represents a request to be made.
325 Partially backwards-compatible with urllib.request.Request.
326
327 @param url: url to send. Will be sanitized.
328 @param data: payload data to send. Must be bytes, iterable of bytes, a file-like object or None
329 @param headers: headers to send.
330 @param proxies: proxy dict mapping of proto:proxy to use for the request and any redirects.
331 @param query: URL query parameters to update the url with.
332 @param method: HTTP method to use. If no method specified, will use POST if payload data is present else GET
333 @param extensions: Dictionary of Request extensions to add, as supported by handlers.
334 """
335
336 def __init__(
337 self,
338 url: str,
339 data: RequestData = None,
340 headers: typing.Mapping = None,
341 proxies: dict = None,
342 query: dict = None,
343 method: str = None,
344 extensions: dict = None
345 ):
346
347 self._headers = HTTPHeaderDict()
348 self._data = None
349
350 if query:
351 url = update_url_query(url, query)
352
353 self.url = url
354 self.method = method
355 if headers:
356 self.headers = headers
357 self.data = data # note: must be done after setting headers
358 self.proxies = proxies or {}
359 self.extensions = extensions or {}
360
361 @property
362 def url(self):
363 return self._url
364
365 @url.setter
366 def url(self, url):
367 if not isinstance(url, str):
368 raise TypeError('url must be a string')
369 elif url.startswith('//'):
370 url = 'http:' + url
371 self._url = escape_url(url)
372
373 @property
374 def method(self):
375 return self._method or ('POST' if self.data is not None else 'GET')
376
377 @method.setter
378 def method(self, method):
379 if method is None:
380 self._method = None
381 elif isinstance(method, str):
382 self._method = method.upper()
383 else:
384 raise TypeError('method must be a string')
385
386 @property
387 def data(self):
388 return self._data
389
390 @data.setter
391 def data(self, data: RequestData):
392 # Try catch some common mistakes
393 if data is not None and (
394 not isinstance(data, (bytes, io.IOBase, Iterable)) or isinstance(data, (str, Mapping))
395 ):
396 raise TypeError('data must be bytes, iterable of bytes, or a file-like object')
397
398 if data == self._data and self._data is None:
399 self.headers.pop('Content-Length', None)
400
401 # https://docs.python.org/3/library/urllib.request.html#urllib.request.Request.data
402 if data != self._data:
403 if self._data is not None:
404 self.headers.pop('Content-Length', None)
405 self._data = data
406
407 if self._data is None:
408 self.headers.pop('Content-Type', None)
409
410 if 'Content-Type' not in self.headers and self._data is not None:
411 self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
412
413 @property
414 def headers(self) -> HTTPHeaderDict:
415 return self._headers
416
417 @headers.setter
418 def headers(self, new_headers: Mapping):
419 """Replaces headers of the request. If not a CaseInsensitiveDict, it will be converted to one."""
420 if isinstance(new_headers, HTTPHeaderDict):
421 self._headers = new_headers
422 elif isinstance(new_headers, Mapping):
423 self._headers = HTTPHeaderDict(new_headers)
424 else:
425 raise TypeError('headers must be a mapping')
426
427 def update(self, url=None, data=None, headers=None, query=None):
71baa490 428 self.data = data if data is not None else self.data
227bf1a3 429 self.headers.update(headers or {})
430 self.url = update_url_query(url or self.url, query or {})
431
432 def copy(self):
433 return self.__class__(
434 url=self.url,
435 headers=copy.deepcopy(self.headers),
436 proxies=copy.deepcopy(self.proxies),
437 data=self._data,
438 extensions=copy.copy(self.extensions),
439 method=self._method,
440 )
441
442
443HEADRequest = functools.partial(Request, method='HEAD')
444PUTRequest = functools.partial(Request, method='PUT')
445
446
447class Response(io.IOBase):
448 """
449 Base class for HTTP response adapters.
450
451 By default, it provides a basic wrapper for a file-like response object.
452
453 Interface partially backwards-compatible with addinfourl and http.client.HTTPResponse.
454
455 @param fp: Original, file-like, response.
456 @param url: URL that this is a response of.
457 @param headers: response headers.
458 @param status: Response HTTP status code. Default is 200 OK.
459 @param reason: HTTP status reason. Will use built-in reasons based on status code if not provided.
460 """
461
462 def __init__(
463 self,
464 fp: typing.IO,
465 url: str,
466 headers: Mapping[str, str],
467 status: int = 200,
468 reason: str = None):
469
470 self.fp = fp
471 self.headers = Message()
472 for name, value in headers.items():
473 self.headers.add_header(name, value)
474 self.status = status
475 self.url = url
476 try:
477 self.reason = reason or HTTPStatus(status).phrase
478 except ValueError:
479 self.reason = None
480
481 def readable(self):
482 return self.fp.readable()
483
484 def read(self, amt: int = None) -> bytes:
485 # Expected errors raised here should be of type RequestError or subclasses.
486 # Subclasses should redefine this method with more precise error handling.
487 try:
488 return self.fp.read(amt)
489 except Exception as e:
490 raise TransportError(cause=e) from e
491
492 def close(self):
493 self.fp.close()
494 return super().close()
495
496 def get_header(self, name, default=None):
497 """Get header for name.
498 If there are multiple matching headers, return all seperated by comma."""
499 headers = self.headers.get_all(name)
500 if not headers:
501 return default
502 if name.title() == 'Set-Cookie':
503 # Special case, only get the first one
504 # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.3-4.1
505 return headers[0]
506 return ', '.join(headers)
507
508 # The following methods are for compatability reasons and are deprecated
509 @property
510 def code(self):
3d2623a8 511 deprecation_warning('Response.code is deprecated, use Response.status', stacklevel=2)
227bf1a3 512 return self.status
513
514 def getcode(self):
3d2623a8 515 deprecation_warning('Response.getcode() is deprecated, use Response.status', stacklevel=2)
227bf1a3 516 return self.status
517
518 def geturl(self):
3d2623a8 519 deprecation_warning('Response.geturl() is deprecated, use Response.url', stacklevel=2)
227bf1a3 520 return self.url
521
522 def info(self):
3d2623a8 523 deprecation_warning('Response.info() is deprecated, use Response.headers', stacklevel=2)
227bf1a3 524 return self.headers
525
526 def getheader(self, name, default=None):
3d2623a8 527 deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
227bf1a3 528 return self.get_header(name, default)