]> jfr.im git - yt-dlp.git/blame - test/test_websockets.py
[cleanup] Add more ruff rules (#10149)
[yt-dlp.git] / test / test_websockets.py
CommitLineData
ccfd70f4 1#!/usr/bin/env python3
2
3# Allow direct execution
4import os
5import sys
53b4d44f 6import time
ccfd70f4 7
8import pytest
9
69d31914 10from test.helper import verify_address_availability
53b4d44f 11from yt_dlp.networking.common import Features, DEFAULT_TIMEOUT
69d31914 12
ccfd70f4 13sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
15import http.client
16import http.cookiejar
17import http.server
18import json
19import random
20import ssl
21import threading
22
3c7a287e 23from yt_dlp import socks, traverse_obj
ccfd70f4 24from yt_dlp.cookies import YoutubeDLCookieJar
25from yt_dlp.dependencies import websockets
26from yt_dlp.networking import Request
27from yt_dlp.networking.exceptions import (
28 CertificateVerifyError,
29 HTTPError,
30 ProxyError,
31 RequestError,
32 SSLError,
33 TransportError,
34)
35from yt_dlp.utils.networking import HTTPHeaderDict
36
ccfd70f4 37TEST_DIR = os.path.dirname(os.path.abspath(__file__))
38
39
40def websocket_handler(websocket):
41 for message in websocket:
42 if isinstance(message, bytes):
43 if message == b'bytes':
44 return websocket.send('2')
45 elif isinstance(message, str):
46 if message == 'headers':
47 return websocket.send(json.dumps(dict(websocket.request.headers)))
48 elif message == 'path':
49 return websocket.send(websocket.request.path)
50 elif message == 'source_address':
51 return websocket.send(websocket.remote_address[0])
52 elif message == 'str':
53 return websocket.send('1')
54 return websocket.send(message)
55
56
57def process_request(self, request):
58 if request.path.startswith('/gen_'):
59 status = http.HTTPStatus(int(request.path[5:]))
60 if 300 <= status.value <= 300:
61 return websockets.http11.Response(
62 status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
63 return self.protocol.reject(status.value, status.phrase)
64 return self.protocol.accept(request)
65
66
67def create_websocket_server(**ws_kwargs):
68 import websockets.sync.server
f849d77a 69 wsd = websockets.sync.server.serve(
70 websocket_handler, '127.0.0.1', 0,
71 process_request=process_request, open_timeout=2, **ws_kwargs)
ccfd70f4 72 ws_port = wsd.socket.getsockname()[1]
73 ws_server_thread = threading.Thread(target=wsd.serve_forever)
74 ws_server_thread.daemon = True
75 ws_server_thread.start()
76 return ws_server_thread, ws_port
77
78
79def create_ws_websocket_server():
80 return create_websocket_server()
81
82
83def create_wss_websocket_server():
84 certfn = os.path.join(TEST_DIR, 'testcert.pem')
85 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
86 sslctx.load_cert_chain(certfn, None)
87 return create_websocket_server(ssl_context=sslctx)
88
89
90MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
91
92
93def create_mtls_wss_websocket_server():
94 certfn = os.path.join(TEST_DIR, 'testcert.pem')
95 cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
96
97 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
98 sslctx.verify_mode = ssl.CERT_REQUIRED
99 sslctx.load_verify_locations(cafile=cacertfn)
100 sslctx.load_cert_chain(certfn, None)
101
102 return create_websocket_server(ssl_context=sslctx)
103
104
f849d77a 105def ws_validate_and_send(rh, req):
106 rh.validate(req)
107 max_tries = 3
108 for i in range(max_tries):
109 try:
110 return rh.send(req)
111 except TransportError as e:
112 if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
113 # websockets server sometimes hangs on new connections
114 continue
115 raise
116
117
ccfd70f4 118@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
3c7a287e 119@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
ccfd70f4 120class TestWebsSocketRequestHandlerConformance:
121 @classmethod
122 def setup_class(cls):
123 cls.ws_thread, cls.ws_port = create_ws_websocket_server()
124 cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
125
126 cls.wss_thread, cls.wss_port = create_wss_websocket_server()
127 cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
128
129 cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
130 cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
131
132 cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
133 cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
134
ccfd70f4 135 def test_basic_websockets(self, handler):
136 with handler() as rh:
f849d77a 137 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 138 assert 'upgrade' in ws.headers
139 assert ws.status == 101
140 ws.send('foo')
141 assert ws.recv() == 'foo'
142 ws.close()
143
144 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
145 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
ccfd70f4 146 def test_send_types(self, handler, msg, opcode):
147 with handler() as rh:
f849d77a 148 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 149 ws.send(msg)
150 assert int(ws.recv()) == opcode
151 ws.close()
152
ccfd70f4 153 def test_verify_cert(self, handler):
154 with handler() as rh:
155 with pytest.raises(CertificateVerifyError):
f849d77a 156 ws_validate_and_send(rh, Request(self.wss_base_url))
ccfd70f4 157
158 with handler(verify=False) as rh:
f849d77a 159 ws = ws_validate_and_send(rh, Request(self.wss_base_url))
ccfd70f4 160 assert ws.status == 101
161 ws.close()
162
ccfd70f4 163 def test_ssl_error(self, handler):
164 with handler(verify=False) as rh:
37755a03 165 with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
f849d77a 166 ws_validate_and_send(rh, Request(self.bad_wss_host))
ccfd70f4 167 assert not issubclass(exc_info.type, CertificateVerifyError)
168
ccfd70f4 169 @pytest.mark.parametrize('path,expected', [
170 # Unicode characters should be encoded with uppercase percent-encoding
171 ('/中文', '/%E4%B8%AD%E6%96%87'),
172 # don't normalize existing percent encodings
173 ('/%c7%9f', '/%c7%9f'),
174 ])
175 def test_percent_encode(self, handler, path, expected):
176 with handler() as rh:
f849d77a 177 ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
ccfd70f4 178 ws.send('path')
179 assert ws.recv() == expected
180 assert ws.status == 101
181 ws.close()
182
ccfd70f4 183 def test_remove_dot_segments(self, handler):
184 with handler() as rh:
185 # This isn't a comprehensive test,
186 # but it should be enough to check whether the handler is removing dot segments
f849d77a 187 ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
ccfd70f4 188 assert ws.status == 101
189 ws.send('path')
190 assert ws.recv() == '/test'
191 ws.close()
192
193 # We are restricted to known HTTP status codes in http.HTTPStatus
194 # Redirects are not supported for websockets
ccfd70f4 195 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
196 def test_raise_http_error(self, handler, status):
197 with handler() as rh:
198 with pytest.raises(HTTPError) as exc_info:
f849d77a 199 ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
ccfd70f4 200 assert exc_info.value.status == status
201
ccfd70f4 202 @pytest.mark.parametrize('params,extensions', [
ac340d07 203 ({'timeout': sys.float_info.min}, {}),
204 ({}, {'timeout': sys.float_info.min}),
ccfd70f4 205 ])
53b4d44f 206 def test_read_timeout(self, handler, params, extensions):
ccfd70f4 207 with handler(**params) as rh:
208 with pytest.raises(TransportError):
f849d77a 209 ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
ccfd70f4 210
53b4d44f 211 def test_connect_timeout(self, handler):
212 # nothing should be listening on this port
213 connect_timeout_url = 'ws://10.255.255.255'
214 with handler(timeout=0.01) as rh, pytest.raises(TransportError):
215 now = time.time()
216 ws_validate_and_send(rh, Request(connect_timeout_url))
217 assert time.time() - now < DEFAULT_TIMEOUT
218
219 # Per request timeout, should override handler timeout
220 request = Request(connect_timeout_url, extensions={'timeout': 0.01})
221 with handler() as rh, pytest.raises(TransportError):
222 now = time.time()
223 ws_validate_and_send(rh, request)
224 assert time.time() - now < DEFAULT_TIMEOUT
225
ccfd70f4 226 def test_cookies(self, handler):
227 cookiejar = YoutubeDLCookieJar()
228 cookiejar.set_cookie(http.cookiejar.Cookie(
229 version=0, name='test', value='ytdlp', port=None, port_specified=False,
230 domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
231 path_specified=True, secure=False, expires=None, discard=False, comment=None,
232 comment_url=None, rest={}))
233
234 with handler(cookiejar=cookiejar) as rh:
f849d77a 235 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 236 ws.send('headers')
237 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
238 ws.close()
239
240 with handler() as rh:
f849d77a 241 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 242 ws.send('headers')
243 assert 'cookie' not in json.loads(ws.recv())
244 ws.close()
245
f849d77a 246 ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
ccfd70f4 247 ws.send('headers')
248 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
249 ws.close()
250
ccfd70f4 251 def test_source_address(self, handler):
252 source_address = f'127.0.0.{random.randint(5, 255)}'
69d31914 253 verify_address_availability(source_address)
ccfd70f4 254 with handler(source_address=source_address) as rh:
f849d77a 255 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 256 ws.send('source_address')
257 assert source_address == ws.recv()
258 ws.close()
259
ccfd70f4 260 def test_response_url(self, handler):
261 with handler() as rh:
262 url = f'{self.ws_base_url}/something'
f849d77a 263 ws = ws_validate_and_send(rh, Request(url))
ccfd70f4 264 assert ws.url == url
265 ws.close()
266
ccfd70f4 267 def test_request_headers(self, handler):
268 with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
269 # Global Headers
f849d77a 270 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 271 ws.send('headers')
272 headers = HTTPHeaderDict(json.loads(ws.recv()))
273 assert headers['test1'] == 'test'
274 ws.close()
275
276 # Per request headers, merged with global
f849d77a 277 ws = ws_validate_and_send(rh, Request(
ccfd70f4 278 self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
279 ws.send('headers')
280 headers = HTTPHeaderDict(json.loads(ws.recv()))
281 assert headers['test1'] == 'test'
282 assert headers['test2'] == 'changed'
283 assert headers['test3'] == 'test3'
284 ws.close()
285
286 @pytest.mark.parametrize('client_cert', (
287 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
288 {
289 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
290 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
291 },
292 {
293 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
294 'client_certificate_password': 'foobar',
295 },
296 {
297 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
298 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
299 'client_certificate_password': 'foobar',
add96eb9 300 },
ccfd70f4 301 ))
ccfd70f4 302 def test_mtls(self, handler, client_cert):
303 with handler(
304 # Disable client-side validation of unacceptable self-signed testcert.pem
305 # The test is of a check on the server side, so unaffected
306 verify=False,
add96eb9 307 client_cert=client_cert,
ccfd70f4 308 ) as rh:
f849d77a 309 ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
ccfd70f4 310
3c7a287e 311 def test_request_disable_proxy(self, handler):
312 for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
313 # Given handler is configured with a proxy
314 with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
315 # When a proxy is explicitly set to None for the request
316 ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
317 # Then no proxy should be used
318 assert ws.status == 101
319 ws.close()
320
321 @pytest.mark.skip_handlers_if(
322 lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
323 def test_noproxy(self, handler):
324 for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
325 # Given the handler is configured with a proxy
326 with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
327 for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
328 # When request no proxy includes the request url host
329 ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy}))
330 # Then the proxy should not be used
331 assert ws.status == 101
332 ws.close()
333
334 @pytest.mark.skip_handlers_if(
335 lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
336 def test_allproxy(self, handler):
337 supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws')
338 # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
339 # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
340 with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh:
341 with pytest.raises(TransportError):
342 ws_validate_and_send(rh, Request(self.ws_base_url)).close()
343
344 with handler(timeout=0.1) as rh:
345 with pytest.raises(TransportError):
346 ws_validate_and_send(
347 rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
348
ccfd70f4 349
350def create_fake_ws_connection(raised):
351 import websockets.sync.client
352
353 class FakeWsConnection(websockets.sync.client.ClientConnection):
354 def __init__(self, *args, **kwargs):
355 class FakeResponse:
356 body = b''
357 headers = {}
358 status_code = 101
359 reason_phrase = 'test'
360
361 self.response = FakeResponse()
362
363 def send(self, *args, **kwargs):
364 raise raised()
365
366 def recv(self, *args, **kwargs):
367 raise raised()
368
369 def close(self, *args, **kwargs):
370 return
371
372 return FakeWsConnection()
373
374
375@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
376class TestWebsocketsRequestHandler:
377 @pytest.mark.parametrize('raised,expected', [
378 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
379 (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
380 # Requires a response object. Should be covered by HTTP error tests.
381 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
382 (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
383 # These are subclasses of InvalidHandshake
384 (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
385 (lambda: websockets.exceptions.NegotiationError(), TransportError),
386 # Catch-all
387 (lambda: websockets.exceptions.WebSocketException(), TransportError),
388 (lambda: TimeoutError(), TransportError),
389 # These may be raised by our create_connection implementation, which should also be caught
390 (lambda: OSError(), TransportError),
391 (lambda: ssl.SSLError(), SSLError),
392 (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
393 (lambda: socks.ProxyError(), ProxyError),
394 ])
395 def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
396 import websockets.sync.client
397
398 import yt_dlp.networking._websockets
399 with handler() as rh:
400 def fake_connect(*args, **kwargs):
401 raise raised()
402 monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
403 monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
404 with pytest.raises(expected) as exc_info:
405 rh.send(Request('ws://fake-url'))
406 assert exc_info.type is expected
407
408 @pytest.mark.parametrize('raised,expected,match', [
409 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
410 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
411 (lambda: RuntimeError(), TransportError, None),
412 (lambda: TimeoutError(), TransportError, None),
413 (lambda: TypeError(), RequestError, None),
414 (lambda: socks.ProxyError(), ProxyError, None),
415 # Catch-all
416 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
417 ])
418 def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
419 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
420 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
421 with pytest.raises(expected, match=match) as exc_info:
422 ws.send('test')
423 assert exc_info.type is expected
424
425 @pytest.mark.parametrize('raised,expected,match', [
426 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
427 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
428 (lambda: RuntimeError(), TransportError, None),
429 (lambda: TimeoutError(), TransportError, None),
430 (lambda: socks.ProxyError(), ProxyError, None),
431 # Catch-all
432 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
433 ])
434 def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
435 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
436 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
437 with pytest.raises(expected, match=match) as exc_info:
438 ws.recv()
439 assert exc_info.type is expected