]> jfr.im git - yt-dlp.git/blobdiff - test/test_networking.py
[networking] Add strict Request extension checking (#7604)
[yt-dlp.git] / test / test_networking.py
index b60ed283be27c70f781d8b528e5b192946e8d4bd..1bd6afc88baad5ea0ed8e9b180a60c0176c6faab 100644 (file)
@@ -804,10 +804,10 @@ def test_httplib_validation_errors(self, handler):
             assert not isinstance(exc_info.value, TransportError)
 
 
-def run_validation(handler, fail, req, **handler_kwargs):
+def run_validation(handler, error, req, **handler_kwargs):
     with handler(**handler_kwargs) as rh:
-        if fail:
-            with pytest.raises(UnsupportedRequest):
+        if error:
+            with pytest.raises(error):
                 rh.validate(req)
         else:
             rh.validate(req)
@@ -824,6 +824,9 @@ class NoCheckRH(ValidationRH):
         _SUPPORTED_PROXY_SCHEMES = None
         _SUPPORTED_URL_SCHEMES = None
 
+        def _check_extensions(self, extensions):
+            extensions.clear()
+
     class HTTPSupportedRH(ValidationRH):
         _SUPPORTED_URL_SCHEMES = ('http',)
 
@@ -834,26 +837,26 @@ class HTTPSupportedRH(ValidationRH):
             ('https', False, {}),
             ('data', False, {}),
             ('ftp', False, {}),
-            ('file', True, {}),
+            ('file', UnsupportedRequest, {}),
             ('file', False, {'enable_file_urls': True}),
         ]),
         (NoCheckRH, [('http', False, {})]),
-        (ValidationRH, [('http', True, {})])
+        (ValidationRH, [('http', UnsupportedRequest, {})])
     ]
 
     PROXY_SCHEME_TESTS = [
         # scheme, expected to fail
         ('Urllib', [
             ('http', False),
-            ('https', True),
+            ('https', UnsupportedRequest),
             ('socks4', False),
             ('socks4a', False),
             ('socks5', False),
             ('socks5h', False),
-            ('socks', True),
+            ('socks', UnsupportedRequest),
         ]),
         (NoCheckRH, [('http', False)]),
-        (HTTPSupportedRH, [('http', True)]),
+        (HTTPSupportedRH, [('http', UnsupportedRequest)]),
     ]
 
     PROXY_KEY_TESTS = [
@@ -863,8 +866,22 @@ class HTTPSupportedRH(ValidationRH):
             ('unrelated', False),
         ]),
         (NoCheckRH, [('all', False)]),
-        (HTTPSupportedRH, [('all', True)]),
-        (HTTPSupportedRH, [('no', True)]),
+        (HTTPSupportedRH, [('all', UnsupportedRequest)]),
+        (HTTPSupportedRH, [('no', UnsupportedRequest)]),
+    ]
+
+    EXTENSION_TESTS = [
+        ('Urllib', [
+            ({'cookiejar': 'notacookiejar'}, AssertionError),
+            ({'cookiejar': CookieJar()}, False),
+            ({'timeout': 1}, False),
+            ({'timeout': 'notatimeout'}, AssertionError),
+            ({'unsupported': 'value'}, UnsupportedRequest),
+        ]),
+        (NoCheckRH, [
+            ({'cookiejar': 'notacookiejar'}, False),
+            ({'somerandom': 'test'}, False),  # but any extension is allowed through
+        ]),
     ]
 
     @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
@@ -907,15 +924,16 @@ def test_empty_proxy(self, handler):
     @pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1'])
     @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
     def test_missing_proxy_scheme(self, handler, proxy_url):
-        run_validation(handler, True, Request('http://', proxies={'http': 'example.com'}))
-
-    @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
-    def test_cookiejar_extension(self, handler):
-        run_validation(handler, True, Request('http://', extensions={'cookiejar': 'notacookiejar'}))
+        run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': 'example.com'}))
 
-    @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
-    def test_timeout_extension(self, handler):
-        run_validation(handler, True, Request('http://', extensions={'timeout': 'notavalidtimeout'}))
+    @pytest.mark.parametrize('handler,extensions,fail', [
+        (handler_tests[0], extensions, fail)
+        for handler_tests in EXTENSION_TESTS
+        for extensions, fail in handler_tests[1]
+    ], indirect=['handler'])
+    def test_extension(self, handler, extensions, fail):
+        run_validation(
+            handler, fail, Request('http://', extensions=extensions))
 
     def test_invalid_request_type(self):
         rh = self.ValidationRH(logger=FakeLogger())
@@ -1152,7 +1170,7 @@ def test_build_handler_params(self):
             'debug_printtraffic': True,
             'compat_opts': ['no-certifi'],
             'nocheckcertificate': True,
-            'legacy_server_connect': True,
+            'legacyserverconnect': True,
         }) as ydl:
             rh = self.build_handler(ydl)
             assert rh.headers.get('test') == 'testtest'
@@ -1280,6 +1298,17 @@ def test_content_type_header(self):
         req.data = b'test3'
         assert req.headers.get('Content-Type') == 'application/x-www-form-urlencoded'
 
+    def test_update_req(self):
+        req = Request('http://example.com')
+        assert req.data is None
+        assert req.method == 'GET'
+        assert 'Content-Type' not in req.headers
+        # Test that zero-byte payloads will be sent
+        req.update(data=b'')
+        assert req.data == b''
+        assert req.method == 'POST'
+        assert req.headers.get('Content-Type') == 'application/x-www-form-urlencoded'
+
     def test_proxies(self):
         req = Request(url='http://example.com', proxies={'http': 'http://127.0.0.1:8080'})
         assert req.proxies == {'http': 'http://127.0.0.1:8080'}