]> jfr.im git - yt-dlp.git/commitdiff
[rh:websockets] Migrate websockets to networking framework (#7720)
authorcoletdjnz <redacted>
Mon, 20 Nov 2023 08:04:04 +0000 (08:04 +0000)
committerGitHub <redacted>
Mon, 20 Nov 2023 08:04:04 +0000 (08:04 +0000)
* Adds a basic WebSocket framework
* Introduces new minimum `websockets` version of 12.0
* Deprecates `WebSocketsWrapper`

Fixes https://github.com/yt-dlp/yt-dlp/issues/8439

Authored by: coletdjnz

14 files changed:
requirements.txt
test/conftest.py
test/test_networking.py
test/test_socks.py
test/test_websockets.py [new file with mode: 0644]
yt_dlp/YoutubeDL.py
yt_dlp/downloader/niconico.py
yt_dlp/extractor/fc2.py
yt_dlp/extractor/niconico.py
yt_dlp/networking/__init__.py
yt_dlp/networking/_websockets.py [new file with mode: 0644]
yt_dlp/networking/websocket.py [new file with mode: 0644]
yt_dlp/utils/_legacy.py
yt_dlp/utils/_utils.py

index 5b6270a7dab47ec039423feb0b302da931504be7..d983fa03ffcc31bd4725a1697e766f9a3d49266c 100644 (file)
@@ -6,3 +6,4 @@ brotlicffi; implementation_name!='cpython'
 certifi
 requests>=2.31.0,<3
 urllib3>=1.26.17,<3
+websockets>=12.0
index 15549d30b9d8aecd7bce0713c643f28c3ed64a1d..2fbc269e1fb7b4b509a83115d1a5a1ad890033f7 100644 (file)
@@ -19,3 +19,8 @@ def handler(request):
         pytest.skip(f'{RH_KEY} request handler is not available')
 
     return functools.partial(handler, logger=FakeLogger)
+
+
+def validate_and_send(rh, req):
+    rh.validate(req)
+    return rh.send(req)
index 4466fc0485c11936d73b5d8dd4e3605a05463d31..64af6e459a730cd7c7124955298efd51c8d60e59 100644 (file)
@@ -52,6 +52,8 @@
 from yt_dlp.utils._utils import _YDLLogger as FakeLogger
 from yt_dlp.utils.networking import HTTPHeaderDict
 
+from test.conftest import validate_and_send
+
 TEST_DIR = os.path.dirname(os.path.abspath(__file__))
 
 
@@ -275,11 +277,6 @@ def send_header(self, keyword, value):
         self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
 
 
-def validate_and_send(rh, req):
-    rh.validate(req)
-    return rh.send(req)
-
-
 class TestRequestHandlerBase:
     @classmethod
     def setup_class(cls):
@@ -872,8 +869,9 @@ def request(self, *args, **kwargs):
     ])
     @pytest.mark.parametrize('handler', ['Requests'], indirect=True)
     def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match):
-        from urllib3.response import HTTPResponse as Urllib3Response
         from requests.models import Response as RequestsResponse
+        from urllib3.response import HTTPResponse as Urllib3Response
+
         from yt_dlp.networking._requests import RequestsResponseAdapter
         requests_res = RequestsResponse()
         requests_res.raw = Urllib3Response(body=b'', status=200)
@@ -929,13 +927,17 @@ class HTTPSupportedRH(ValidationRH):
             ('http', False, {}),
             ('https', False, {}),
         ]),
