]> jfr.im git - yt-dlp.git/blobdiff - test/test_websockets.py
[ie] Make `_search_nextjs_data` non fatal (#8937)
[yt-dlp.git] / test / test_websockets.py
index 39d3c7d7221439edaaf796d9512ef842de21a1c4..b294b0932b90d626ce6ada2204e74062c1128717 100644 (file)
@@ -6,6 +6,8 @@
 
 import pytest
 
+from test.helper import verify_address_availability
+
 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
 import http.client
@@ -30,8 +32,6 @@
 )
 from yt_dlp.utils.networking import HTTPHeaderDict
 
-from test.conftest import validate_and_send
-
 TEST_DIR = os.path.dirname(os.path.abspath(__file__))
 
 
@@ -64,7 +64,9 @@ def process_request(self, request):
 
 def create_websocket_server(**ws_kwargs):
     import websockets.sync.server
-    wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
+    wsd = websockets.sync.server.serve(
+        websocket_handler, '127.0.0.1', 0,
+        process_request=process_request, open_timeout=2, **ws_kwargs)
     ws_port = wsd.socket.getsockname()[1]
     ws_server_thread = threading.Thread(target=wsd.serve_forever)
     ws_server_thread.daemon = True
@@ -98,6 +100,19 @@ def create_mtls_wss_websocket_server():
     return create_websocket_server(ssl_context=sslctx)
 
 
+def ws_validate_and_send(rh, req):
+    rh.validate(req)
+    max_tries = 3
+    for i in range(max_tries):
+        try:
+            return rh.send(req)
+        except TransportError as e:
+            if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
+                # websockets server sometimes hangs on new connections
+                continue
+            raise
+
+
 @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
 class TestWebsSocketRequestHandlerConformance:
     @classmethod
@@ -117,7 +132,7 @@ def setup_class(cls):
     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
     def test_basic_websockets(self, handler):
         with handler() as rh:
-            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
             assert 'upgrade' in ws.headers
             assert ws.status == 101
             ws.send('foo')
@@ -129,7 +144,7 @@ def test_basic_websockets(self, handler):
     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
     def test_send_types(self, handler, msg, opcode):
         with handler() as rh:
-            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
             ws.send(msg)
             assert int(ws.recv()) == opcode
             ws.close()
@@ -138,18 +153,18 @@ def test_send_types(self, handler, msg, opcode):
     def test_verify_cert(self, handler):
         with handler() as rh:
             with pytest.raises(CertificateVerifyError):
-                validate_and_send(rh, Request(self.wss_base_url))
+                ws_validate_and_send(rh, Request(self.wss_base_url))
 
         with handler(verify=False) as rh:
-            ws = validate_and_send(rh, Request(self.wss_base_url))
+            ws = ws_validate_and_send(rh, Request(self.wss_base_url))
             assert ws.status == 101
             ws.close()
 
     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
     def test_ssl_error(self, handler):
         with handler(verify=False) as rh:
-            with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info:
-                validate_and_send(rh, Request(self.bad_wss_host))
+            with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
+                ws_validate_and_send(rh, Request(self.bad_wss_host))
             assert not issubclass(exc_info.type, CertificateVerifyError)
 
     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@@ -161,7 +176,7 @@ def test_ssl_error(self, handler):
     ])
     def test_percent_encode(self, handler, path, expected):
         with handler() as rh:
-            ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
+            ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
             ws.send('path')
             assert ws.recv() == expected
             assert ws.status == 101
@@ -172,7 +187,7 @@ def test_remove_dot_segments(self, handler):
         with handler() as rh:
             # This isn't a comprehensive test,
             # but it should be enough to check whether the handler is removing dot segments
-            ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
+            ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
             assert ws.status == 101
             ws.send('path')
             assert ws.recv() == '/test'
@@ -185,18 +200,18 @@ def test_remove_dot_segments(self, handler):
     def test_raise_http_error(self, handler, status):
         with handler() as rh:
             with pytest.raises(HTTPError) as exc_info:
-                validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
+                ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
             assert exc_info.value.status == status
 
     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
     @pytest.mark.parametrize('params,extensions', [
-        ({'timeout': 0.00001}, {}),
-        ({}, {'timeout': 0.00001}),
+        ({'timeout': sys.float_info.min}, {}),
+        ({}, {'timeout': sys.float_info.min}),
     ])
     def test_timeout(self, handler, params, extensions):
         with handler(**params) as rh:
             with pytest.raises(TransportError):
-                validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
+                ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
 
     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
     def test_cookies(self, handler):
@@ -208,18 +223,18 @@ def test_cookies(self, handler):
             comment_url=None, rest={}))
 
         with handler(cookiejar=cookiejar) as rh:
-            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
             ws.send('headers')
             assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
             ws.close()
 
         with handler() as rh:
-            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
             ws.send('headers')
             assert 'cookie' not in json.loads(ws.recv())
             ws.close()
 
-            ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
+            ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
             ws.send('headers')
             assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
             ws.close()
@@ -227,8 +242,9 @@ def test_cookies(self, handler):
     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
     def test_source_address(self, handler):
         source_address = f'127.0.0.{random.randint(5, 255)}'
+        verify_address_availability(source_address)
         with handler(source_address=source_address) as rh:
-            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
             ws.send('source_address')
             assert source_address == ws.recv()
             ws.close()
@@ -237,7 +253,7 @@ def test_source_address(self, handler):
     def test_response_url(self, handler):
         with handler() as rh:
             url = f'{self.ws_base_url}/something'
-            ws = validate_and_send(rh, Request(url))
+            ws = ws_validate_and_send(rh, Request(url))
             assert ws.url == url
             ws.close()
 
@@ -245,14 +261,14 @@ def test_response_url(self, handler):
     def test_request_headers(self, handler):
         with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
             # Global Headers
-            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
             ws.send('headers')
             headers = HTTPHeaderDict(json.loads(ws.recv()))
             assert headers['test1'] == 'test'
             ws.close()
 
             # Per request headers, merged with global
-            ws = validate_and_send(rh, Request(
+            ws = ws_validate_and_send(rh, Request(
                 self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
             ws.send('headers')
             headers = HTTPHeaderDict(json.loads(ws.recv()))
@@ -285,7 +301,7 @@ def test_mtls(self, handler, client_cert):
             verify=False,
             client_cert=client_cert
         ) as rh:
-            validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
+            ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
 
 
 def create_fake_ws_connection(raised):