]> jfr.im git - yt-dlp.git/blob - test/test_websockets.py
39d3c7d7221439edaaf796d9512ef842de21a1c4
[yt-dlp.git] / test / test_websockets.py
1 #!/usr/bin/env python3
2
3 # Allow direct execution
4 import os
5 import sys
6
7 import pytest
8
9 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
11 import http.client
12 import http.cookiejar
13 import http.server
14 import json
15 import random
16 import ssl
17 import threading
18
19 from yt_dlp import socks
20 from yt_dlp.cookies import YoutubeDLCookieJar
21 from yt_dlp.dependencies import websockets
22 from yt_dlp.networking import Request
23 from yt_dlp.networking.exceptions import (
24 CertificateVerifyError,
25 HTTPError,
26 ProxyError,
27 RequestError,
28 SSLError,
29 TransportError,
30 )
31 from yt_dlp.utils.networking import HTTPHeaderDict
32
33 from test.conftest import validate_and_send
34
35 TEST_DIR = os.path.dirname(os.path.abspath(__file__))
36
37
38 def websocket_handler(websocket):
39 for message in websocket:
40 if isinstance(message, bytes):
41 if message == b'bytes':
42 return websocket.send('2')
43 elif isinstance(message, str):
44 if message == 'headers':
45 return websocket.send(json.dumps(dict(websocket.request.headers)))
46 elif message == 'path':
47 return websocket.send(websocket.request.path)
48 elif message == 'source_address':
49 return websocket.send(websocket.remote_address[0])
50 elif message == 'str':
51 return websocket.send('1')
52 return websocket.send(message)
53
54
55 def process_request(self, request):
56 if request.path.startswith('/gen_'):
57 status = http.HTTPStatus(int(request.path[5:]))
58 if 300 <= status.value <= 300:
59 return websockets.http11.Response(
60 status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
61 return self.protocol.reject(status.value, status.phrase)
62 return self.protocol.accept(request)
63
64
65 def create_websocket_server(**ws_kwargs):
66 import websockets.sync.server
67 wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
68 ws_port = wsd.socket.getsockname()[1]
69 ws_server_thread = threading.Thread(target=wsd.serve_forever)
70 ws_server_thread.daemon = True
71 ws_server_thread.start()
72 return ws_server_thread, ws_port
73
74
75 def create_ws_websocket_server():
76 return create_websocket_server()
77
78
79 def create_wss_websocket_server():
80 certfn = os.path.join(TEST_DIR, 'testcert.pem')
81 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
82 sslctx.load_cert_chain(certfn, None)
83 return create_websocket_server(ssl_context=sslctx)
84
85
86 MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
87
88
89 def create_mtls_wss_websocket_server():
90 certfn = os.path.join(TEST_DIR, 'testcert.pem')
91 cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
92
93 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
94 sslctx.verify_mode = ssl.CERT_REQUIRED
95 sslctx.load_verify_locations(cafile=cacertfn)
96 sslctx.load_cert_chain(certfn, None)
97
98 return create_websocket_server(ssl_context=sslctx)
99
100
101 @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
102 class TestWebsSocketRequestHandlerConformance:
103 @classmethod
104 def setup_class(cls):
105 cls.ws_thread, cls.ws_port = create_ws_websocket_server()
106 cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
107
108 cls.wss_thread, cls.wss_port = create_wss_websocket_server()
109 cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
110
111 cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
112 cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
113
114 cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
115 cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
116
117 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
118 def test_basic_websockets(self, handler):
119 with handler() as rh:
120 ws = validate_and_send(rh, Request(self.ws_base_url))
121 assert 'upgrade' in ws.headers
122 assert ws.status == 101
123 ws.send('foo')
124 assert ws.recv() == 'foo'
125 ws.close()
126
127 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
128 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
129 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
130 def test_send_types(self, handler, msg, opcode):
131 with handler() as rh:
132 ws = validate_and_send(rh, Request(self.ws_base_url))
133 ws.send(msg)
134 assert int(ws.recv()) == opcode
135 ws.close()
136
137 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
138 def test_verify_cert(self, handler):
139 with handler() as rh:
140 with pytest.raises(CertificateVerifyError):
141 validate_and_send(rh, Request(self.wss_base_url))
142
143 with handler(verify=False) as rh:
144 ws = validate_and_send(rh, Request(self.wss_base_url))
145 assert ws.status == 101
146 ws.close()
147
148 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
149 def test_ssl_error(self, handler):
150 with handler(verify=False) as rh:
151 with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info:
152 validate_and_send(rh, Request(self.bad_wss_host))
153 assert not issubclass(exc_info.type, CertificateVerifyError)
154
155 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
156 @pytest.mark.parametrize('path,expected', [
157 # Unicode characters should be encoded with uppercase percent-encoding
158 ('/中文', '/%E4%B8%AD%E6%96%87'),
159 # don't normalize existing percent encodings
160 ('/%c7%9f', '/%c7%9f'),
161 ])
162 def test_percent_encode(self, handler, path, expected):
163 with handler() as rh:
164 ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
165 ws.send('path')
166 assert ws.recv() == expected
167 assert ws.status == 101
168 ws.close()
169
170 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
171 def test_remove_dot_segments(self, handler):
172 with handler() as rh:
173 # This isn't a comprehensive test,
174 # but it should be enough to check whether the handler is removing dot segments
175 ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
176 assert ws.status == 101
177 ws.send('path')
178 assert ws.recv() == '/test'
179 ws.close()
180
181 # We are restricted to known HTTP status codes in http.HTTPStatus
182 # Redirects are not supported for websockets
183 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
184 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
185 def test_raise_http_error(self, handler, status):
186 with handler() as rh:
187 with pytest.raises(HTTPError) as exc_info:
188 validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
189 assert exc_info.value.status == status
190
191 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
192 @pytest.mark.parametrize('params,extensions', [
193 ({'timeout': 0.00001}, {}),
194 ({}, {'timeout': 0.00001}),
195 ])
196 def test_timeout(self, handler, params, extensions):
197 with handler(**params) as rh:
198 with pytest.raises(TransportError):
199 validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
200
201 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
202 def test_cookies(self, handler):
203 cookiejar = YoutubeDLCookieJar()
204 cookiejar.set_cookie(http.cookiejar.Cookie(
205 version=0, name='test', value='ytdlp', port=None, port_specified=False,
206 domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
207 path_specified=True, secure=False, expires=None, discard=False, comment=None,
208 comment_url=None, rest={}))
209
210 with handler(cookiejar=cookiejar) as rh:
211 ws = validate_and_send(rh, Request(self.ws_base_url))
212 ws.send('headers')
213 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
214 ws.close()
215
216 with handler() as rh:
217 ws = validate_and_send(rh, Request(self.ws_base_url))
218 ws.send('headers')
219 assert 'cookie' not in json.loads(ws.recv())
220 ws.close()
221
222 ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
223 ws.send('headers')
224 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
225 ws.close()
226
227 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
228 def test_source_address(self, handler):
229 source_address = f'127.0.0.{random.randint(5, 255)}'
230 with handler(source_address=source_address) as rh:
231 ws = validate_and_send(rh, Request(self.ws_base_url))
232 ws.send('source_address')
233 assert source_address == ws.recv()
234 ws.close()
235
236 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
237 def test_response_url(self, handler):
238 with handler() as rh:
239 url = f'{self.ws_base_url}/something'
240 ws = validate_and_send(rh, Request(url))
241 assert ws.url == url
242 ws.close()
243
244 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
245 def test_request_headers(self, handler):
246 with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
247 # Global Headers
248 ws = validate_and_send(rh, Request(self.ws_base_url))
249 ws.send('headers')
250 headers = HTTPHeaderDict(json.loads(ws.recv()))
251 assert headers['test1'] == 'test'
252 ws.close()
253
254 # Per request headers, merged with global
255 ws = validate_and_send(rh, Request(
256 self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
257 ws.send('headers')
258 headers = HTTPHeaderDict(json.loads(ws.recv()))
259 assert headers['test1'] == 'test'
260 assert headers['test2'] == 'changed'
261 assert headers['test3'] == 'test3'
262 ws.close()
263
264 @pytest.mark.parametrize('client_cert', (
265 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
266 {
267 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
268 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
269 },
270 {
271 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
272 'client_certificate_password': 'foobar',
273 },
274 {
275 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
276 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
277 'client_certificate_password': 'foobar',
278 }
279 ))
280 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
281 def test_mtls(self, handler, client_cert):
282 with handler(
283 # Disable client-side validation of unacceptable self-signed testcert.pem
284 # The test is of a check on the server side, so unaffected
285 verify=False,
286 client_cert=client_cert
287 ) as rh:
288 validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
289
290
291 def create_fake_ws_connection(raised):
292 import websockets.sync.client
293
294 class FakeWsConnection(websockets.sync.client.ClientConnection):
295 def __init__(self, *args, **kwargs):
296 class FakeResponse:
297 body = b''
298 headers = {}
299 status_code = 101
300 reason_phrase = 'test'
301
302 self.response = FakeResponse()
303
304 def send(self, *args, **kwargs):
305 raise raised()
306
307 def recv(self, *args, **kwargs):
308 raise raised()
309
310 def close(self, *args, **kwargs):
311 return
312
313 return FakeWsConnection()
314
315
316 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
317 class TestWebsocketsRequestHandler:
318 @pytest.mark.parametrize('raised,expected', [
319 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
320 (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
321 # Requires a response object. Should be covered by HTTP error tests.
322 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
323 (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
324 # These are subclasses of InvalidHandshake
325 (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
326 (lambda: websockets.exceptions.NegotiationError(), TransportError),
327 # Catch-all
328 (lambda: websockets.exceptions.WebSocketException(), TransportError),
329 (lambda: TimeoutError(), TransportError),
330 # These may be raised by our create_connection implementation, which should also be caught
331 (lambda: OSError(), TransportError),
332 (lambda: ssl.SSLError(), SSLError),
333 (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
334 (lambda: socks.ProxyError(), ProxyError),
335 ])
336 def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
337 import websockets.sync.client
338
339 import yt_dlp.networking._websockets
340 with handler() as rh:
341 def fake_connect(*args, **kwargs):
342 raise raised()
343 monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
344 monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
345 with pytest.raises(expected) as exc_info:
346 rh.send(Request('ws://fake-url'))
347 assert exc_info.type is expected
348
349 @pytest.mark.parametrize('raised,expected,match', [
350 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
351 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
352 (lambda: RuntimeError(), TransportError, None),
353 (lambda: TimeoutError(), TransportError, None),
354 (lambda: TypeError(), RequestError, None),
355 (lambda: socks.ProxyError(), ProxyError, None),
356 # Catch-all
357 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
358 ])
359 def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
360 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
361 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
362 with pytest.raises(expected, match=match) as exc_info:
363 ws.send('test')
364 assert exc_info.type is expected
365
366 @pytest.mark.parametrize('raised,expected,match', [
367 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
368 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
369 (lambda: RuntimeError(), TransportError, None),
370 (lambda: TimeoutError(), TransportError, None),
371 (lambda: socks.ProxyError(), ProxyError, None),
372 # Catch-all
373 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
374 ])
375 def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
376 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
377 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
378 with pytest.raises(expected, match=match) as exc_info:
379 ws.recv()
380 assert exc_info.type is expected