]> jfr.im git - yt-dlp.git/blame - test/test_websockets.py
[networking] Remove `_CompatHTTPError` (#8871)
[yt-dlp.git] / test / test_websockets.py
CommitLineData
ccfd70f4 1#!/usr/bin/env python3
2
3# Allow direct execution
4import os
5import sys
6
7import pytest
8
69d31914 9from test.helper import verify_address_availability
10
ccfd70f4 11sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
13import http.client
14import http.cookiejar
15import http.server
16import json
17import random
18import ssl
19import threading
20
21from yt_dlp import socks
22from yt_dlp.cookies import YoutubeDLCookieJar
23from yt_dlp.dependencies import websockets
24from yt_dlp.networking import Request
25from yt_dlp.networking.exceptions import (
26 CertificateVerifyError,
27 HTTPError,
28 ProxyError,
29 RequestError,
30 SSLError,
31 TransportError,
32)
33from yt_dlp.utils.networking import HTTPHeaderDict
34
35from test.conftest import validate_and_send
36
37TEST_DIR = os.path.dirname(os.path.abspath(__file__))
38
39
40def websocket_handler(websocket):
41 for message in websocket:
42 if isinstance(message, bytes):
43 if message == b'bytes':
44 return websocket.send('2')
45 elif isinstance(message, str):
46 if message == 'headers':
47 return websocket.send(json.dumps(dict(websocket.request.headers)))
48 elif message == 'path':
49 return websocket.send(websocket.request.path)
50 elif message == 'source_address':
51 return websocket.send(websocket.remote_address[0])
52 elif message == 'str':
53 return websocket.send('1')
54 return websocket.send(message)
55
56
57def process_request(self, request):
58 if request.path.startswith('/gen_'):
59 status = http.HTTPStatus(int(request.path[5:]))
60 if 300 <= status.value <= 300:
61 return websockets.http11.Response(
62 status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
63 return self.protocol.reject(status.value, status.phrase)
64 return self.protocol.accept(request)
65
66
67def create_websocket_server(**ws_kwargs):
68 import websockets.sync.server
69 wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
70 ws_port = wsd.socket.getsockname()[1]
71 ws_server_thread = threading.Thread(target=wsd.serve_forever)
72 ws_server_thread.daemon = True
73 ws_server_thread.start()
74 return ws_server_thread, ws_port
75
76
77def create_ws_websocket_server():
78 return create_websocket_server()
79
80
81def create_wss_websocket_server():
82 certfn = os.path.join(TEST_DIR, 'testcert.pem')
83 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
84 sslctx.load_cert_chain(certfn, None)
85 return create_websocket_server(ssl_context=sslctx)
86
87
88MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
89
90
91def create_mtls_wss_websocket_server():
92 certfn = os.path.join(TEST_DIR, 'testcert.pem')
93 cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
94
95 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
96 sslctx.verify_mode = ssl.CERT_REQUIRED
97 sslctx.load_verify_locations(cafile=cacertfn)
98 sslctx.load_cert_chain(certfn, None)
99
100 return create_websocket_server(ssl_context=sslctx)
101
102
103@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
104class TestWebsSocketRequestHandlerConformance:
105 @classmethod
106 def setup_class(cls):
107 cls.ws_thread, cls.ws_port = create_ws_websocket_server()
108 cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
109
110 cls.wss_thread, cls.wss_port = create_wss_websocket_server()
111 cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
112
113 cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
114 cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
115
116 cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
117 cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
118
119 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
120 def test_basic_websockets(self, handler):
121 with handler() as rh:
122 ws = validate_and_send(rh, Request(self.ws_base_url))
123 assert 'upgrade' in ws.headers
124 assert ws.status == 101
125 ws.send('foo')
126 assert ws.recv() == 'foo'
127 ws.close()
128
129 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
130 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
131 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
132 def test_send_types(self, handler, msg, opcode):
133 with handler() as rh:
134 ws = validate_and_send(rh, Request(self.ws_base_url))
135 ws.send(msg)
136 assert int(ws.recv()) == opcode
137 ws.close()
138
139 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
140 def test_verify_cert(self, handler):
141 with handler() as rh:
142 with pytest.raises(CertificateVerifyError):
143 validate_and_send(rh, Request(self.wss_base_url))
144
145 with handler(verify=False) as rh:
146 ws = validate_and_send(rh, Request(self.wss_base_url))
147 assert ws.status == 101
148 ws.close()
149
150 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
151 def test_ssl_error(self, handler):
152 with handler(verify=False) as rh:
37755a03 153 with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
ccfd70f4 154 validate_and_send(rh, Request(self.bad_wss_host))
155 assert not issubclass(exc_info.type, CertificateVerifyError)
156
157 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
158 @pytest.mark.parametrize('path,expected', [
159 # Unicode characters should be encoded with uppercase percent-encoding
160 ('/中文', '/%E4%B8%AD%E6%96%87'),
161 # don't normalize existing percent encodings
162 ('/%c7%9f', '/%c7%9f'),
163 ])
164 def test_percent_encode(self, handler, path, expected):
165 with handler() as rh:
166 ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
167 ws.send('path')
168 assert ws.recv() == expected
169 assert ws.status == 101
170 ws.close()
171
172 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
173 def test_remove_dot_segments(self, handler):
174 with handler() as rh:
175 # This isn't a comprehensive test,
176 # but it should be enough to check whether the handler is removing dot segments
177 ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
178 assert ws.status == 101
179 ws.send('path')
180 assert ws.recv() == '/test'
181 ws.close()
182
183 # We are restricted to known HTTP status codes in http.HTTPStatus
184 # Redirects are not supported for websockets
185 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
186 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
187 def test_raise_http_error(self, handler, status):
188 with handler() as rh:
189 with pytest.raises(HTTPError) as exc_info:
190 validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
191 assert exc_info.value.status == status
192
193 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
194 @pytest.mark.parametrize('params,extensions', [
195 ({'timeout': 0.00001}, {}),
196 ({}, {'timeout': 0.00001}),
197 ])
198 def test_timeout(self, handler, params, extensions):
199 with handler(**params) as rh:
200 with pytest.raises(TransportError):
201 validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
202
203 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
204 def test_cookies(self, handler):
205 cookiejar = YoutubeDLCookieJar()
206 cookiejar.set_cookie(http.cookiejar.Cookie(
207 version=0, name='test', value='ytdlp', port=None, port_specified=False,
208 domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
209 path_specified=True, secure=False, expires=None, discard=False, comment=None,
210 comment_url=None, rest={}))
211
212 with handler(cookiejar=cookiejar) as rh:
213 ws = validate_and_send(rh, Request(self.ws_base_url))
214 ws.send('headers')
215 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
216 ws.close()
217
218 with handler() as rh:
219 ws = validate_and_send(rh, Request(self.ws_base_url))
220 ws.send('headers')
221 assert 'cookie' not in json.loads(ws.recv())
222 ws.close()
223
224 ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
225 ws.send('headers')
226 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
227 ws.close()
228
229 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
230 def test_source_address(self, handler):
231 source_address = f'127.0.0.{random.randint(5, 255)}'
69d31914 232 verify_address_availability(source_address)
ccfd70f4 233 with handler(source_address=source_address) as rh:
234 ws = validate_and_send(rh, Request(self.ws_base_url))
235 ws.send('source_address')
236 assert source_address == ws.recv()
237 ws.close()
238
239 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
240 def test_response_url(self, handler):
241 with handler() as rh:
242 url = f'{self.ws_base_url}/something'
243 ws = validate_and_send(rh, Request(url))
244 assert ws.url == url
245 ws.close()
246
247 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
248 def test_request_headers(self, handler):
249 with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
250 # Global Headers
251 ws = validate_and_send(rh, Request(self.ws_base_url))
252 ws.send('headers')
253 headers = HTTPHeaderDict(json.loads(ws.recv()))
254 assert headers['test1'] == 'test'
255 ws.close()
256
257 # Per request headers, merged with global
258 ws = validate_and_send(rh, Request(
259 self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
260 ws.send('headers')
261 headers = HTTPHeaderDict(json.loads(ws.recv()))
262 assert headers['test1'] == 'test'
263 assert headers['test2'] == 'changed'
264 assert headers['test3'] == 'test3'
265 ws.close()
266
267 @pytest.mark.parametrize('client_cert', (
268 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
269 {
270 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
271 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
272 },
273 {
274 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
275 'client_certificate_password': 'foobar',
276 },
277 {
278 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
279 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
280 'client_certificate_password': 'foobar',
281 }
282 ))
283 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
284 def test_mtls(self, handler, client_cert):
285 with handler(
286 # Disable client-side validation of unacceptable self-signed testcert.pem
287 # The test is of a check on the server side, so unaffected
288 verify=False,
289 client_cert=client_cert
290 ) as rh:
291 validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
292
293
294def create_fake_ws_connection(raised):
295 import websockets.sync.client
296
297 class FakeWsConnection(websockets.sync.client.ClientConnection):
298 def __init__(self, *args, **kwargs):
299 class FakeResponse:
300 body = b''
301 headers = {}
302 status_code = 101
303 reason_phrase = 'test'
304
305 self.response = FakeResponse()
306
307 def send(self, *args, **kwargs):
308 raise raised()
309
310 def recv(self, *args, **kwargs):
311 raise raised()
312
313 def close(self, *args, **kwargs):
314 return
315
316 return FakeWsConnection()
317
318
319@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
320class TestWebsocketsRequestHandler:
321 @pytest.mark.parametrize('raised,expected', [
322 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
323 (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
324 # Requires a response object. Should be covered by HTTP error tests.
325 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
326 (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
327 # These are subclasses of InvalidHandshake
328 (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
329 (lambda: websockets.exceptions.NegotiationError(), TransportError),
330 # Catch-all
331 (lambda: websockets.exceptions.WebSocketException(), TransportError),
332 (lambda: TimeoutError(), TransportError),
333 # These may be raised by our create_connection implementation, which should also be caught
334 (lambda: OSError(), TransportError),
335 (lambda: ssl.SSLError(), SSLError),
336 (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
337 (lambda: socks.ProxyError(), ProxyError),
338 ])
339 def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
340 import websockets.sync.client
341
342 import yt_dlp.networking._websockets
343 with handler() as rh:
344 def fake_connect(*args, **kwargs):
345 raise raised()
346 monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
347 monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
348 with pytest.raises(expected) as exc_info:
349 rh.send(Request('ws://fake-url'))
350 assert exc_info.type is expected
351
352 @pytest.mark.parametrize('raised,expected,match', [
353 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
354 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
355 (lambda: RuntimeError(), TransportError, None),
356 (lambda: TimeoutError(), TransportError, None),
357 (lambda: TypeError(), RequestError, None),
358 (lambda: socks.ProxyError(), ProxyError, None),
359 # Catch-all
360 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
361 ])
362 def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
363 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
364 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
365 with pytest.raises(expected, match=match) as exc_info:
366 ws.send('test')
367 assert exc_info.type is expected
368
369 @pytest.mark.parametrize('raised,expected,match', [
370 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
371 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
372 (lambda: RuntimeError(), TransportError, None),
373 (lambda: TimeoutError(), TransportError, None),
374 (lambda: socks.ProxyError(), ProxyError, None),
375 # Catch-all
376 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
377 ])
378 def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
379 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
380 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
381 with pytest.raises(expected, match=match) as exc_info:
382 ws.recv()
383 assert exc_info.type is expected