]> jfr.im git - yt-dlp.git/blob - test/test_websockets.py
Release 2024.04.09
[yt-dlp.git] / test / test_websockets.py
1 #!/usr/bin/env python3
2
3 # Allow direct execution
4 import os
5 import sys
6
7 import pytest
8
9 from test.helper import verify_address_availability
10
11 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
13 import http.client
14 import http.cookiejar
15 import http.server
16 import json
17 import random
18 import ssl
19 import threading
20
21 from yt_dlp import socks
22 from yt_dlp.cookies import YoutubeDLCookieJar
23 from yt_dlp.dependencies import websockets
24 from yt_dlp.networking import Request
25 from yt_dlp.networking.exceptions import (
26 CertificateVerifyError,
27 HTTPError,
28 ProxyError,
29 RequestError,
30 SSLError,
31 TransportError,
32 )
33 from yt_dlp.utils.networking import HTTPHeaderDict
34
35 TEST_DIR = os.path.dirname(os.path.abspath(__file__))
36
37
38 def websocket_handler(websocket):
39 for message in websocket:
40 if isinstance(message, bytes):
41 if message == b'bytes':
42 return websocket.send('2')
43 elif isinstance(message, str):
44 if message == 'headers':
45 return websocket.send(json.dumps(dict(websocket.request.headers)))
46 elif message == 'path':
47 return websocket.send(websocket.request.path)
48 elif message == 'source_address':
49 return websocket.send(websocket.remote_address[0])
50 elif message == 'str':
51 return websocket.send('1')
52 return websocket.send(message)
53
54
55 def process_request(self, request):
56 if request.path.startswith('/gen_'):
57 status = http.HTTPStatus(int(request.path[5:]))
58 if 300 <= status.value <= 300:
59 return websockets.http11.Response(
60 status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
61 return self.protocol.reject(status.value, status.phrase)
62 return self.protocol.accept(request)
63
64
65 def create_websocket_server(**ws_kwargs):
66 import websockets.sync.server
67 wsd = websockets.sync.server.serve(
68 websocket_handler, '127.0.0.1', 0,
69 process_request=process_request, open_timeout=2, **ws_kwargs)
70 ws_port = wsd.socket.getsockname()[1]
71 ws_server_thread = threading.Thread(target=wsd.serve_forever)
72 ws_server_thread.daemon = True
73 ws_server_thread.start()
74 return ws_server_thread, ws_port
75
76
77 def create_ws_websocket_server():
78 return create_websocket_server()
79
80
81 def create_wss_websocket_server():
82 certfn = os.path.join(TEST_DIR, 'testcert.pem')
83 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
84 sslctx.load_cert_chain(certfn, None)
85 return create_websocket_server(ssl_context=sslctx)
86
87
88 MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
89
90
91 def create_mtls_wss_websocket_server():
92 certfn = os.path.join(TEST_DIR, 'testcert.pem')
93 cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
94
95 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
96 sslctx.verify_mode = ssl.CERT_REQUIRED
97 sslctx.load_verify_locations(cafile=cacertfn)
98 sslctx.load_cert_chain(certfn, None)
99
100 return create_websocket_server(ssl_context=sslctx)
101
102
103 def ws_validate_and_send(rh, req):
104 rh.validate(req)
105 max_tries = 3
106 for i in range(max_tries):
107 try:
108 return rh.send(req)
109 except TransportError as e:
110 if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
111 # websockets server sometimes hangs on new connections
112 continue
113 raise
114
115
116 @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
117 class TestWebsSocketRequestHandlerConformance:
118 @classmethod
119 def setup_class(cls):
120 cls.ws_thread, cls.ws_port = create_ws_websocket_server()
121 cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
122
123 cls.wss_thread, cls.wss_port = create_wss_websocket_server()
124 cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
125
126 cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
127 cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
128
129 cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
130 cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
131
132 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
133 def test_basic_websockets(self, handler):
134 with handler() as rh:
135 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
136 assert 'upgrade' in ws.headers
137 assert ws.status == 101
138 ws.send('foo')
139 assert ws.recv() == 'foo'
140 ws.close()
141
142 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
143 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
144 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
145 def test_send_types(self, handler, msg, opcode):
146 with handler() as rh:
147 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
148 ws.send(msg)
149 assert int(ws.recv()) == opcode
150 ws.close()
151
152 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
153 def test_verify_cert(self, handler):
154 with handler() as rh:
155 with pytest.raises(CertificateVerifyError):
156 ws_validate_and_send(rh, Request(self.wss_base_url))
157
158 with handler(verify=False) as rh:
159 ws = ws_validate_and_send(rh, Request(self.wss_base_url))
160 assert ws.status == 101
161 ws.close()
162
163 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
164 def test_ssl_error(self, handler):
165 with handler(verify=False) as rh:
166 with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
167 ws_validate_and_send(rh, Request(self.bad_wss_host))
168 assert not issubclass(exc_info.type, CertificateVerifyError)
169
170 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
171 @pytest.mark.parametrize('path,expected', [
172 # Unicode characters should be encoded with uppercase percent-encoding
173 ('/中文', '/%E4%B8%AD%E6%96%87'),
174 # don't normalize existing percent encodings
175 ('/%c7%9f', '/%c7%9f'),
176 ])
177 def test_percent_encode(self, handler, path, expected):
178 with handler() as rh:
179 ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
180 ws.send('path')
181 assert ws.recv() == expected
182 assert ws.status == 101
183 ws.close()
184
185 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
186 def test_remove_dot_segments(self, handler):
187 with handler() as rh:
188 # This isn't a comprehensive test,
189 # but it should be enough to check whether the handler is removing dot segments
190 ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
191 assert ws.status == 101
192 ws.send('path')
193 assert ws.recv() == '/test'
194 ws.close()
195
196 # We are restricted to known HTTP status codes in http.HTTPStatus
197 # Redirects are not supported for websockets
198 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
199 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
200 def test_raise_http_error(self, handler, status):
201 with handler() as rh:
202 with pytest.raises(HTTPError) as exc_info:
203 ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
204 assert exc_info.value.status == status
205
206 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
207 @pytest.mark.parametrize('params,extensions', [
208 ({'timeout': sys.float_info.min}, {}),
209 ({}, {'timeout': sys.float_info.min}),
210 ])
211 def test_timeout(self, handler, params, extensions):
212 with handler(**params) as rh:
213 with pytest.raises(TransportError):
214 ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
215
216 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
217 def test_cookies(self, handler):
218 cookiejar = YoutubeDLCookieJar()
219 cookiejar.set_cookie(http.cookiejar.Cookie(
220 version=0, name='test', value='ytdlp', port=None, port_specified=False,
221 domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
222 path_specified=True, secure=False, expires=None, discard=False, comment=None,
223 comment_url=None, rest={}))
224
225 with handler(cookiejar=cookiejar) as rh:
226 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
227 ws.send('headers')
228 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
229 ws.close()
230
231 with handler() as rh:
232 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
233 ws.send('headers')
234 assert 'cookie' not in json.loads(ws.recv())
235 ws.close()
236
237 ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
238 ws.send('headers')
239 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
240 ws.close()
241
242 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
243 def test_source_address(self, handler):
244 source_address = f'127.0.0.{random.randint(5, 255)}'
245 verify_address_availability(source_address)
246 with handler(source_address=source_address) as rh:
247 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
248 ws.send('source_address')
249 assert source_address == ws.recv()
250 ws.close()
251
252 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
253 def test_response_url(self, handler):
254 with handler() as rh:
255 url = f'{self.ws_base_url}/something'
256 ws = ws_validate_and_send(rh, Request(url))
257 assert ws.url == url
258 ws.close()
259
260 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
261 def test_request_headers(self, handler):
262 with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
263 # Global Headers
264 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
265 ws.send('headers')
266 headers = HTTPHeaderDict(json.loads(ws.recv()))
267 assert headers['test1'] == 'test'
268 ws.close()
269
270 # Per request headers, merged with global
271 ws = ws_validate_and_send(rh, Request(
272 self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
273 ws.send('headers')
274 headers = HTTPHeaderDict(json.loads(ws.recv()))
275 assert headers['test1'] == 'test'
276 assert headers['test2'] == 'changed'
277 assert headers['test3'] == 'test3'
278 ws.close()
279
280 @pytest.mark.parametrize('client_cert', (
281 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
282 {
283 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
284 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
285 },
286 {
287 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
288 'client_certificate_password': 'foobar',
289 },
290 {
291 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
292 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
293 'client_certificate_password': 'foobar',
294 }
295 ))
296 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
297 def test_mtls(self, handler, client_cert):
298 with handler(
299 # Disable client-side validation of unacceptable self-signed testcert.pem
300 # The test is of a check on the server side, so unaffected
301 verify=False,
302 client_cert=client_cert
303 ) as rh:
304 ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
305
306
307 def create_fake_ws_connection(raised):
308 import websockets.sync.client
309
310 class FakeWsConnection(websockets.sync.client.ClientConnection):
311 def __init__(self, *args, **kwargs):
312 class FakeResponse:
313 body = b''
314 headers = {}
315 status_code = 101
316 reason_phrase = 'test'
317
318 self.response = FakeResponse()
319
320 def send(self, *args, **kwargs):
321 raise raised()
322
323 def recv(self, *args, **kwargs):
324 raise raised()
325
326 def close(self, *args, **kwargs):
327 return
328
329 return FakeWsConnection()
330
331
332 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
333 class TestWebsocketsRequestHandler:
334 @pytest.mark.parametrize('raised,expected', [
335 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
336 (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
337 # Requires a response object. Should be covered by HTTP error tests.
338 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
339 (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
340 # These are subclasses of InvalidHandshake
341 (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
342 (lambda: websockets.exceptions.NegotiationError(), TransportError),
343 # Catch-all
344 (lambda: websockets.exceptions.WebSocketException(), TransportError),
345 (lambda: TimeoutError(), TransportError),
346 # These may be raised by our create_connection implementation, which should also be caught
347 (lambda: OSError(), TransportError),
348 (lambda: ssl.SSLError(), SSLError),
349 (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
350 (lambda: socks.ProxyError(), ProxyError),
351 ])
352 def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
353 import websockets.sync.client
354
355 import yt_dlp.networking._websockets
356 with handler() as rh:
357 def fake_connect(*args, **kwargs):
358 raise raised()
359 monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
360 monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
361 with pytest.raises(expected) as exc_info:
362 rh.send(Request('ws://fake-url'))
363 assert exc_info.type is expected
364
365 @pytest.mark.parametrize('raised,expected,match', [
366 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
367 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
368 (lambda: RuntimeError(), TransportError, None),
369 (lambda: TimeoutError(), TransportError, None),
370 (lambda: TypeError(), RequestError, None),
371 (lambda: socks.ProxyError(), ProxyError, None),
372 # Catch-all
373 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
374 ])
375 def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
376 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
377 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
378 with pytest.raises(expected, match=match) as exc_info:
379 ws.send('test')
380 assert exc_info.type is expected
381
382 @pytest.mark.parametrize('raised,expected,match', [
383 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
384 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
385 (lambda: RuntimeError(), TransportError, None),
386 (lambda: TimeoutError(), TransportError, None),
387 (lambda: socks.ProxyError(), ProxyError, None),
388 # Catch-all
389 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
390 ])
391 def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
392 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
393 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
394 with pytest.raises(expected, match=match) as exc_info:
395 ws.recv()
396 assert exc_info.type is expected