]> jfr.im git - yt-dlp.git/commitdiff
[networking] Add strict Request extension checking (#7604)
authorcoletdjnz <redacted>
Sun, 23 Jul 2023 05:17:15 +0000 (17:17 +1200)
committerGitHub <redacted>
Sun, 23 Jul 2023 05:17:15 +0000 (05:17 +0000)
Authored by: coletdjnz
Co-authored-by: pukkandan <redacted>
test/test_networking.py
yt_dlp/networking/_urllib.py
yt_dlp/networking/common.py

index d4eba2a5dfe67eec5aa6dac7490f13564c2ad982..1bd6afc88baad5ea0ed8e9b180a60c0176c6faab 100644 (file)
@@ -804,10 +804,10 @@ def test_httplib_validation_errors(self, handler):
             assert not isinstance(exc_info.value, TransportError)
 
 
             assert not isinstance(exc_info.value, TransportError)
 
 
-def run_validation(handler, fail, req, **handler_kwargs):
+def run_validation(handler, error, req, **handler_kwargs):
     with handler(**handler_kwargs) as rh:
     with handler(**handler_kwargs) as rh:
-        if fail:
-            with pytest.raises(UnsupportedRequest):
+        if error:
+            with pytest.raises(error):
                 rh.validate(req)
         else:
             rh.validate(req)
                 rh.validate(req)
         else:
             rh.validate(req)
@@ -824,6 +824,9 @@ class NoCheckRH(ValidationRH):
         _SUPPORTED_PROXY_SCHEMES = None
         _SUPPORTED_URL_SCHEMES = None
 
         _SUPPORTED_PROXY_SCHEMES = None
         _SUPPORTED_URL_SCHEMES = None
 
+        def _check_extensions(self, extensions):
+            extensions.clear()
+
     class HTTPSupportedRH(ValidationRH):
         _SUPPORTED_URL_SCHEMES = ('http',)
 
     class HTTPSupportedRH(ValidationRH):
         _SUPPORTED_URL_SCHEMES = ('http',)
 
@@ -834,26 +837,26 @@ class HTTPSupportedRH(ValidationRH):
             ('https', False, {}),
             ('data', False, {}),
             ('ftp', False, {}),
             ('https', False, {}),
             ('data', False, {}),
             ('ftp', False, {}),
-            ('file', True, {}),
+            ('file', UnsupportedRequest, {}),
             ('file', False, {'enable_file_urls': True}),
         ]),
         (NoCheckRH, [('http', False, {})]),
             ('file', False, {'enable_file_urls': True}),
         ]),
         (NoCheckRH, [('http', False, {})]),
-        (ValidationRH, [('http', True, {})])
+        (ValidationRH, [('http', UnsupportedRequest, {})])
     ]
 
     PROXY_SCHEME_TESTS = [
         # scheme, expected to fail
         ('Urllib', [
             ('http', False),
     ]
 
     PROXY_SCHEME_TESTS = [
         # scheme, expected to fail
         ('Urllib', [
             ('http', False),
-            ('https', True),
+            ('https', UnsupportedRequest),
             ('socks4', False),
             ('socks4a', False),
             ('socks5', False),
             ('socks5h', False),
             ('socks4', False),
             ('socks4a', False),
             ('socks5', False),
             ('socks5h', False),
-            ('socks', True),
+            ('socks', UnsupportedRequest),
         ]),
         (NoCheckRH, [('http', False)]),
         ]),
         (NoCheckRH, [('http', False)]),
-        (HTTPSupportedRH, [('http', True)]),
+        (HTTPSupportedRH, [('http', UnsupportedRequest)]),
     ]
 
     PROXY_KEY_TESTS = [
     ]
 
     PROXY_KEY_TESTS = [
@@ -863,8 +866,22 @@ class HTTPSupportedRH(ValidationRH):
             ('unrelated', False),
         ]),
         (NoCheckRH, [('all', False)]),
             ('unrelated', False),
         ]),
         (NoCheckRH, [('all', False)]),
-        (HTTPSupportedRH, [('all', True)]),
-        (HTTPSupportedRH, [('no', True)]),
+        (HTTPSupportedRH, [('all', UnsupportedRequest)]),
+        (HTTPSupportedRH, [('no', UnsupportedRequest)]),
+    ]
+
+    EXTENSION_TESTS = [
+        ('Urllib', [
+            ({'cookiejar': 'notacookiejar'}, AssertionError),
+            ({'cookiejar': CookieJar()}, False),
+            ({'timeout': 1}, False),
+            ({'timeout': 'notatimeout'}, AssertionError),
+            ({'unsupported': 'value'}, UnsupportedRequest),
+        ]),
+        (NoCheckRH, [
+            ({'cookiejar': 'notacookiejar'}, False),
+            ({'somerandom': 'test'}, False),  # but any extension is allowed through
+        ]),
     ]
 
     @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
     ]
 
     @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
@@ -907,15 +924,16 @@ def test_empty_proxy(self, handler):
     @pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1'])
     @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
     def test_missing_proxy_scheme(self, handler, proxy_url):
     @pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1'])
     @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
     def test_missing_proxy_scheme(self, handler, proxy_url):
