]> jfr.im git - yt-dlp.git/blob - test/test_websockets.py
[build] Exclude `requests` from `py2exe` (#9982)
[yt-dlp.git] / test / test_websockets.py
1 #!/usr/bin/env python3
2
3 # Allow direct execution
4 import os
5 import sys
6 import time
7
8 import pytest
9
10 from test.helper import verify_address_availability
11 from yt_dlp.networking.common import Features, DEFAULT_TIMEOUT
12
13 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
15 import http.client
16 import http.cookiejar
17 import http.server
18 import json
19 import random
20 import ssl
21 import threading
22
23 from yt_dlp import socks, traverse_obj
24 from yt_dlp.cookies import YoutubeDLCookieJar
25 from yt_dlp.dependencies import websockets
26 from yt_dlp.networking import Request
27 from yt_dlp.networking.exceptions import (
28 CertificateVerifyError,
29 HTTPError,
30 ProxyError,
31 RequestError,
32 SSLError,
33 TransportError,
34 )
35 from yt_dlp.utils.networking import HTTPHeaderDict
36
37 TEST_DIR = os.path.dirname(os.path.abspath(__file__))
38
39
40 def websocket_handler(websocket):
41 for message in websocket:
42 if isinstance(message, bytes):
43 if message == b'bytes':
44 return websocket.send('2')
45 elif isinstance(message, str):
46 if message == 'headers':
47 return websocket.send(json.dumps(dict(websocket.request.headers)))
48 elif message == 'path':
49 return websocket.send(websocket.request.path)
50 elif message == 'source_address':
51 return websocket.send(websocket.remote_address[0])
52 elif message == 'str':
53 return websocket.send('1')
54 return websocket.send(message)
55
56
57 def process_request(self, request):
58 if request.path.startswith('/gen_'):
59 status = http.HTTPStatus(int(request.path[5:]))
60 if 300 <= status.value <= 300:
61 return websockets.http11.Response(
62 status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
63 return self.protocol.reject(status.value, status.phrase)
64 return self.protocol.accept(request)
65
66
67 def create_websocket_server(**ws_kwargs):
68 import websockets.sync.server
69 wsd = websockets.sync.server.serve(
70 websocket_handler, '127.0.0.1', 0,
71 process_request=process_request, open_timeout=2, **ws_kwargs)
72 ws_port = wsd.socket.getsockname()[1]
73 ws_server_thread = threading.Thread(target=wsd.serve_forever)
74 ws_server_thread.daemon = True
75 ws_server_thread.start()
76 return ws_server_thread, ws_port
77
78
79 def create_ws_websocket_server():
80 return create_websocket_server()
81
82
83 def create_wss_websocket_server():
84 certfn = os.path.join(TEST_DIR, 'testcert.pem')
85 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
86 sslctx.load_cert_chain(certfn, None)
87 return create_websocket_server(ssl_context=sslctx)
88
89
90 MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
91
92
93 def create_mtls_wss_websocket_server():
94 certfn = os.path.join(TEST_DIR, 'testcert.pem')
95 cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
96
97 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
98 sslctx.verify_mode = ssl.CERT_REQUIRED
99 sslctx.load_verify_locations(cafile=cacertfn)
100 sslctx.load_cert_chain(certfn, None)
101
102 return create_websocket_server(ssl_context=sslctx)
103
104
105 def ws_validate_and_send(rh, req):
106 rh.validate(req)
107 max_tries = 3
108 for i in range(max_tries):
109 try:
110 return rh.send(req)
111 except TransportError as e:
112 if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
113 # websockets server sometimes hangs on new connections
114 continue
115 raise
116
117
118 @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
119 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
120 class TestWebsSocketRequestHandlerConformance:
121 @classmethod
122 def setup_class(cls):
123 cls.ws_thread, cls.ws_port = create_ws_websocket_server()
124 cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
125
126 cls.wss_thread, cls.wss_port = create_wss_websocket_server()
127 cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
128
129 cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
130 cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
131
132 cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
133 cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
134
135 def test_basic_websockets(self, handler):
136 with handler() as rh:
137 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
138 assert 'upgrade' in ws.headers
139 assert ws.status == 101
140 ws.send('foo')
141 assert ws.recv() == 'foo'
142 ws.close()
143
144 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
145 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
146 def test_send_types(self, handler, msg, opcode):
147 with handler() as rh:
148 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
149 ws.send(msg)
150 assert int(ws.recv()) == opcode
151 ws.close()
152
153 def test_verify_cert(self, handler):
154 with handler() as rh:
155 with pytest.raises(CertificateVerifyError):
156 ws_validate_and_send(rh, Request(self.wss_base_url))
157
158 with handler(verify=False) as rh:
159 ws = ws_validate_and_send(rh, Request(self.wss_base_url))
160 assert ws.status == 101
161 ws.close()
162
163 def test_ssl_error(self, handler):
164 with handler(verify=False) as rh:
165 with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
166 ws_validate_and_send(rh, Request(self.bad_wss_host))
167 assert not issubclass(exc_info.type, CertificateVerifyError)
168
169 @pytest.mark.parametrize('path,expected', [
170 # Unicode characters should be encoded with uppercase percent-encoding
171 ('/中文', '/%E4%B8%AD%E6%96%87'),
172 # don't normalize existing percent encodings
173 ('/%c7%9f', '/%c7%9f'),
174 ])
175 def test_percent_encode(self, handler, path, expected):
176 with handler() as rh:
177 ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
178 ws.send('path')
179 assert ws.recv() == expected
180 assert ws.status == 101
181 ws.close()
182
183 def test_remove_dot_segments(self, handler):
184 with handler() as rh:
185 # This isn't a comprehensive test,
186 # but it should be enough to check whether the handler is removing dot segments
187 ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
188 assert ws.status == 101
189 ws.send('path')
190 assert ws.recv() == '/test'
191 ws.close()
192
193 # We are restricted to known HTTP status codes in http.HTTPStatus
194 # Redirects are not supported for websockets
195 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
196 def test_raise_http_error(self, handler, status):
197 with handler() as rh:
198 with pytest.raises(HTTPError) as exc_info:
199 ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
200 assert exc_info.value.status == status
201
202 @pytest.mark.parametrize('params,extensions', [
203 ({'timeout': sys.float_info.min}, {}),
204 ({}, {'timeout': sys.float_info.min}),
205 ])
206 def test_read_timeout(self, handler, params, extensions):
207 with handler(**params) as rh:
208 with pytest.raises(TransportError):
209 ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
210
211 def test_connect_timeout(self, handler):
212 # nothing should be listening on this port
213 connect_timeout_url = 'ws://10.255.255.255'
214 with handler(timeout=0.01) as rh, pytest.raises(TransportError):
215 now = time.time()
216 ws_validate_and_send(rh, Request(connect_timeout_url))
217 assert time.time() - now < DEFAULT_TIMEOUT
218
219 # Per request timeout, should override handler timeout
220 request = Request(connect_timeout_url, extensions={'timeout': 0.01})
221 with handler() as rh, pytest.raises(TransportError):
222 now = time.time()
223 ws_validate_and_send(rh, request)
224 assert time.time() - now < DEFAULT_TIMEOUT
225
226 def test_cookies(self, handler):
227 cookiejar = YoutubeDLCookieJar()
228 cookiejar.set_cookie(http.cookiejar.Cookie(
229 version=0, name='test', value='ytdlp', port=None, port_specified=False,
230 domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
231 path_specified=True, secure=False, expires=None, discard=False, comment=None,
232 comment_url=None, rest={}))
233
234 with handler(cookiejar=cookiejar) as rh:
235 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
236 ws.send('headers')
237 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
238 ws.close()
239
240 with handler() as rh:
241 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
242 ws.send('headers')
243 assert 'cookie' not in json.loads(ws.recv())
244 ws.close()
245
246 ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
247 ws.send('headers')
248 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
249 ws.close()
250
251 def test_source_address(self, handler):
252 source_address = f'127.0.0.{random.randint(5, 255)}'
253 verify_address_availability(source_address)
254 with handler(source_address=source_address) as rh:
255 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
256 ws.send('source_address')
257 assert source_address == ws.recv()
258 ws.close()
259
260 def test_response_url(self, handler):
261 with handler() as rh:
262 url = f'{self.ws_base_url}/something'
263 ws = ws_validate_and_send(rh, Request(url))
264 assert ws.url == url
265 ws.close()
266
267 def test_request_headers(self, handler):
268 with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
269 # Global Headers
270 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
271 ws.send('headers')
272 headers = HTTPHeaderDict(json.loads(ws.recv()))
273 assert headers['test1'] == 'test'
274 ws.close()
275
276 # Per request headers, merged with global
277 ws = ws_validate_and_send(rh, Request(
278 self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
279 ws.send('headers')
280 headers = HTTPHeaderDict(json.loads(ws.recv()))
281 assert headers['test1'] == 'test'
282 assert headers['test2'] == 'changed'
283 assert headers['test3'] == 'test3'
284 ws.close()
285
286 @pytest.mark.parametrize('client_cert', (
287 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
288 {
289 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
290 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
291 },
292 {
293 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
294 'client_certificate_password': 'foobar',
295 },
296 {
297 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
298 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
299 'client_certificate_password': 'foobar',
300 }
301 ))
302 def test_mtls(self, handler, client_cert):
303 with handler(
304 # Disable client-side validation of unacceptable self-signed testcert.pem
305 # The test is of a check on the server side, so unaffected
306 verify=False,
307 client_cert=client_cert
308 ) as rh:
309 ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
310
311 def test_request_disable_proxy(self, handler):
312 for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
313 # Given handler is configured with a proxy
314 with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
315 # When a proxy is explicitly set to None for the request
316 ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
317 # Then no proxy should be used
318 assert ws.status == 101
319 ws.close()
320
321 @pytest.mark.skip_handlers_if(
322 lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
323 def test_noproxy(self, handler):
324 for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
325 # Given the handler is configured with a proxy
326 with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
327 for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
328 # When request no proxy includes the request url host
329 ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy}))
330 # Then the proxy should not be used
331 assert ws.status == 101
332 ws.close()
333
334 @pytest.mark.skip_handlers_if(
335 lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
336 def test_allproxy(self, handler):
337 supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws')
338 # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
339 # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
340 with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh:
341 with pytest.raises(TransportError):
342 ws_validate_and_send(rh, Request(self.ws_base_url)).close()
343
344 with handler(timeout=0.1) as rh:
345 with pytest.raises(TransportError):
346 ws_validate_and_send(
347 rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
348
349
350 def create_fake_ws_connection(raised):
351 import websockets.sync.client
352
353 class FakeWsConnection(websockets.sync.client.ClientConnection):
354 def __init__(self, *args, **kwargs):
355 class FakeResponse:
356 body = b''
357 headers = {}
358 status_code = 101
359 reason_phrase = 'test'
360
361 self.response = FakeResponse()
362
363 def send(self, *args, **kwargs):
364 raise raised()
365
366 def recv(self, *args, **kwargs):
367 raise raised()
368
369 def close(self, *args, **kwargs):
370 return
371
372 return FakeWsConnection()
373
374
375 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
376 class TestWebsocketsRequestHandler:
377 @pytest.mark.parametrize('raised,expected', [
378 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
379 (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
380 # Requires a response object. Should be covered by HTTP error tests.
381 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
382 (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
383 # These are subclasses of InvalidHandshake
384 (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
385 (lambda: websockets.exceptions.NegotiationError(), TransportError),
386 # Catch-all
387 (lambda: websockets.exceptions.WebSocketException(), TransportError),
388 (lambda: TimeoutError(), TransportError),
389 # These may be raised by our create_connection implementation, which should also be caught
390 (lambda: OSError(), TransportError),
391 (lambda: ssl.SSLError(), SSLError),
392 (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
393 (lambda: socks.ProxyError(), ProxyError),
394 ])
395 def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
396 import websockets.sync.client
397
398 import yt_dlp.networking._websockets
399 with handler() as rh:
400 def fake_connect(*args, **kwargs):
401 raise raised()
402 monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
403 monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
404 with pytest.raises(expected) as exc_info:
405 rh.send(Request('ws://fake-url'))
406 assert exc_info.type is expected
407
408 @pytest.mark.parametrize('raised,expected,match', [
409 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
410 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
411 (lambda: RuntimeError(), TransportError, None),
412 (lambda: TimeoutError(), TransportError, None),
413 (lambda: TypeError(), RequestError, None),
414 (lambda: socks.ProxyError(), ProxyError, None),
415 # Catch-all
416 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
417 ])
418 def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
419 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
420 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
421 with pytest.raises(expected, match=match) as exc_info:
422 ws.send('test')
423 assert exc_info.type is expected
424
425 @pytest.mark.parametrize('raised,expected,match', [
426 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
427 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
428 (lambda: RuntimeError(), TransportError, None),
429 (lambda: TimeoutError(), TransportError, None),
430 (lambda: socks.ProxyError(), ProxyError, None),
431 # Catch-all
432 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
433 ])
434 def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
435 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
436 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
437 with pytest.raises(expected, match=match) as exc_info:
438 ws.recv()
439 assert exc_info.type is expected