certifi
requests>=2.31.0,<3
urllib3>=1.26.17,<3
+websockets>=12.0
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)
+
+
+def validate_and_send(rh, req):
+ rh.validate(req)
+ return rh.send(req)
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
from yt_dlp.utils.networking import HTTPHeaderDict
+from test.conftest import validate_and_send
+
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
-def validate_and_send(rh, req):
- rh.validate(req)
- return rh.send(req)
-
-
class TestRequestHandlerBase:
@classmethod
def setup_class(cls):
])
@pytest.mark.parametrize('handler', ['Requests'], indirect=True)
def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match):
- from urllib3.response import HTTPResponse as Urllib3Response
from requests.models import Response as RequestsResponse
+ from urllib3.response import HTTPResponse as Urllib3Response
+
from yt_dlp.networking._requests import RequestsResponseAdapter
requests_res = RequestsResponse()
requests_res.raw = Urllib3Response(body=b'', status=200)
('http', False, {}),
('https', False, {}),
]),
+ ('Websockets', [
+ ('ws', False, {}),
+ ('wss', False, {}),
+ ]),
(NoCheckRH, [('http', False, {})]),
(ValidationRH, [('http', UnsupportedRequest, {})])
]
PROXY_SCHEME_TESTS = [
# scheme, expected to fail
- ('Urllib', [
+ ('Urllib', 'http', [
('http', False),
('https', UnsupportedRequest),
('socks4', False),
('socks5h', False),
('socks', UnsupportedRequest),
]),
- ('Requests', [
+ ('Requests', 'http', [
('http', False),
('https', False),
('socks4', False),
('socks5', False),
('socks5h', False),
]),
- (NoCheckRH, [('http', False)]),
- (HTTPSupportedRH, [('http', UnsupportedRequest)]),
+ (NoCheckRH, 'http', [('http', False)]),
+ (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
+ ('Websockets', 'ws', [('http', UnsupportedRequest)]),
+ (NoCheckRH, 'http', [('http', False)]),
+ (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
]
PROXY_KEY_TESTS = [
]
EXTENSION_TESTS = [
- ('Urllib', [
+ ('Urllib', 'http', [
({'cookiejar': 'notacookiejar'}, AssertionError),
({'cookiejar': YoutubeDLCookieJar()}, False),
({'cookiejar': CookieJar()}, AssertionError),
({'timeout': 'notatimeout'}, AssertionError),
({'unsupported': 'value'}, UnsupportedRequest),
]),
- ('Requests', [
+ ('Requests', 'http', [
({'cookiejar': 'notacookiejar'}, AssertionError),
({'cookiejar': YoutubeDLCookieJar()}, False),
({'timeout': 1}, False),
({'timeout': 'notatimeout'}, AssertionError),
({'unsupported': 'value'}, UnsupportedRequest),
]),
- (NoCheckRH, [
+ (NoCheckRH, 'http', [
({'cookiejar': 'notacookiejar'}, False),
({'somerandom': 'test'}, False), # but any extension is allowed through
]),
+ ('Websockets', 'ws', [
+ ({'cookiejar': YoutubeDLCookieJar()}, False),
+ ({'timeout': 2}, False),
+ ]),
]
@pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'}))
run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'})
- @pytest.mark.parametrize('handler,scheme,fail', [
- (handler_tests[0], scheme, fail)
+ @pytest.mark.parametrize('handler,req_scheme,scheme,fail', [
+ (handler_tests[0], handler_tests[1], scheme, fail)
for handler_tests in PROXY_SCHEME_TESTS
- for scheme, fail in handler_tests[1]
+ for scheme, fail in handler_tests[2]
], indirect=['handler'])
- def test_proxy_scheme(self, handler, scheme, fail):
- run_validation(handler, fail, Request('http://', proxies={'http': f'{scheme}://example.com'}))
- run_validation(handler, fail, Request('http://'), proxies={'http': f'{scheme}://example.com'})
+ def test_proxy_scheme(self, handler, req_scheme, scheme, fail):
+ run_validation(handler, fail, Request(f'{req_scheme}://', proxies={req_scheme: f'{scheme}://example.com'}))
+ run_validation(handler, fail, Request(f'{req_scheme}://'), proxies={req_scheme: f'{scheme}://example.com'})
@pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH, 'Requests'], indirect=True)
def test_empty_proxy(self, handler):
def test_invalid_proxy_url(self, handler, proxy_url):
run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': proxy_url}))
- @pytest.mark.parametrize('handler,extensions,fail', [
- (handler_tests[0], extensions, fail)
+ @pytest.mark.parametrize('handler,scheme,extensions,fail', [
+ (handler_tests[0], handler_tests[1], extensions, fail)
for handler_tests in EXTENSION_TESTS
- for extensions, fail in handler_tests[1]
+ for extensions, fail in handler_tests[2]
], indirect=['handler'])
- def test_extension(self, handler, extensions, fail):
+ def test_extension(self, handler, scheme, extensions, fail):
run_validation(
- handler, fail, Request('http://', extensions=extensions))
+ handler, fail, Request(f'{scheme}://', extensions=extensions))
def test_invalid_request_type(self):
rh = self.ValidationRH(logger=FakeLogger())
self._request_director = self.build_request_director([FakeRH])
+class AllUnsupportedRHYDL(FakeYDL):
+
+ def __init__(self, *args, **kwargs):
+
+ class UnsupportedRH(RequestHandler):
+ def _send(self, request: Request):
+ pass
+
+ _SUPPORTED_FEATURES = ()
+ _SUPPORTED_PROXY_SCHEMES = ()
+ _SUPPORTED_URL_SCHEMES = ()
+
+ super().__init__(*args, **kwargs)
+ self._request_director = self.build_request_director([UnsupportedRH])
+
+
class TestRequestDirector:
def test_handler_operations(self):
with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'):
ydl.urlopen('file://')
+ @pytest.mark.parametrize('scheme', (['ws', 'wss']))
+ def test_websocket_unavailable_error(self, scheme):
+ with AllUnsupportedRHYDL() as ydl:
+ with pytest.raises(RequestError, match=r'This request requires WebSocket support'):
+ ydl.urlopen(f'{scheme}://')
+
def test_legacy_server_connect_error(self):
with FakeRHYDL() as ydl:
for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'):
self.wfile.write(payload.encode())
+class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
+ def handle(self):
+ import websockets.sync.server
+ protocol = websockets.ServerProtocol()
+ connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
+ connection.handshake()
+ connection.send(json.dumps(self.socks_info))
+ connection.close()
+
+
@contextlib.contextmanager
def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
server = server_thread = None
return json.loads(handler.send(request).read().decode())
+class WebSocketSocksTestProxyContext(SocksProxyTestContext):
+ REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler
+
+ def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
+ request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
+ handler.validate(request)
+ ws = handler.send(request)
+ ws.send('socks_info')
+ socks_info = ws.recv()
+ ws.close()
+ return json.loads(socks_info)
+
+
CTX_MAP = {
'http': HTTPSocksTestProxyContext,
+ 'ws': WebSocketSocksTestProxyContext,
}
class TestSocks4Proxy:
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4_no_auth(self, handler, ctx):
with handler() as rh:
with ctx.socks_server(Socks4ProxyHandler) as server_address:
rh, proxies={'all': f'socks4://{server_address}'})
assert response['version'] == 4
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4_auth(self, handler, ctx):
with handler() as rh:
with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
rh, proxies={'all': f'socks4://user:@{server_address}'})
assert response['version'] == 4
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4a_ipv4_target(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address:
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
assert response['version'] == 4
assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4a_domain_target(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address:
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
assert response['ipv4_address'] is None
assert response['domain_address'] == 'localhost'
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
assert response['client_address'][0] == source_address
assert response['version'] == 4
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('reply_code', [
Socks4CD.REQUEST_REJECTED_OR_FAILED,
Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
with pytest.raises(ProxyError):
ctx.socks_info_request(rh)
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv6_socks4_proxy(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
assert response['ipv4_address'] == '127.0.0.1'
assert response['version'] == 4
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_timeout(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
class TestSocks5Proxy:
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_no_auth(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
assert response['auth_methods'] == [0x0]
assert response['version'] == 5
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_user_pass(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
with handler() as rh:
assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
assert response['version'] == 5
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_ipv4_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
assert response['ipv4_address'] == '127.0.0.1'
assert response['version'] == 5
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_domain_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
assert response['version'] == 5
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5h_domain_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
assert response['domain_address'] == 'localhost'
assert response['version'] == 5
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5h_ip_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
assert response['domain_address'] is None
assert response['version'] == 5
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_ipv6_destination(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
assert response['ipv6_address'] == '::1'
assert response['version'] == 5
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv6_socks5_proxy(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
# XXX: is there any feasible way of testing IPv6 source addresses?
# Same would go for non-proxy source_address test...
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
assert response['client_address'][0] == source_address
assert response['version'] == 5
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('reply_code', [
Socks5Reply.GENERAL_FAILURE,
Socks5Reply.CONNECTION_NOT_ALLOWED,
with pytest.raises(ProxyError):
ctx.socks_info_request(rh)
- @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Websockets', 'ws')], indirect=True)
def test_timeout(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh:
--- /dev/null
+#!/usr/bin/env python3
+
+# Allow direct execution
+import os
+import sys
+
+import pytest
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import http.client
+import http.cookiejar
+import http.server
+import json
+import random
+import ssl
+import threading
+
+from yt_dlp import socks
+from yt_dlp.cookies import YoutubeDLCookieJar
+from yt_dlp.dependencies import websockets
+from yt_dlp.networking import Request
+from yt_dlp.networking.exceptions import (
+ CertificateVerifyError,
+ HTTPError,
+ ProxyError,
+ RequestError,
+ SSLError,
+ TransportError,
+)
+from yt_dlp.utils.networking import HTTPHeaderDict
+
+from test.conftest import validate_and_send
+
+TEST_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def websocket_handler(websocket):
+ for message in websocket:
+ if isinstance(message, bytes):
+ if message == b'bytes':
+ return websocket.send('2')
+ elif isinstance(message, str):
+ if message == 'headers':
+ return websocket.send(json.dumps(dict(websocket.request.headers)))
+ elif message == 'path':
+ return websocket.send(websocket.request.path)
+ elif message == 'source_address':
+ return websocket.send(websocket.remote_address[0])
+ elif message == 'str':
+ return websocket.send('1')
+ return websocket.send(message)
+
+
+def process_request(self, request):
+ if request.path.startswith('/gen_'):
+ status = http.HTTPStatus(int(request.path[5:]))
+ if 300 <= status.value <= 300:
+ return websockets.http11.Response(
+ status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
+ return self.protocol.reject(status.value, status.phrase)
+ return self.protocol.accept(request)
+
+
+def create_websocket_server(**ws_kwargs):
+ import websockets.sync.server
+ wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
+ ws_port = wsd.socket.getsockname()[1]
+ ws_server_thread = threading.Thread(target=wsd.serve_forever)
+ ws_server_thread.daemon = True
+ ws_server_thread.start()
+ return ws_server_thread, ws_port
+
+
+def create_ws_websocket_server():
+ return create_websocket_server()
+
+
+def create_wss_websocket_server():
+ certfn = os.path.join(TEST_DIR, 'testcert.pem')
+ sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ sslctx.load_cert_chain(certfn, None)
+ return create_websocket_server(ssl_context=sslctx)
+
+
+MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
+
+
+def create_mtls_wss_websocket_server():
+ certfn = os.path.join(TEST_DIR, 'testcert.pem')
+ cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
+
+ sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ sslctx.verify_mode = ssl.CERT_REQUIRED
+ sslctx.load_verify_locations(cafile=cacertfn)
+ sslctx.load_cert_chain(certfn, None)
+
+ return create_websocket_server(ssl_context=sslctx)
+
+
+@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
+class TestWebsSocketRequestHandlerConformance:
+ @classmethod
+ def setup_class(cls):
+ cls.ws_thread, cls.ws_port = create_ws_websocket_server()
+ cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
+
+ cls.wss_thread, cls.wss_port = create_wss_websocket_server()
+ cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
+
+ cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
+ cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
+
+ cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
+ cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_basic_websockets(self, handler):
+ with handler() as rh:
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ assert 'upgrade' in ws.headers
+ assert ws.status == 101
+ ws.send('foo')
+ assert ws.recv() == 'foo'
+ ws.close()
+
+ # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
+ @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_send_types(self, handler, msg, opcode):
+ with handler() as rh:
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ ws.send(msg)
+ assert int(ws.recv()) == opcode
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_verify_cert(self, handler):
+ with handler() as rh:
+ with pytest.raises(CertificateVerifyError):
+ validate_and_send(rh, Request(self.wss_base_url))
+
+ with handler(verify=False) as rh:
+ ws = validate_and_send(rh, Request(self.wss_base_url))
+ assert ws.status == 101
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_ssl_error(self, handler):
+ with handler(verify=False) as rh:
+ with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info:
+ validate_and_send(rh, Request(self.bad_wss_host))
+ assert not issubclass(exc_info.type, CertificateVerifyError)
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ @pytest.mark.parametrize('path,expected', [
+ # Unicode characters should be encoded with uppercase percent-encoding
+ ('/ä¸æ–‡', '/%E4%B8%AD%E6%96%87'),
+ # don't normalize existing percent encodings
+ ('/%c7%9f', '/%c7%9f'),
+ ])
+ def test_percent_encode(self, handler, path, expected):
+ with handler() as rh:
+ ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
+ ws.send('path')
+ assert ws.recv() == expected
+ assert ws.status == 101
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_remove_dot_segments(self, handler):
+ with handler() as rh:
+ # This isn't a comprehensive test,
+ # but it should be enough to check whether the handler is removing dot segments
+ ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
+ assert ws.status == 101
+ ws.send('path')
+ assert ws.recv() == '/test'
+ ws.close()
+
+ # We are restricted to known HTTP status codes in http.HTTPStatus
+ # Redirects are not supported for websockets
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
+ def test_raise_http_error(self, handler, status):
+ with handler() as rh:
+ with pytest.raises(HTTPError) as exc_info:
+ validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
+ assert exc_info.value.status == status
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ @pytest.mark.parametrize('params,extensions', [
+ ({'timeout': 0.00001}, {}),
+ ({}, {'timeout': 0.00001}),
+ ])
+ def test_timeout(self, handler, params, extensions):
+ with handler(**params) as rh:
+ with pytest.raises(TransportError):
+ validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_cookies(self, handler):
+ cookiejar = YoutubeDLCookieJar()
+ cookiejar.set_cookie(http.cookiejar.Cookie(
+ version=0, name='test', value='ytdlp', port=None, port_specified=False,
+ domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
+ path_specified=True, secure=False, expires=None, discard=False, comment=None,
+ comment_url=None, rest={}))
+
+ with handler(cookiejar=cookiejar) as rh:
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ ws.send('headers')
+ assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
+ ws.close()
+
+ with handler() as rh:
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ ws.send('headers')
+ assert 'cookie' not in json.loads(ws.recv())
+ ws.close()
+
+ ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
+ ws.send('headers')
+ assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_source_address(self, handler):
+ source_address = f'127.0.0.{random.randint(5, 255)}'
+ with handler(source_address=source_address) as rh:
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ ws.send('source_address')
+ assert source_address == ws.recv()
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_response_url(self, handler):
+ with handler() as rh:
+ url = f'{self.ws_base_url}/something'
+ ws = validate_and_send(rh, Request(url))
+ assert ws.url == url
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_request_headers(self, handler):
+ with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
+ # Global Headers
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ ws.send('headers')
+ headers = HTTPHeaderDict(json.loads(ws.recv()))
+ assert headers['test1'] == 'test'
+ ws.close()
+
+ # Per request headers, merged with global
+ ws = validate_and_send(rh, Request(
+ self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
+ ws.send('headers')
+ headers = HTTPHeaderDict(json.loads(ws.recv()))
+ assert headers['test1'] == 'test'
+ assert headers['test2'] == 'changed'
+ assert headers['test3'] == 'test3'
+ ws.close()
+
+ @pytest.mark.parametrize('client_cert', (
+ {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
+ {
+ 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
+ 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
+ },
+ {
+ 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
+ 'client_certificate_password': 'foobar',
+ },
+ {
+ 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
+ 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
+ 'client_certificate_password': 'foobar',
+ }
+ ))
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_mtls(self, handler, client_cert):
+ with handler(
+ # Disable client-side validation of unacceptable self-signed testcert.pem
+ # The test is of a check on the server side, so unaffected
+ verify=False,
+ client_cert=client_cert
+ ) as rh:
+ validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
+
+
+def create_fake_ws_connection(raised):
+ import websockets.sync.client
+
+ class FakeWsConnection(websockets.sync.client.ClientConnection):
+ def __init__(self, *args, **kwargs):
+ class FakeResponse:
+ body = b''
+ headers = {}
+ status_code = 101
+ reason_phrase = 'test'
+
+ self.response = FakeResponse()
+
+ def send(self, *args, **kwargs):
+ raise raised()
+
+ def recv(self, *args, **kwargs):
+ raise raised()
+
+ def close(self, *args, **kwargs):
+ return
+
+ return FakeWsConnection()
+
+
+@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+class TestWebsocketsRequestHandler:
+ @pytest.mark.parametrize('raised,expected', [
+ # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
+ (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
+ # Requires a response object. Should be covered by HTTP error tests.
+ # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
+ (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
+ # These are subclasses of InvalidHandshake
+ (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
+ (lambda: websockets.exceptions.NegotiationError(), TransportError),
+ # Catch-all
+ (lambda: websockets.exceptions.WebSocketException(), TransportError),
+ (lambda: TimeoutError(), TransportError),
+ # These may be raised by our create_connection implementation, which should also be caught
+ (lambda: OSError(), TransportError),
+ (lambda: ssl.SSLError(), SSLError),
+ (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
+ (lambda: socks.ProxyError(), ProxyError),
+ ])
+ def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
+ import websockets.sync.client
+
+ import yt_dlp.networking._websockets
+ with handler() as rh:
+ def fake_connect(*args, **kwargs):
+ raise raised()
+ monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
+ monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
+ with pytest.raises(expected) as exc_info:
+ rh.send(Request('ws://fake-url'))
+ assert exc_info.type is expected
+
+ @pytest.mark.parametrize('raised,expected,match', [
+ # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
+ (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
+ (lambda: RuntimeError(), TransportError, None),
+ (lambda: TimeoutError(), TransportError, None),
+ (lambda: TypeError(), RequestError, None),
+ (lambda: socks.ProxyError(), ProxyError, None),
+ # Catch-all
+ (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
+ ])
+ def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
+ from yt_dlp.networking._websockets import WebsocketsResponseAdapter
+ ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
+ with pytest.raises(expected, match=match) as exc_info:
+ ws.send('test')
+ assert exc_info.type is expected
+
+ @pytest.mark.parametrize('raised,expected,match', [
+ # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
+ (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
+ (lambda: RuntimeError(), TransportError, None),
+ (lambda: TimeoutError(), TransportError, None),
+ (lambda: socks.ProxyError(), ProxyError, None),
+ # Catch-all
+ (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
+ ])
+ def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
+ from yt_dlp.networking._websockets import WebsocketsResponseAdapter
+ ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
+ with pytest.raises(expected, match=match) as exc_info:
+ ws.recv()
+ assert exc_info.type is expected
return self._request_director.send(req)
except NoSupportingHandlers as e:
for ue in e.unsupported_errors:
+ # FIXME: This depends on the order of errors.
if not (ue.handler and ue.msg):
continue
if ue.handler.RH_KEY == 'Urllib' and 'unsupported url scheme: "file"' in ue.msg.lower():
if 'unsupported proxy type: "https"' in ue.msg.lower():
raise RequestError(
'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests')
+
+ elif (
+ re.match(r'unsupported url scheme: "wss?"', ue.msg.lower())
+ and 'websockets' not in self._request_director.handlers
+ ):
+ raise RequestError(
+ 'This request requires WebSocket support. '
+ 'Ensure one of the following dependencies are installed: websockets',
+ cause=ue) from ue
raise
except SSLError as e:
if 'UNSAFE_LEGACY_RENEGOTIATION_DISABLED' in str(e):
from .common import FileDownloader
from .external import FFmpegFD
from ..networking import Request
-from ..utils import DownloadError, WebSocketsWrapper, str_or_none, try_get
+from ..utils import DownloadError, str_or_none, try_get
class NiconicoDmcFD(FileDownloader):
ws_url = info_dict['url']
ws_extractor = info_dict['ws']
ws_origin_host = info_dict['origin']
- cookies = info_dict.get('cookies')
live_quality = info_dict.get('live_quality', 'high')
live_latency = info_dict.get('live_latency', 'high')
dl = FFmpegFD(self.ydl, self.params or {})
def communicate_ws(reconnect):
if reconnect:
- ws = WebSocketsWrapper(ws_url, {
- 'Cookies': str_or_none(cookies) or '',
- 'Origin': f'https://{ws_origin_host}',
- 'Accept': '*/*',
- 'User-Agent': self.params['http_headers']['User-Agent'],
- })
+ ws = self.ydl.urlopen(Request(ws_url, headers={'Origin': f'https://{ws_origin_host}'}))
if self.ydl.params.get('verbose', False):
self.to_screen('[debug] Sending startWatching request')
ws.send(json.dumps({
from .common import InfoExtractor
from ..compat import compat_parse_qs
-from ..dependencies import websockets
from ..networking import Request
from ..utils import (
ExtractorError,
- WebSocketsWrapper,
js_to_json,
traverse_obj,
update_url_query,
}]
def _real_extract(self, url):
- if not websockets:
- raise ExtractorError('websockets library is not available. Please install it.', expected=True)
video_id = self._match_id(url)
webpage = self._download_webpage('https://live.fc2.com/%s/' % video_id, video_id)
ws_url = update_url_query(control_server['url'], {'control_token': control_server['control_token']})
playlist_data = None
- self.to_screen('%s: Fetching HLS playlist info via WebSocket' % video_id)
- ws = WebSocketsWrapper(ws_url, {
- 'Cookie': str(self._get_cookies('https://live.fc2.com/'))[12:],
+ ws = self._request_webpage(Request(ws_url, headers={
'Origin': 'https://live.fc2.com',
- 'Accept': '*/*',
- 'User-Agent': self.get_param('http_headers')['User-Agent'],
- })
+ }), video_id, note='Fetching HLS playlist info via WebSocket')
self.write_debug('Sending HLS server request')
from urllib.parse import urlparse
from .common import InfoExtractor, SearchInfoExtractor
-from ..dependencies import websockets
+from ..networking import Request
from ..networking.exceptions import HTTPError
from ..utils import (
ExtractorError,
OnDemandPagedList,
- WebSocketsWrapper,
bug_reports_message,
clean_html,
float_or_none,
_KNOWN_LATENCY = ('high', 'low')
def _real_extract(self, url):
- if not websockets:
- raise ExtractorError('websockets library is not available. Please install it.', expected=True)
video_id = self._match_id(url)
webpage, urlh = self._download_webpage_handle(f'https://live.nicovideo.jp/watch/{video_id}', video_id)
})
hostname = remove_start(urlparse(urlh.url).hostname, 'sp.')
- cookies = try_get(urlh.url, self._downloader._calc_cookies)
latency = try_get(self._configuration_arg('latency'), lambda x: x[0])
if latency not in self._KNOWN_LATENCY:
latency = 'high'
- ws = WebSocketsWrapper(ws_url, {
- 'Cookies': str_or_none(cookies) or '',
- 'Origin': f'https://{hostname}',
- 'Accept': '*/*',
- 'User-Agent': self.get_param('http_headers')['User-Agent'],
- })
+ ws = self._request_webpage(
+ Request(ws_url, headers={'Origin': f'https://{hostname}'}),
+ video_id=video_id, note='Connecting to WebSocket server')
self.write_debug('[debug] Sending HLS server request')
ws.send(json.dumps({
'protocol': 'niconico_live',
'ws': ws,
'video_id': video_id,
- 'cookies': cookies,
'live_latency': latency,
'origin': hostname,
})
pass
except Exception as e:
warnings.warn(f'Failed to import "requests" request handler: {e}' + bug_reports_message())
+
+try:
+ from . import _websockets
+except ImportError:
+ pass
+except Exception as e:
+ warnings.warn(f'Failed to import "websockets" request handler: {e}' + bug_reports_message())
+
--- /dev/null
+from __future__ import annotations
+
+import io
+import logging
+import ssl
+import sys
+
+from ._helper import create_connection, select_proxy, make_socks_proxy_opts, create_socks_proxy_socket
+from .common import Response, register_rh, Features
+from .exceptions import (
+ CertificateVerifyError,
+ HTTPError,
+ RequestError,
+ SSLError,
+ TransportError, ProxyError,
+)
+from .websocket import WebSocketRequestHandler, WebSocketResponse
+from ..compat import functools
+from ..dependencies import websockets
+from ..utils import int_or_none
+from ..socks import ProxyError as SocksProxyError
+
+if not websockets:
+ raise ImportError('websockets is not installed')
+
+import websockets.version
+
+websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
+if websockets_version < (12, 0):
+ raise ImportError('Only websockets>=12.0 is supported')
+
+import websockets.sync.client
+from websockets.uri import parse_uri
+
+
+class WebsocketsResponseAdapter(WebSocketResponse):
+
+ def __init__(self, wsw: websockets.sync.client.ClientConnection, url):
+ super().__init__(
+ fp=io.BytesIO(wsw.response.body or b''),
+ url=url,
+ headers=wsw.response.headers,
+ status=wsw.response.status_code,
+ reason=wsw.response.reason_phrase,
+ )
+ self.wsw = wsw
+
+ def close(self):
+ self.wsw.close()
+ super().close()
+
+ def send(self, message):
+ # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
+ try:
+ return self.wsw.send(message)
+ except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
+ raise TransportError(cause=e) from e
+ except SocksProxyError as e:
+ raise ProxyError(cause=e) from e
+ except TypeError as e:
+ raise RequestError(cause=e) from e
+
+ def recv(self):
+ # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
+ try:
+ return self.wsw.recv()
+ except SocksProxyError as e:
+ raise ProxyError(cause=e) from e
+ except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
+ raise TransportError(cause=e) from e
+
+
+@register_rh
+class WebsocketsRH(WebSocketRequestHandler):
+ """
+ Websockets request handler
+ https://websockets.readthedocs.io
+ https://github.com/python-websockets/websockets
+ """
+ _SUPPORTED_URL_SCHEMES = ('wss', 'ws')
+ _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
+ _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
+ RH_NAME = 'websockets'
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ for name in ('websockets.client', 'websockets.server'):
+ logger = logging.getLogger(name)
+ handler = logging.StreamHandler(stream=sys.stdout)
+ handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
+ logger.addHandler(handler)
+ if self.verbose:
+ logger.setLevel(logging.DEBUG)
+
+ def _check_extensions(self, extensions):
+ super()._check_extensions(extensions)
+ extensions.pop('timeout', None)
+ extensions.pop('cookiejar', None)
+
+ def _send(self, request):
+ timeout = float(request.extensions.get('timeout') or self.timeout)
+ headers = self._merge_headers(request.headers)
+ if 'cookie' not in headers:
+ cookiejar = request.extensions.get('cookiejar') or self.cookiejar
+ cookie_header = cookiejar.get_cookie_header(request.url)
+ if cookie_header:
+ headers['cookie'] = cookie_header
+
+ wsuri = parse_uri(request.url)
+ create_conn_kwargs = {
+ 'source_address': (self.source_address, 0) if self.source_address else None,
+ 'timeout': timeout
+ }
+ proxy = select_proxy(request.url, request.proxies or self.proxies or {})
+ try:
+ if proxy:
+ socks_proxy_options = make_socks_proxy_opts(proxy)
+ sock = create_connection(
+ address=(socks_proxy_options['addr'], socks_proxy_options['port']),
+ _create_socket_func=functools.partial(
+ create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
+ **create_conn_kwargs
+ )
+ else:
+ sock = create_connection(
+ address=(wsuri.host, wsuri.port),
+ **create_conn_kwargs
+ )
+ conn = websockets.sync.client.connect(
+ sock=sock,
+ uri=request.url,
+ additional_headers=headers,
+ open_timeout=timeout,
+ user_agent_header=None,
+ ssl_context=self._make_sslcontext() if wsuri.secure else None,
+ close_timeout=0, # not ideal, but prevents yt-dlp hanging
+ )
+ return WebsocketsResponseAdapter(conn, url=request.url)
+
+ # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
+ except SocksProxyError as e:
+ raise ProxyError(cause=e) from e
+ except websockets.exceptions.InvalidURI as e:
+ raise RequestError(cause=e) from e
+ except ssl.SSLCertVerificationError as e:
+ raise CertificateVerifyError(cause=e) from e
+ except ssl.SSLError as e:
+ raise SSLError(cause=e) from e
+ except websockets.exceptions.InvalidStatus as e:
+ raise HTTPError(
+ Response(
+ fp=io.BytesIO(e.response.body),
+ url=request.url,
+ headers=e.response.headers,
+ status=e.response.status_code,
+ reason=e.response.reason_phrase),
+ ) from e
+ except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
+ raise TransportError(cause=e) from e
--- /dev/null
+from __future__ import annotations
+
+import abc
+
+from .common import Response, RequestHandler
+
+
+class WebSocketResponse(Response):
+
+ def send(self, message: bytes | str):
+ """
+ Send a message to the server.
+
+ @param message: The message to send. A string (str) is sent as a text frame, bytes is sent as a binary frame.
+ """
+ raise NotImplementedError
+
+ def recv(self):
+ raise NotImplementedError
+
+
+class WebSocketRequestHandler(RequestHandler, abc.ABC):
+ pass
"""No longer used and new code should not use. Exists only for API compat."""
+import asyncio
+import atexit
import platform
import struct
import sys
has_websockets = bool(websockets)
+class WebSocketsWrapper:
+ """Wraps websockets module to use in non-async scopes"""
+ pool = None
+
+ def __init__(self, url, headers=None, connect=True, **ws_kwargs):
+ self.loop = asyncio.new_event_loop()
+ # XXX: "loop" is deprecated
+ self.conn = websockets.connect(
+ url, extra_headers=headers, ping_interval=None,
+ close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'), **ws_kwargs)
+ if connect:
+ self.__enter__()
+ atexit.register(self.__exit__, None, None, None)
+
+ def __enter__(self):
+ if not self.pool:
+ self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
+ return self
+
+ def send(self, *args):
+ self.run_with_loop(self.pool.send(*args), self.loop)
+
+ def recv(self, *args):
+ return self.run_with_loop(self.pool.recv(*args), self.loop)
+
+ def __exit__(self, type, value, traceback):
+ try:
+ return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
+ finally:
+ self.loop.close()
+ self._cancel_all_tasks(self.loop)
+
+ # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
+ # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
+ @staticmethod
+ def run_with_loop(main, loop):
+ if not asyncio.iscoroutine(main):
+ raise ValueError(f'a coroutine was expected, got {main!r}')
+
+ try:
+ return loop.run_until_complete(main)
+ finally:
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ if hasattr(loop, 'shutdown_default_executor'):
+ loop.run_until_complete(loop.shutdown_default_executor())
+
+ @staticmethod
+ def _cancel_all_tasks(loop):
+ to_cancel = asyncio.all_tasks(loop)
+
+ if not to_cancel:
+ return
+
+ for task in to_cancel:
+ task.cancel()
+
+ # XXX: "loop" is removed in python 3.10+
+ loop.run_until_complete(
+ asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
+
+ for task in to_cancel:
+ if task.cancelled():
+ continue
+ if task.exception() is not None:
+ loop.call_exception_handler({
+ 'message': 'unhandled exception during asyncio.run() shutdown',
+ 'exception': task.exception(),
+ 'task': task,
+ })
+
+
def load_plugins(name, suffix, namespace):
from ..plugins import load_plugins
ret = load_plugins(name, suffix)
-import asyncio
-import atexit
import base64
import binascii
import calendar
compat_os_name,
compat_shlex_quote,
)
-from ..dependencies import websockets, xattr
+from ..dependencies import xattr
__name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module
return self.parser.parse_args(self.all_args)
-class WebSocketsWrapper:
- """Wraps websockets module to use in non-async scopes"""
- pool = None
-
- def __init__(self, url, headers=None, connect=True):
- self.loop = asyncio.new_event_loop()
- # XXX: "loop" is deprecated
- self.conn = websockets.connect(
- url, extra_headers=headers, ping_interval=None,
- close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
- if connect:
- self.__enter__()
- atexit.register(self.__exit__, None, None, None)
-
- def __enter__(self):
- if not self.pool:
- self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
- return self
-
- def send(self, *args):
- self.run_with_loop(self.pool.send(*args), self.loop)
-
- def recv(self, *args):
- return self.run_with_loop(self.pool.recv(*args), self.loop)
-
- def __exit__(self, type, value, traceback):
- try:
- return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
- finally:
- self.loop.close()
- self._cancel_all_tasks(self.loop)
-
- # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
- # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
- @staticmethod
- def run_with_loop(main, loop):
- if not asyncio.iscoroutine(main):
- raise ValueError(f'a coroutine was expected, got {main!r}')
-
- try:
- return loop.run_until_complete(main)
- finally:
- loop.run_until_complete(loop.shutdown_asyncgens())
- if hasattr(loop, 'shutdown_default_executor'):
- loop.run_until_complete(loop.shutdown_default_executor())
-
- @staticmethod
- def _cancel_all_tasks(loop):
- to_cancel = asyncio.all_tasks(loop)
-
- if not to_cancel:
- return
-
- for task in to_cancel:
- task.cancel()
-
- # XXX: "loop" is removed in python 3.10+
- loop.run_until_complete(
- asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
-
- for task in to_cancel:
- if task.cancelled():
- continue
- if task.exception() is not None:
- loop.call_exception_handler({
- 'message': 'unhandled exception during asyncio.run() shutdown',
- 'exception': task.exception(),
- 'task': task,
- })
-
-
def merge_headers(*dicts):
"""Merge dicts of http headers case insensitively, prioritizing the latter ones"""
return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}