+        ('Websockets', [
+            ('ws', False, {}),
+            ('wss', False, {}),
+        ]),
         (NoCheckRH, [('http', False, {})]),
         (ValidationRH, [('http', UnsupportedRequest, {})])
     ]
 
     PROXY_SCHEME_TESTS = [
         # scheme, expected to fail
-        ('Urllib', [
+        ('Urllib', 'http', [
             ('http', False),
             ('https', UnsupportedRequest),
             ('socks4', False),
@@ -944,7 +946,7 @@ class HTTPSupportedRH(ValidationRH):
             ('socks5h', False),
             ('socks', UnsupportedRequest),
         ]),
-        ('Requests', [
+        ('Requests', 'http', [
             ('http', False),
             ('https', False),
             ('socks4', False),
@@ -952,8 +954,11 @@ class HTTPSupportedRH(ValidationRH):
             ('socks5', False),
             ('socks5h', False),
         ]),
-        (NoCheckRH, [('http', False)]),
-        (HTTPSupportedRH, [('http', UnsupportedRequest)]),
+        (NoCheckRH, 'http', [('http', False)]),
+        (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
+        ('Websockets', 'ws', [('http', UnsupportedRequest)]),
+        (NoCheckRH, 'http', [('http', False)]),
+        (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
     ]
 
     PROXY_KEY_TESTS = [
@@ -972,7 +977,7 @@ class HTTPSupportedRH(ValidationRH):
     ]
 
     EXTENSION_TESTS = [
-        ('Urllib', [
+        ('Urllib', 'http', [
             ({'cookiejar': 'notacookiejar'}, AssertionError),
             ({'cookiejar': YoutubeDLCookieJar()}, False),
             ({'cookiejar': CookieJar()}, AssertionError),
@@ -980,17 +985,21 @@ class HTTPSupportedRH(ValidationRH):
             ({'timeout': 'notatimeout'}, AssertionError),
             ({'unsupported': 'value'}, UnsupportedRequest),
         ]),
-        ('Requests', [
+        ('Requests', 'http', [
             ({'cookiejar': 'notacookiejar'}, AssertionError),
             ({'cookiejar': YoutubeDLCookieJar()}, False),
             ({'timeout': 1}, False),
             ({'timeout': 'notatimeout'}, AssertionError),
             ({'unsupported': 'value'}, UnsupportedRequest),
         ]),
-        (NoCheckRH, [
+        (NoCheckRH, 'http', [
             ({'cookiejar': 'notacookiejar'}, False),
             ({'somerandom': 'test'}, False),  # but any extension is allowed through
         ]),
+        ('Websockets', 'ws', [
+            ({'cookiejar': YoutubeDLCookieJar()}, False),
+            ({'timeout': 2}, False),
+        ]),
     ]
 
     @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
@@ -1016,14 +1025,14 @@ def test_proxy_key(self, handler, proxy_key, fail):
         run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'}))
         run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'})
 
-    @pytest.mark.parametrize('handler,scheme,fail', [
-        (handler_tests[0], scheme, fail)
+    @pytest.mark.parametrize('handler,req_scheme,scheme,fail', [
+        (handler_tests[0], handler_tests[1], scheme, fail)
         for handler_tests in PROXY_SCHEME_TESTS
-        for scheme, fail in handler_tests[1]
+        for scheme, fail in handler_tests[2]
     ], indirect=['handler'])
-    def test_proxy_scheme(self, handler, scheme, fail):
-        run_validation(handler, fail, Request('http://', proxies={'http': f'{scheme}://example.com'}))
-        run_validation(handler, fail, Request('http://'), proxies={'http': f'{scheme}://example.com'})
+    def test_proxy_scheme(self, handler, req_scheme, scheme, fail):
+        run_validation(handler, fail, Request(f'{req_scheme}://', proxies={req_scheme: f'{scheme}://example.com'}))
+        run_validation(handler, fail, Request(f'{req_scheme}://'), proxies={req_scheme: f'{scheme}://example.com'})
 
     @pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH, 'Requests'], indirect=True)
     def test_empty_proxy(self, handler):
@@ -1035,14 +1044,14 @@ def test_empty_proxy(self, handler):
     def test_invalid_proxy_url(self, handler, proxy_url):
         run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': proxy_url}))
 
-    @pytest.mark.parametrize('handler,extensions,fail', [
-        (handler_tests[0], extensions, fail)
+    @pytest.mark.parametrize('handler,scheme,extensions,fail', [
+        (handler_tests[0], handler_tests[1], extensions, fail)
         for handler_tests in EXTENSION_TESTS
-        for extensions, fail in handler_tests[1]
+        for extensions, fail in handler_tests[2]
     ], indirect=['handler'])
-    def test_extension(self, handler, extensions, fail):
+    def test_extension(self, handler, scheme, extensions, fail):
         run_validation(
-            handler, fail, Request('http://', extensions=extensions))
+            handler, fail, Request(f'{scheme}://', extensions=extensions))
 
     def test_invalid_request_type(self):
         rh = self.ValidationRH(logger=FakeLogger())
@@ -1075,6 +1084,22 @@ def __init__(self, *args, **kwargs):
         self._request_director = self.build_request_director([FakeRH])
 
 
+class AllUnsupportedRHYDL(FakeYDL):
+
+    def __init__(self, *args, **kwargs):
+
+        class UnsupportedRH(RequestHandler):
+            def _send(self, request: Request):
+                pass
+
+            _SUPPORTED_FEATURES = ()
+            _SUPPORTED_PROXY_SCHEMES = ()
+            _SUPPORTED_URL_SCHEMES = ()
+
+        super().__init__(*args, **kwargs)
+        self._request_director = self.build_request_director([UnsupportedRH])
+
+
 class TestRequestDirector:
 
     def test_handler_operations(self):
@@ -1234,6 +1259,12 @@ def test_file_urls_error(self):
             with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'):
                 ydl.urlopen('file://')
 
+    @pytest.mark.parametrize('scheme', (['ws', 'wss']))
+    def test_websocket_unavailable_error(self, scheme):
+        with AllUnsupportedRHYDL() as ydl:
+            with pytest.raises(RequestError, match=r'This request requires WebSocket support'):
+                ydl.urlopen(f'{scheme}://')
+
     def test_legacy_server_connect_error(self):
         with FakeRHYDL() as ydl:
             for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'):
index d8ac88dad5846751756988d5a1bd8d7f3dc36265..71f783e132dcccbac52d34b16535e1ddbf47c92a 100644 (file)
@@ -210,6 +210,16 @@ def do_GET(self):
             self.wfile.write(payload.encode())
 
 
+class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
+    def handle(self):
+        import websockets.sync.server
+        protocol = websockets.ServerProtocol()
+        connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
+        connection.handshake()
+        connection.send(json.dumps(self.socks_info))
+        connection.close()
+
+
 @contextlib.contextmanager
 def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
     server = server_thread = None
@@ -252,8 +262,22 @@ def socks_info_request(self, handler, target_domain=None, target_port=None, **re
         return json.loads(handler.send(request).read().decode())
 
 
+class WebSocketSocksTestProxyContext(SocksProxyTestContext):
+    REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler
+
+    def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
+        request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
+        handler.validate(request)
+        ws = handler.send(request)
+        ws.send('socks_info')
+        socks_info = ws.recv()
+        ws.close()
+        return json.loads(socks_info)
+
+
 CTX_MAP = {
     'http': HTTPSocksTestProxyContext,
+    'ws': WebSocketSocksTestProxyContext,
 }
 
 
@@ -263,7 +287,7 @@ def ctx(request):
 
 
 class TestSocks4Proxy:
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks4_no_auth(self, handler, ctx):
         with handler() as rh:
             with ctx.socks_server(Socks4ProxyHandler) as server_address:
@@ -271,7 +295,7 @@ def test_socks4_no_auth(self, handler, ctx):
                     rh, proxies={'all': f'socks4://{server_address}'})
                 assert response['version'] == 4
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks4_auth(self, handler, ctx):
         with handler() as rh:
             with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
@@ -281,7 +305,7 @@ def test_socks4_auth(self, handler, ctx):
                     rh, proxies={'all': f'socks4://user:@{server_address}'})
                 assert response['version'] == 4
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks4a_ipv4_target(self, handler, ctx):
         with ctx.socks_server(Socks4ProxyHandler) as server_address:
             with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
@@ -289,7 +313,7 @@ def test_socks4a_ipv4_target(self, handler, ctx):
                 assert response['version'] == 4
                 assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks4a_domain_target(self, handler, ctx):
         with ctx.socks_server(Socks4ProxyHandler) as server_address:
             with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
@@ -298,7 +322,7 @@ def test_socks4a_domain_target(self, handler, ctx):
                 assert response['ipv4_address'] is None
                 assert response['domain_address'] == 'localhost'
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_ipv4_client_source_address(self, handler, ctx):
         with ctx.socks_server(Socks4ProxyHandler) as server_address:
             source_address = f'127.0.0.{random.randint(5, 255)}'
@@ -308,7 +332,7 @@ def test_ipv4_client_source_address(self, handler, ctx):
                 assert response['client_address'][0] == source_address
                 assert response['version'] == 4
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     @pytest.mark.parametrize('reply_code', [
         Socks4CD.REQUEST_REJECTED_OR_FAILED,
         Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
@@ -320,7 +344,7 @@ def test_socks4_errors(self, handler, ctx, reply_code):
                 with pytest.raises(ProxyError):
                     ctx.socks_info_request(rh)
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_ipv6_socks4_proxy(self, handler, ctx):
         with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
             with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
@@ -329,7 +353,7 @@ def test_ipv6_socks4_proxy(self, handler, ctx):
                 assert response['ipv4_address'] == '127.0.0.1'
                 assert response['version'] == 4
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_timeout(self, handler, ctx):
         with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
             with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
@@ -339,7 +363,7 @@ def test_timeout(self, handler, ctx):
 
 class TestSocks5Proxy:
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks5_no_auth(self, handler, ctx):
         with ctx.socks_server(Socks5ProxyHandler) as server_address:
             with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@@ -347,7 +371,7 @@ def test_socks5_no_auth(self, handler, ctx):
                 assert response['auth_methods'] == [0x0]
                 assert response['version'] == 5
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks5_user_pass(self, handler, ctx):
         with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
             with handler() as rh:
@@ -360,7 +384,7 @@ def test_socks5_user_pass(self, handler, ctx):
                 assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
                 assert response['version'] == 5
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks5_ipv4_target(self, handler, ctx):
         with ctx.socks_server(Socks5ProxyHandler) as server_address:
             with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@@ -368,7 +392,7 @@ def test_socks5_ipv4_target(self, handler, ctx):
                 assert response['ipv4_address'] == '127.0.0.1'
                 assert response['version'] == 5
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks5_domain_target(self, handler, ctx):
         with ctx.socks_server(Socks5ProxyHandler) as server_address:
             with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@@ -376,7 +400,7 @@ def test_socks5_domain_target(self, handler, ctx):
                 assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
                 assert response['version'] == 5
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks5h_domain_target(self, handler, ctx):
         with ctx.socks_server(Socks5ProxyHandler) as server_address:
             with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
@@ -385,7 +409,7 @@ def test_socks5h_domain_target(self, handler, ctx):
                 assert response['domain_address'] == 'localhost'
                 assert response['version'] == 5
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks5h_ip_target(self, handler, ctx):
         with ctx.socks_server(Socks5ProxyHandler) as server_address:
             with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
@@ -394,7 +418,7 @@ def test_socks5h_ip_target(self, handler, ctx):
                 assert response['domain_address'] is None
                 assert response['version'] == 5
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_socks5_ipv6_destination(self, handler, ctx):
         with ctx.socks_server(Socks5ProxyHandler) as server_address:
             with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@@ -402,7 +426,7 @@ def test_socks5_ipv6_destination(self, handler, ctx):
                 assert response['ipv6_address'] == '::1'
                 assert response['version'] == 5
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_ipv6_socks5_proxy(self, handler, ctx):
         with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
             with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@@ -413,7 +437,7 @@ def test_ipv6_socks5_proxy(self, handler, ctx):
 
     # XXX: is there any feasible way of testing IPv6 source addresses?
     # Same would go for non-proxy source_address test...
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_ipv4_client_source_address(self, handler, ctx):
         with ctx.socks_server(Socks5ProxyHandler) as server_address:
             source_address = f'127.0.0.{random.randint(5, 255)}'
@@ -422,7 +446,7 @@ def test_ipv4_client_source_address(self, handler, ctx):
                 assert response['client_address'][0] == source_address
                 assert response['version'] == 5
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
     @pytest.mark.parametrize('reply_code', [
         Socks5Reply.GENERAL_FAILURE,
         Socks5Reply.CONNECTION_NOT_ALLOWED,
@@ -439,7 +463,7 @@ def test_socks5_errors(self, handler, ctx, reply_code):
                 with pytest.raises(ProxyError):
                     ctx.socks_info_request(rh)
 
-    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
+    @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Websockets', 'ws')], indirect=True)
     def test_timeout(self, handler, ctx):
         with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address:
             with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh:
diff --git a/test/test_websockets.py b/test/test_websockets.py
new file mode 100644 (file)
index 0000000..39d3c7d
--- /dev/null
@@ -0,0 +1,380 @@
+#!/usr/bin/env python3
+
+# Allow direct execution
+import os
+import sys
+
+import pytest
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import http.client
+import http.cookiejar
+import http.server
+import json
+import random
+import ssl
+import threading
+
+from yt_dlp import socks
+from yt_dlp.cookies import YoutubeDLCookieJar
+from yt_dlp.dependencies import websockets
+from yt_dlp.networking import Request
+from yt_dlp.networking.exceptions import (
+    CertificateVerifyError,
+    HTTPError,
+    ProxyError,
+    RequestError,
+    SSLError,
+    TransportError,
+)
+from yt_dlp.utils.networking import HTTPHeaderDict
+
+from test.conftest import validate_and_send
+
+TEST_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def websocket_handler(websocket):
+    for message in websocket:
+        if isinstance(message, bytes):
+            if message == b'bytes':
+                return websocket.send('2')
+        elif isinstance(message, str):
+            if message == 'headers':
+                return websocket.send(json.dumps(dict(websocket.request.headers)))
+            elif message == 'path':
+                return websocket.send(websocket.request.path)
+            elif message == 'source_address':
+                return websocket.send(websocket.remote_address[0])
+            elif message == 'str':
+                return websocket.send('1')
+        return websocket.send(message)
+
+
+def process_request(self, request):
+    if request.path.startswith('/gen_'):
+        status = http.HTTPStatus(int(request.path[5:]))
+        if 300 <= status.value <= 300:
+            return websockets.http11.Response(
+                status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
+        return self.protocol.reject(status.value, status.phrase)
+    return self.protocol.accept(request)
+
+
+def create_websocket_server(**ws_kwargs):
+    import websockets.sync.server
+    wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
+    ws_port = wsd.socket.getsockname()[1]
+    ws_server_thread = threading.Thread(target=wsd.serve_forever)
+    ws_server_thread.daemon = True
+    ws_server_thread.start()
+    return ws_server_thread, ws_port
+
+
+def create_ws_websocket_server():
+    return create_websocket_server()
+
+
+def create_wss_websocket_server():
+    certfn = os.path.join(TEST_DIR, 'testcert.pem')
+    sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+    sslctx.load_cert_chain(certfn, None)
+    return create_websocket_server(ssl_context=sslctx)
+
+
+MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
+
+
+def create_mtls_wss_websocket_server():
+    certfn = os.path.join(TEST_DIR, 'testcert.pem')
+    cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
+
+    sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+    sslctx.verify_mode = ssl.CERT_REQUIRED
+    sslctx.load_verify_locations(cafile=cacertfn)
+    sslctx.load_cert_chain(certfn, None)
+
+    return create_websocket_server(ssl_context=sslctx)
+
+
+@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
+class TestWebsSocketRequestHandlerConformance:
+    @classmethod
+    def setup_class(cls):
+        cls.ws_thread, cls.ws_port = create_ws_websocket_server()
+        cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
+
+        cls.wss_thread, cls.wss_port = create_wss_websocket_server()
+        cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
+
+        cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
+        cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
+
+        cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
+        cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
+
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    def test_basic_websockets(self, handler):
+        with handler() as rh:
+            ws = validate_and_send(rh, Request(self.ws_base_url))
+            assert 'upgrade' in ws.headers
+            assert ws.status == 101
+            ws.send('foo')
+            assert ws.recv() == 'foo'
+            ws.close()
+
+    # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
+    @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    def test_send_types(self, handler, msg, opcode):
+        with handler() as rh:
+            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws.send(msg)
+            assert int(ws.recv()) == opcode
+            ws.close()
+
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    def test_verify_cert(self, handler):
+        with handler() as rh:
+            with pytest.raises(CertificateVerifyError):
+                validate_and_send(rh, Request(self.wss_base_url))
+
+        with handler(verify=False) as rh:
+            ws = validate_and_send(rh, Request(self.wss_base_url))
+            assert ws.status == 101
+            ws.close()
+
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    def test_ssl_error(self, handler):
+        with handler(verify=False) as rh:
+            with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info:
+                validate_and_send(rh, Request(self.bad_wss_host))
+            assert not issubclass(exc_info.type, CertificateVerifyError)
+
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    @pytest.mark.parametrize('path,expected', [
+        # Unicode characters should be encoded with uppercase percent-encoding
+        ('/中文', '/%E4%B8%AD%E6%96%87'),
+        # don't normalize existing percent encodings
+        ('/%c7%9f', '/%c7%9f'),
+    ])
+    def test_percent_encode(self, handler, path, expected):
+        with handler() as rh:
+            ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
+            ws.send('path')
+            assert ws.recv() == expected
+            assert ws.status == 101
+            ws.close()
+
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    def test_remove_dot_segments(self, handler):
+        with handler() as rh:
+            # This isn't a comprehensive test,
+            # but it should be enough to check whether the handler is removing dot segments
+            ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
+            assert ws.status == 101
+            ws.send('path')
+            assert ws.recv() == '/test'
+            ws.close()
+
+    # We are restricted to known HTTP status codes in http.HTTPStatus
+    # Redirects are not supported for websockets
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
+    def test_raise_http_error(self, handler, status):
+        with handler() as rh:
+            with pytest.raises(HTTPError) as exc_info:
+                validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
+            assert exc_info.value.status == status
+
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    @pytest.mark.parametrize('params,extensions', [
+        ({'timeout': 0.00001}, {}),
+        ({}, {'timeout': 0.00001}),
+    ])
+    def test_timeout(self, handler, params, extensions):
+        with handler(**params) as rh:
+            with pytest.raises(TransportError):
+                validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
+
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    def test_cookies(self, handler):
+        cookiejar = YoutubeDLCookieJar()
+        cookiejar.set_cookie(http.cookiejar.Cookie(
+            version=0, name='test', value='ytdlp', port=None, port_specified=False,
+            domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
+            path_specified=True, secure=False, expires=None, discard=False, comment=None,
+            comment_url=None, rest={}))
+
+        with handler(cookiejar=cookiejar) as rh:
+            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws.send('headers')
+            assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
+            ws.close()
+
+        with handler() as rh:
+            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws.send('headers')
+            assert 'cookie' not in json.loads(ws.recv())
+            ws.close()
+
+            ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
+            ws.send('headers')
+            assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
+            ws.close()
+
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    def test_source_address(self, handler):
+        source_address = f'127.0.0.{random.randint(5, 255)}'
+        with handler(source_address=source_address) as rh:
+            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws.send('source_address')
+            assert source_address == ws.recv()
+            ws.close()
+
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    def test_response_url(self, handler):
+        with handler() as rh:
+            url = f'{self.ws_base_url}/something'
+            ws = validate_and_send(rh, Request(url))
+            assert ws.url == url
+            ws.close()
+
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    def test_request_headers(self, handler):
+        with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
+            # Global Headers
+            ws = validate_and_send(rh, Request(self.ws_base_url))
+            ws.send('headers')
+            headers = HTTPHeaderDict(json.loads(ws.recv()))
+            assert headers['test1'] == 'test'
+            ws.close()
+
+            # Per request headers, merged with global
+            ws = validate_and_send(rh, Request(
+                self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
+            ws.send('headers')
+            headers = HTTPHeaderDict(json.loads(ws.recv()))
+            assert headers['test1'] == 'test'
+            assert headers['test2'] == 'changed'
+            assert headers['test3'] == 'test3'
+            ws.close()
+
+    @pytest.mark.parametrize('client_cert', (
+        {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
+        {
+            'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
+            'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
+        },
+        {
+            'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
+            'client_certificate_password': 'foobar',
+        },
+        {
+            'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
+            'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
+            'client_certificate_password': 'foobar',
+        }
+    ))
+    @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+    def test_mtls(self, handler, client_cert):
+        with handler(
+            # Disable client-side validation of unacceptable self-signed testcert.pem
+            # The test is of a check on the server side, so unaffected
+            verify=False,
+            client_cert=client_cert
+        ) as rh:
+            validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
+
+
+def create_fake_ws_connection(raised):
+    import websockets.sync.client
+
+    class FakeWsConnection(websockets.sync.client.ClientConnection):
+        def __init__(self, *args, **kwargs):
+            class FakeResponse:
+                body = b''
+                headers = {}
+                status_code = 101
+                reason_phrase = 'test'
+
+            self.response = FakeResponse()
+
+        def send(self, *args, **kwargs):
+            raise raised()
+
+        def recv(self, *args, **kwargs):
+            raise raised()
+
+        def close(self, *args, **kwargs):
+            return
+
+    return FakeWsConnection()
+
+
+@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+class TestWebsocketsRequestHandler:
+    @pytest.mark.parametrize('raised,expected', [
+        # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
+        (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
+        # Requires a response object. Should be covered by HTTP error tests.
+        # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
+        (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
+        # These are subclasses of InvalidHandshake
+        (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
+        (lambda: websockets.exceptions.NegotiationError(), TransportError),
+        # Catch-all
+        (lambda: websockets.exceptions.WebSocketException(), TransportError),
+        (lambda: TimeoutError(), TransportError),
+        # These may be raised by our create_connection implementation, which should also be caught
+        (lambda: OSError(), TransportError),
+        (lambda: ssl.SSLError(), SSLError),
+        (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
+        (lambda: socks.ProxyError(), ProxyError),
+    ])
+    def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
+        import websockets.sync.client
+
+        import yt_dlp.networking._websockets
+        with handler() as rh:
+            def fake_connect(*args, **kwargs):
+                raise raised()
+            monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
+            monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
+            with pytest.raises(expected) as exc_info:
+                rh.send(Request('ws://fake-url'))
+            assert exc_info.type is expected
+
+    @pytest.mark.parametrize('raised,expected,match', [
+        # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
+        (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
+        (lambda: RuntimeError(), TransportError, None),
+        (lambda: TimeoutError(), TransportError, None),
+        (lambda: TypeError(), RequestError, None),
+        (lambda: socks.ProxyError(), ProxyError, None),
+        # Catch-all
+        (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
+    ])
+    def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
+        from yt_dlp.networking._websockets import WebsocketsResponseAdapter
+        ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
+        with pytest.raises(expected, match=match) as exc_info:
+            ws.send('test')
+        assert exc_info.type is expected
+
+    @pytest.mark.parametrize('raised,expected,match', [
+        # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
+        (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
+        (lambda: RuntimeError(), TransportError, None),
+        (lambda: TimeoutError(), TransportError, None),
+        (lambda: socks.ProxyError(), ProxyError, None),
+        # Catch-all
+        (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
+    ])
+    def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
+        from yt_dlp.networking._websockets import WebsocketsResponseAdapter
+        ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
+        with pytest.raises(expected, match=match) as exc_info:
+            ws.recv()
+        assert exc_info.type is expected
index 740826b452391854c8e47aaf8acbee1157bc5099..85b282bd51ccddd927b6c51f516d872191d14eee 100644 (file)
@@ -4052,6 +4052,7 @@ def urlopen(self, req):
             return self._request_director.send(req)
         except NoSupportingHandlers as e:
             for ue in e.unsupported_errors:
+                # FIXME: This depends on the order of errors.
                 if not (ue.handler and ue.msg):
                     continue
                 if ue.handler.RH_KEY == 'Urllib' and 'unsupported url scheme: "file"' in ue.msg.lower():
@@ -4061,6 +4062,15 @@ def urlopen(self, req):
                 if 'unsupported proxy type: "https"' in ue.msg.lower():
                     raise RequestError(
                         'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests')
+
+                elif (
+                    re.match(r'unsupported url scheme: "wss?"', ue.msg.lower())
+                    and 'websockets' not in self._request_director.handlers
+                ):
+                    raise RequestError(
+                        'This request requires WebSocket support. '
+                        'Ensure one of the following dependencies are installed: websockets',
+                        cause=ue) from ue
             raise
         except SSLError as e:
             if 'UNSAFE_LEGACY_RENEGOTIATION_DISABLED' in str(e):
index 5720f6eb8ff48e27bd6ece8b53da401305885f81..fef8bff73ad9753bfa89d6e6b0f8c6f1a2adbbff 100644 (file)
@@ -6,7 +6,7 @@
 from .common import FileDownloader
 from .external import FFmpegFD
 from ..networking import Request
-from ..utils import DownloadError, WebSocketsWrapper, str_or_none, try_get
+from ..utils import DownloadError, str_or_none, try_get
 
 
 class NiconicoDmcFD(FileDownloader):
@@ -64,7 +64,6 @@ def real_download(self, filename, info_dict):
         ws_url = info_dict['url']
         ws_extractor = info_dict['ws']
         ws_origin_host = info_dict['origin']
-        cookies = info_dict.get('cookies')
         live_quality = info_dict.get('live_quality', 'high')
         live_latency = info_dict.get('live_latency', 'high')
         dl = FFmpegFD(self.ydl, self.params or {})
@@ -76,12 +75,7 @@ def real_download(self, filename, info_dict):
 
         def communicate_ws(reconnect):
             if reconnect:
-                ws = WebSocketsWrapper(ws_url, {
-                    'Cookies': str_or_none(cookies) or '',
-                    'Origin': f'https://{ws_origin_host}',
-                    'Accept': '*/*',
-                    'User-Agent': self.params['http_headers']['User-Agent'],
-                })
+                ws = self.ydl.urlopen(Request(ws_url, headers={'Origin': f'https://{ws_origin_host}'}))
                 if self.ydl.params.get('verbose', False):
                     self.to_screen('[debug] Sending startWatching request')
                 ws.send(json.dumps({
index ba19b6cab421c444f205890da9534d8bd07828ab..bbc4b56931254abc144fef1ad5cec578344ed16a 100644 (file)
@@ -2,11 +2,9 @@
 
 from .common import InfoExtractor
 from ..compat import compat_parse_qs
-from ..dependencies import websockets
 from ..networking import Request
 from ..utils import (
     ExtractorError,
-    WebSocketsWrapper,
     js_to_json,
     traverse_obj,
     update_url_query,
@@ -167,8 +165,6 @@ class FC2LiveIE(InfoExtractor):
     }]
 
     def _real_extract(self, url):
-        if not websockets:
-            raise ExtractorError('websockets library is not available. Please install it.', expected=True)
         video_id = self._match_id(url)
         webpage = self._download_webpage('https://live.fc2.com/%s/' % video_id, video_id)
 
@@ -199,13 +195,9 @@ def _real_extract(self, url):
         ws_url = update_url_query(control_server['url'], {'control_token': control_server['control_token']})
         playlist_data = None
 
-        self.to_screen('%s: Fetching HLS playlist info via WebSocket' % video_id)
-        ws = WebSocketsWrapper(ws_url, {
-            'Cookie': str(self._get_cookies('https://live.fc2.com/'))[12:],
+        ws = self._request_webpage(Request(ws_url, headers={
             'Origin': 'https://live.fc2.com',
-            'Accept': '*/*',
-            'User-Agent': self.get_param('http_headers')['User-Agent'],
-        })
+        }), video_id, note='Fetching HLS playlist info via WebSocket')
 
         self.write_debug('Sending HLS server request')
 
index fa2d709d28f7abd76a644cf592b74e19278c2494..797b5268af4936217b3acbeda416ae58edfdc5eb 100644 (file)
@@ -8,12 +8,11 @@
 from urllib.parse import urlparse
 
 from .common import InfoExtractor, SearchInfoExtractor
-from ..dependencies import websockets
+from ..networking import Request
 from ..networking.exceptions import HTTPError
 from ..utils import (
     ExtractorError,
     OnDemandPagedList,
-    WebSocketsWrapper,
     bug_reports_message,
     clean_html,
     float_or_none,
@@ -934,8 +933,6 @@ class NiconicoLiveIE(InfoExtractor):
     _KNOWN_LATENCY = ('high', 'low')
 
     def _real_extract(self, url):
-        if not websockets:
-            raise ExtractorError('websockets library is not available. Please install it.', expected=True)
         video_id = self._match_id(url)
         webpage, urlh = self._download_webpage_handle(f'https://live.nicovideo.jp/watch/{video_id}', video_id)
 
@@ -950,17 +947,13 @@ def _real_extract(self, url):
         })
 
         hostname = remove_start(urlparse(urlh.url).hostname, 'sp.')
-        cookies = try_get(urlh.url, self._downloader._calc_cookies)
         latency = try_get(self._configuration_arg('latency'), lambda x: x[0])
         if latency not in self._KNOWN_LATENCY:
             latency = 'high'
 
-        ws = WebSocketsWrapper(ws_url, {
-            'Cookies': str_or_none(cookies) or '',
-            'Origin': f'https://{hostname}',
-            'Accept': '*/*',
-            'User-Agent': self.get_param('http_headers')['User-Agent'],
-        })
+        ws = self._request_webpage(
+            Request(ws_url, headers={'Origin': f'https://{hostname}'}),
+            video_id=video_id, note='Connecting to WebSocket server')
 
         self.write_debug('[debug] Sending HLS server request')
         ws.send(json.dumps({
@@ -1034,7 +1027,6 @@ def _real_extract(self, url):
                 'protocol': 'niconico_live',
                 'ws': ws,
                 'video_id': video_id,
-                'cookies': cookies,
                 'live_latency': latency,
                 'origin': hostname,
             })
index aa8d0eabe4c7f6fa17ecdb2335ee6879937beece..96c5a0678f85f3262db928bb8d4d35b145e06b82 100644 (file)
     pass
 except Exception as e:
     warnings.warn(f'Failed to import "requests" request handler: {e}' + bug_reports_message())
+
+try:
+    from . import _websockets
+except ImportError:
+    pass
+except Exception as e:
+    warnings.warn(f'Failed to import "websockets" request handler: {e}' + bug_reports_message())
+
diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py
new file mode 100644 (file)
index 0000000..ad85554
--- /dev/null
@@ -0,0 +1,159 @@
+from __future__ import annotations
+
+import io
+import logging
+import ssl
+import sys
+
+from ._helper import create_connection, select_proxy, make_socks_proxy_opts, create_socks_proxy_socket
+from .common import Response, register_rh, Features
+from .exceptions import (
+    CertificateVerifyError,
+    HTTPError,
+    RequestError,
+    SSLError,
+    TransportError, ProxyError,
+)
+from .websocket import WebSocketRequestHandler, WebSocketResponse
+from ..compat import functools
+from ..dependencies import websockets
+from ..utils import int_or_none
+from ..socks import ProxyError as SocksProxyError
+
+if not websockets:
+    raise ImportError('websockets is not installed')
+
+import websockets.version
+
+websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
+if websockets_version < (12, 0):
+    raise ImportError('Only websockets>=12.0 is supported')
+
+import websockets.sync.client
+from websockets.uri import parse_uri
+
+
+class WebsocketsResponseAdapter(WebSocketResponse):
+
+    def __init__(self, wsw: websockets.sync.client.ClientConnection, url):
+        super().__init__(
+            fp=io.BytesIO(wsw.response.body or b''),
+            url=url,
+            headers=wsw.response.headers,
+            status=wsw.response.status_code,
+            reason=wsw.response.reason_phrase,
+        )
+        self.wsw = wsw
+
+    def close(self):
+        self.wsw.close()
+        super().close()
+
+    def send(self, message):
+        # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
+        try:
+            return self.wsw.send(message)
+        except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
+            raise TransportError(cause=e) from e
+        except SocksProxyError as e:
+            raise ProxyError(cause=e) from e
+        except TypeError as e:
+            raise RequestError(cause=e) from e
+
+    def recv(self):
+        # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
+        try:
+            return self.wsw.recv()
+        except SocksProxyError as e:
+            raise ProxyError(cause=e) from e
+        except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
+            raise TransportError(cause=e) from e
+
+
+@register_rh
+class WebsocketsRH(WebSocketRequestHandler):
+    """
+    Websockets request handler
+    https://websockets.readthedocs.io
+    https://github.com/python-websockets/websockets
+    """
+    _SUPPORTED_URL_SCHEMES = ('wss', 'ws')
+    _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
+    _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
+    RH_NAME = 'websockets'
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        for name in ('websockets.client', 'websockets.server'):
+            logger = logging.getLogger(name)
+            handler = logging.StreamHandler(stream=sys.stdout)
+            handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
+            logger.addHandler(handler)
+            if self.verbose:
+                logger.setLevel(logging.DEBUG)
+
+    def _check_extensions(self, extensions):
+        super()._check_extensions(extensions)
+        extensions.pop('timeout', None)
+        extensions.pop('cookiejar', None)
+
+    def _send(self, request):
+        timeout = float(request.extensions.get('timeout') or self.timeout)
+        headers = self._merge_headers(request.headers)
+        if 'cookie' not in headers:
+            cookiejar = request.extensions.get('cookiejar') or self.cookiejar
+            cookie_header = cookiejar.get_cookie_header(request.url)
+            if cookie_header:
+                headers['cookie'] = cookie_header
+
+        wsuri = parse_uri(request.url)
+        create_conn_kwargs = {
+            'source_address': (self.source_address, 0) if self.source_address else None,
+            'timeout': timeout
+        }
+        proxy = select_proxy(request.url, request.proxies or self.proxies or {})
+        try:
+            if proxy:
+                socks_proxy_options = make_socks_proxy_opts(proxy)
+                sock = create_connection(
+                    address=(socks_proxy_options['addr'], socks_proxy_options['port']),
+                    _create_socket_func=functools.partial(
+                        create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
+                    **create_conn_kwargs
+                )
+            else:
+                sock = create_connection(
+                    address=(wsuri.host, wsuri.port),
+                    **create_conn_kwargs
+                )
+            conn = websockets.sync.client.connect(
+                sock=sock,
+                uri=request.url,
+                additional_headers=headers,
+                open_timeout=timeout,
+                user_agent_header=None,
+                ssl_context=self._make_sslcontext() if wsuri.secure else None,
+                close_timeout=0,  # not ideal, but prevents yt-dlp hanging
+            )
+            return WebsocketsResponseAdapter(conn, url=request.url)
+
+        # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
+        except SocksProxyError as e:
+            raise ProxyError(cause=e) from e
+        except websockets.exceptions.InvalidURI as e:
+            raise RequestError(cause=e) from e
+        except ssl.SSLCertVerificationError as e:
+            raise CertificateVerifyError(cause=e) from e
+        except ssl.SSLError as e:
+            raise SSLError(cause=e) from e
+        except websockets.exceptions.InvalidStatus as e:
+            raise HTTPError(
+                Response(
+                    fp=io.BytesIO(e.response.body),
+                    url=request.url,
+                    headers=e.response.headers,
+                    status=e.response.status_code,
+                    reason=e.response.reason_phrase),
+            ) from e
+        except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
+            raise TransportError(cause=e) from e
diff --git a/yt_dlp/networking/websocket.py b/yt_dlp/networking/websocket.py
new file mode 100644 (file)
index 0000000..09fcf78
--- /dev/null
@@ -0,0 +1,23 @@
+from __future__ import annotations
+
+import abc
+
+from .common import Response, RequestHandler
+
+
+class WebSocketResponse(Response):
+
+    def send(self, message: bytes | str):
+        """
+        Send a message to the server.
+
+        @param message: The message to send. A string (str) is sent as a text frame, bytes is sent as a binary frame.
+        """
+        raise NotImplementedError
+
+    def recv(self):
+        raise NotImplementedError
+
+
+class WebSocketRequestHandler(RequestHandler, abc.ABC):
+    pass
index dde02092c9925d0fd7ee66835dc9f232746aca55..aa9f46d204a408987639b9557d11298859631fe1 100644 (file)
@@ -1,4 +1,6 @@
 """No longer used and new code should not use. Exists only for API compat."""
+import asyncio
+import atexit
 import platform
 import struct
 import sys
 has_websockets = bool(websockets)
 
 
+class WebSocketsWrapper:
+    """Wraps websockets module to use in non-async scopes"""
+    pool = None
+
+    def __init__(self, url, headers=None, connect=True, **ws_kwargs):
+        self.loop = asyncio.new_event_loop()
+        # XXX: "loop" is deprecated
+        self.conn = websockets.connect(
+            url, extra_headers=headers, ping_interval=None,
+            close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'), **ws_kwargs)
+        if connect:
+            self.__enter__()
+        atexit.register(self.__exit__, None, None, None)
+
+    def __enter__(self):
+        if not self.pool:
+            self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
+        return self
+
+    def send(self, *args):
+        self.run_with_loop(self.pool.send(*args), self.loop)
+
+    def recv(self, *args):
+        return self.run_with_loop(self.pool.recv(*args), self.loop)
+
+    def __exit__(self, type, value, traceback):
+        try:
+            return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
+        finally:
+            self.loop.close()
+            self._cancel_all_tasks(self.loop)
+
+    # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
+    # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
+    @staticmethod
+    def run_with_loop(main, loop):
+        if not asyncio.iscoroutine(main):
+            raise ValueError(f'a coroutine was expected, got {main!r}')
+
+        try:
+            return loop.run_until_complete(main)
+        finally:
+            loop.run_until_complete(loop.shutdown_asyncgens())
+            if hasattr(loop, 'shutdown_default_executor'):
+                loop.run_until_complete(loop.shutdown_default_executor())
+
+    @staticmethod
+    def _cancel_all_tasks(loop):
+        to_cancel = asyncio.all_tasks(loop)
+
+        if not to_cancel:
+            return
+
+        for task in to_cancel:
+            task.cancel()
+
+        # XXX: "loop" is removed in python 3.10+
+        loop.run_until_complete(
+            asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
+
+        for task in to_cancel:
+            if task.cancelled():
+                continue
+            if task.exception() is not None:
+                loop.call_exception_handler({
+                    'message': 'unhandled exception during asyncio.run() shutdown',
+                    'exception': task.exception(),
+                    'task': task,
+                })
+
+
 def load_plugins(name, suffix, namespace):
     from ..plugins import load_plugins
     ret = load_plugins(name, suffix)
index 10c7c431103b8b5adb5a0d4bd0ab53358a2f3c48..b0164a8953d29400f6129ededbdfe9bb57ad625c 100644 (file)
@@ -1,5 +1,3 @@
-import asyncio
-import atexit
 import base64
 import binascii
 import calendar
@@ -54,7 +52,7 @@
     compat_os_name,
     compat_shlex_quote,
 )
-from ..dependencies import websockets, xattr
+from ..dependencies import xattr
 
 __name__ = __name__.rsplit('.', 1)[0]  # Pretend to be the parent module
 
@@ -4923,77 +4921,6 @@ def parse_args(self):
         return self.parser.parse_args(self.all_args)
 
 
-class WebSocketsWrapper:
-    """Wraps websockets module to use in non-async scopes"""
-    pool = None
-
-    def __init__(self, url, headers=None, connect=True):
-        self.loop = asyncio.new_event_loop()
-        # XXX: "loop" is deprecated
-        self.conn = websockets.connect(
-            url, extra_headers=headers, ping_interval=None,
-            close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
-        if connect:
-            self.__enter__()
-        atexit.register(self.__exit__, None, None, None)
-
-    def __enter__(self):
-        if not self.pool:
-            self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
-        return self
-
-    def send(self, *args):
-        self.run_with_loop(self.pool.send(*args), self.loop)
-
-    def recv(self, *args):
-        return self.run_with_loop(self.pool.recv(*args), self.loop)
-
-    def __exit__(self, type, value, traceback):
-        try:
-            return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
-        finally:
-            self.loop.close()
-            self._cancel_all_tasks(self.loop)
-
-    # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
-    # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
-    @staticmethod
-    def run_with_loop(main, loop):
-        if not asyncio.iscoroutine(main):
-            raise ValueError(f'a coroutine was expected, got {main!r}')
-
-        try:
-            return loop.run_until_complete(main)
-        finally:
-            loop.run_until_complete(loop.shutdown_asyncgens())
-            if hasattr(loop, 'shutdown_default_executor'):
-                loop.run_until_complete(loop.shutdown_default_executor())
-
-    @staticmethod
-    def _cancel_all_tasks(loop):
-        to_cancel = asyncio.all_tasks(loop)
-
-        if not to_cancel:
-            return
-
-        for task in to_cancel:
-            task.cancel()
-
-        # XXX: "loop" is removed in python 3.10+
-        loop.run_until_complete(
-            asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
-
-        for task in to_cancel:
-            if task.cancelled():
-                continue
-            if task.exception() is not None:
-                loop.call_exception_handler({
-                    'message': 'unhandled exception during asyncio.run() shutdown',
-                    'exception': task.exception(),
-                    'task': task,
-                })
-
-
 def merge_headers(*dicts):
     """Merge dicts of http headers case insensitively, prioritizing the latter ones"""
     return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}