]> jfr.im git - yt-dlp.git/commitdiff
[networking] Add strict Request extension checking (#7604)
authorcoletdjnz <redacted>
Sun, 23 Jul 2023 05:17:15 +0000 (17:17 +1200)
committerGitHub <redacted>
Sun, 23 Jul 2023 05:17:15 +0000 (05:17 +0000)
Authored by: coletdjnz
Co-authored-by: pukkandan <redacted>
test/test_networking.py
yt_dlp/networking/_urllib.py
yt_dlp/networking/common.py

index d4eba2a5dfe67eec5aa6dac7490f13564c2ad982..1bd6afc88baad5ea0ed8e9b180a60c0176c6faab 100644 (file)
@@ -804,10 +804,10 @@ def test_httplib_validation_errors(self, handler):
             assert not isinstance(exc_info.value, TransportError)
 
 
-def run_validation(handler, fail, req, **handler_kwargs):
+def run_validation(handler, error, req, **handler_kwargs):
     with handler(**handler_kwargs) as rh:
-        if fail:
-            with pytest.raises(UnsupportedRequest):
+        if error:
+            with pytest.raises(error):
                 rh.validate(req)
         else:
             rh.validate(req)
@@ -824,6 +824,9 @@ class NoCheckRH(ValidationRH):
         _SUPPORTED_PROXY_SCHEMES = None
         _SUPPORTED_URL_SCHEMES = None
 
+        def _check_extensions(self, extensions):
+            extensions.clear()
+
     class HTTPSupportedRH(ValidationRH):
         _SUPPORTED_URL_SCHEMES = ('http',)
 
@@ -834,26 +837,26 @@ class HTTPSupportedRH(ValidationRH):
             ('https', False, {}),
             ('data', False, {}),
             ('ftp', False, {}),
-            ('file', True, {}),
+            ('file', UnsupportedRequest, {}),
             ('file', False, {'enable_file_urls': True}),
         ]),
         (NoCheckRH, [('http', False, {})]),
-        (ValidationRH, [('http', True, {})])
+        (ValidationRH, [('http', UnsupportedRequest, {})])
     ]
 
     PROXY_SCHEME_TESTS = [
         # scheme, expected to fail
         ('Urllib', [
             ('http', False),
-            ('https', True),
+            ('https', UnsupportedRequest),
             ('socks4', False),
             ('socks4a', False),
             ('socks5', False),
             ('socks5h', False),
-            ('socks', True),
+            ('socks', UnsupportedRequest),
         ]),
         (NoCheckRH, [('http', False)]),
-        (HTTPSupportedRH, [('http', True)]),
+        (HTTPSupportedRH, [('http', UnsupportedRequest)]),
     ]
 
     PROXY_KEY_TESTS = [
@@ -863,8 +866,22 @@ class HTTPSupportedRH(ValidationRH):
             ('unrelated', False),
         ]),
         (NoCheckRH, [('all', False)]),
-        (HTTPSupportedRH, [('all', True)]),
-        (HTTPSupportedRH, [('no', True)]),
+        (HTTPSupportedRH, [('all', UnsupportedRequest)]),
+        (HTTPSupportedRH, [('no', UnsupportedRequest)]),
+    ]
+
+    EXTENSION_TESTS = [
+        ('Urllib', [
+            ({'cookiejar': 'notacookiejar'}, AssertionError),
+            ({'cookiejar': CookieJar()}, False),
+            ({'timeout': 1}, False),
+            ({'timeout': 'notatimeout'}, AssertionError),
+            ({'unsupported': 'value'}, UnsupportedRequest),
+        ]),
+        (NoCheckRH, [
+            ({'cookiejar': 'notacookiejar'}, False),
+            ({'somerandom': 'test'}, False),  # but any extension is allowed through
+        ]),
     ]
 
     @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
@@ -907,15 +924,16 @@ def test_empty_proxy(self, handler):
     @pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1'])
     @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
     def test_missing_proxy_scheme(self, handler, proxy_url):