-        run_validation(handler, True, Request('http://', proxies={'http': 'example.com'}))
-
-    @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
-    def test_cookiejar_extension(self, handler):
-        run_validation(handler, True, Request('http://', extensions={'cookiejar': 'notacookiejar'}))
+        run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': 'example.com'}))
 
 
-    @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
-    def test_timeout_extension(self, handler):
-        run_validation(handler, True, Request('http://', extensions={'timeout': 'notavalidtimeout'}))
+    @pytest.mark.parametrize('handler,extensions,fail', [
+        (handler_tests[0], extensions, fail)
+        for handler_tests in EXTENSION_TESTS
+        for extensions, fail in handler_tests[1]
+    ], indirect=['handler'])
+    def test_extension(self, handler, extensions, fail):
+        run_validation(
+            handler, fail, Request('http://', extensions=extensions))
 
     def test_invalid_request_type(self):
         rh = self.ValidationRH(logger=FakeLogger())
 
     def test_invalid_request_type(self):
         rh = self.ValidationRH(logger=FakeLogger())
index ff3a22c8c18809163c2454dfb752edabe086b07d..3fe5fa52ea9272411c38246c867a6a3360492b27 100644 (file)
@@ -385,6 +385,11 @@ def __init__(self, *, enable_file_urls: bool = False, **kwargs):
         if self.enable_file_urls:
             self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file')
 
         if self.enable_file_urls:
             self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file')
 
+    def _check_extensions(self, extensions):
+        super()._check_extensions(extensions)
+        extensions.pop('cookiejar', None)
+        extensions.pop('timeout', None)
+
     def _create_instance(self, proxies, cookiejar):
         opener = urllib.request.OpenerDirector()
         handlers = [
     def _create_instance(self, proxies, cookiejar):
         opener = urllib.request.OpenerDirector()
         handlers = [
index 7f745797805da96123d1bffdb2f4c1184bc155b7..ab26a06282e813eb9f5da8b7d83404c9031179c8 100644 (file)
@@ -21,6 +21,7 @@
     TransportError,
     UnsupportedRequest,
 )
     TransportError,
     UnsupportedRequest,
 )
+from ..compat.types import NoneType
 from ..utils import (
     bug_reports_message,
     classproperty,
 from ..utils import (
     bug_reports_message,
     classproperty,
@@ -147,6 +148,7 @@ class RequestHandler(abc.ABC):
         a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
 
     - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
         a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
 
     - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
+
     The above may be set to None to disable the checks.
 
     Parameters:
     The above may be set to None to disable the checks.
 
     Parameters:
@@ -169,9 +171,14 @@ class RequestHandler(abc.ABC):
     Requests may have additional optional parameters defined as extensions.
      RequestHandler subclasses may choose to support custom extensions.
 
     Requests may have additional optional parameters defined as extensions.
      RequestHandler subclasses may choose to support custom extensions.
 
+    If an extension is supported, subclasses should extend _check_extensions(extensions)
+    to pop and validate the extension.
+    - Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised.
+
     The following extensions are defined for RequestHandler:
     The following extensions are defined for RequestHandler:
-    - `cookiejar`: Cookiejar to use for this request
-    - `timeout`: socket timeout to use for this request
+    - `cookiejar`: Cookiejar to use for this request.
+    - `timeout`: socket timeout to use for this request.
+    To enable these, add extensions.pop('<extension>', None) to _check_extensions
 
     Apart from the url protocol, proxies dict may contain the following keys:
     - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
 
     Apart from the url protocol, proxies dict may contain the following keys:
     - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
@@ -263,26 +270,19 @@ def _check_proxies(self, proxies):
             if scheme not in self._SUPPORTED_PROXY_SCHEMES:
                 raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
 
             if scheme not in self._SUPPORTED_PROXY_SCHEMES:
                 raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
 
-    def _check_cookiejar_extension(self, extensions):
-        if not extensions.get('cookiejar'):
-            return
-        if not isinstance(extensions['cookiejar'], CookieJar):
-            raise UnsupportedRequest('cookiejar is not a CookieJar')
-
-    def _check_timeout_extension(self, extensions):
-        if extensions.get('timeout') is None:
-            return
-        if not isinstance(extensions['timeout'], (float, int)):
-            raise UnsupportedRequest('timeout is not a float or int')
-
     def _check_extensions(self, extensions):
     def _check_extensions(self, extensions):
-        self._check_cookiejar_extension(extensions)
-        self._check_timeout_extension(extensions)
+        """Check extensions for unsupported extensions. Subclasses should extend this."""
+        assert isinstance(extensions.get('cookiejar'), (CookieJar, NoneType))
+        assert isinstance(extensions.get('timeout'), (float, int, NoneType))
 
     def _validate(self, request):
         self._check_url_scheme(request)
         self._check_proxies(request.proxies or self.proxies)
 
     def _validate(self, request):
         self._check_url_scheme(request)
         self._check_proxies(request.proxies or self.proxies)
-        self._check_extensions(request.extensions)
+        extensions = request.extensions.copy()
+        self._check_extensions(extensions)
+        if extensions:
+            # TODO: add support for optional extensions
+            raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}')
 
     @wrap_request_errors
     def validate(self, request: Request):
 
     @wrap_request_errors
     def validate(self, request: Request):