-        run_validation(handler, True, Request('http://', proxies={'http': 'example.com'}))
-
-    @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
-    def test_cookiejar_extension(self, handler):
-        run_validation(handler, True, Request('http://', extensions={'cookiejar': 'notacookiejar'}))
+        run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': 'example.com'}))
 
-    @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
-    def test_timeout_extension(self, handler):
-        run_validation(handler, True, Request('http://', extensions={'timeout': 'notavalidtimeout'}))
+    @pytest.mark.parametrize('handler,extensions,fail', [
+        (handler_tests[0], extensions, fail)
+        for handler_tests in EXTENSION_TESTS
+        for extensions, fail in handler_tests[1]
+    ], indirect=['handler'])
+    def test_extension(self, handler, extensions, fail):
+        run_validation(
+            handler, fail, Request('http://', extensions=extensions))
 
     def test_invalid_request_type(self):
         rh = self.ValidationRH(logger=FakeLogger())
index ff3a22c8c18809163c2454dfb752edabe086b07d..3fe5fa52ea9272411c38246c867a6a3360492b27 100644 (file)
@@ -385,6 +385,11 @@ def __init__(self, *, enable_file_urls: bool = False, **kwargs):
         if self.enable_file_urls:
             self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file')
 
+    def _check_extensions(self, extensions):
+        super()._check_extensions(extensions)
+        extensions.pop('cookiejar', None)
+        extensions.pop('timeout', None)
+
     def _create_instance(self, proxies, cookiejar):
         opener = urllib.request.OpenerDirector()
         handlers = [
index 7f745797805da96123d1bffdb2f4c1184bc155b7..ab26a06282e813eb9f5da8b7d83404c9031179c8 100644 (file)
@@ -21,6 +21,7 @@
     TransportError,
     UnsupportedRequest,
 )
+from ..compat.types import NoneType
 from ..utils import (
     bug_reports_message,
     classproperty,
@@ -147,6 +148,7 @@ class RequestHandler(abc.ABC):
         a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
 
     - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
+
     The above may be set to None to disable the checks.
 
     Parameters:
@@ -169,9 +171,14 @@ class RequestHandler(abc.ABC):
     Requests may have additional optional parameters defined as extensions.
      RequestHandler subclasses may choose to support custom extensions.
 
+    If an extension is supported, subclasses should extend _check_extensions(extensions)
+    to pop and validate the extension.
+    - Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised.
+
     The following extensions are defined for RequestHandler:
-    - `cookiejar`: Cookiejar to use for this request
-    - `timeout`: socket timeout to use for this request
+    - `cookiejar`: Cookiejar to use for this request.
+    - `timeout`: socket timeout to use for this request.
+    To enable these, add extensions.pop('<extension>', None) to _check_extensions
 
     Apart from the url protocol, proxies dict may contain the following keys:
     - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
@@ -263,26 +270,19 @@ def _check_proxies(self, proxies):
             if scheme not in self._SUPPORTED_PROXY_SCHEMES:
                 raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
 
-    def _check_cookiejar_extension(self, extensions):
-        if not extensions.get('cookiejar'):
-            return
-        if not isinstance(extensions['cookiejar'], CookieJar):
-            raise UnsupportedRequest('cookiejar is not a CookieJar')
-
-    def _check_timeout_extension(self, extensions):
-        if extensions.get('timeout') is None:
-            return
-        if not isinstance(extensions['timeout'], (float, int)):
-            raise UnsupportedRequest('timeout is not a float or int')
-
     def _check_extensions(self, extensions):
-        self._check_cookiejar_extension(extensions)
-        self._check_timeout_extension(extensions)
+        """Check extensions for unsupported extensions. Subclasses should extend this."""
+        assert isinstance(extensions.get('cookiejar'), (CookieJar, NoneType))
+        assert isinstance(extensions.get('timeout'), (float, int, NoneType))
 
     def _validate(self, request):
         self._check_url_scheme(request)
         self._check_proxies(request.proxies or self.proxies)
-        self._check_extensions(request.extensions)
+        extensions = request.extensions.copy()
+        self._check_extensions(extensions)
+        if extensions:
+            # TODO: add support for optional extensions
+            raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}')
 
     @wrap_request_errors
     def validate(self, request: Request):