]> jfr.im git - yt-dlp.git/blame - test/test_websockets.py
Add new options `--impersonate` and `--list-impersonate-targets`
[yt-dlp.git] / test / test_websockets.py
CommitLineData
ccfd70f4 1#!/usr/bin/env python3
2
3# Allow direct execution
4import os
5import sys
6
7import pytest
8
69d31914 9from test.helper import verify_address_availability
10
ccfd70f4 11sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
13import http.client
14import http.cookiejar
15import http.server
16import json
17import random
18import ssl
19import threading
20
21from yt_dlp import socks
22from yt_dlp.cookies import YoutubeDLCookieJar
23from yt_dlp.dependencies import websockets
24from yt_dlp.networking import Request
25from yt_dlp.networking.exceptions import (
26 CertificateVerifyError,
27 HTTPError,
28 ProxyError,
29 RequestError,
30 SSLError,
31 TransportError,
32)
33from yt_dlp.utils.networking import HTTPHeaderDict
34
ccfd70f4 35TEST_DIR = os.path.dirname(os.path.abspath(__file__))
36
37
38def websocket_handler(websocket):
39 for message in websocket:
40 if isinstance(message, bytes):
41 if message == b'bytes':
42 return websocket.send('2')
43 elif isinstance(message, str):
44 if message == 'headers':
45 return websocket.send(json.dumps(dict(websocket.request.headers)))
46 elif message == 'path':
47 return websocket.send(websocket.request.path)
48 elif message == 'source_address':
49 return websocket.send(websocket.remote_address[0])
50 elif message == 'str':
51 return websocket.send('1')
52 return websocket.send(message)
53
54
55def process_request(self, request):
56 if request.path.startswith('/gen_'):
57 status = http.HTTPStatus(int(request.path[5:]))
58 if 300 <= status.value <= 300:
59 return websockets.http11.Response(
60 status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
61 return self.protocol.reject(status.value, status.phrase)
62 return self.protocol.accept(request)
63
64
65def create_websocket_server(**ws_kwargs):
66 import websockets.sync.server
f849d77a 67 wsd = websockets.sync.server.serve(
68 websocket_handler, '127.0.0.1', 0,
69 process_request=process_request, open_timeout=2, **ws_kwargs)
ccfd70f4 70 ws_port = wsd.socket.getsockname()[1]
71 ws_server_thread = threading.Thread(target=wsd.serve_forever)
72 ws_server_thread.daemon = True
73 ws_server_thread.start()
74 return ws_server_thread, ws_port
75
76
77def create_ws_websocket_server():
78 return create_websocket_server()
79
80
81def create_wss_websocket_server():
82 certfn = os.path.join(TEST_DIR, 'testcert.pem')
83 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
84 sslctx.load_cert_chain(certfn, None)
85 return create_websocket_server(ssl_context=sslctx)
86
87
88MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
89
90
91def create_mtls_wss_websocket_server():
92 certfn = os.path.join(TEST_DIR, 'testcert.pem')
93 cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
94
95 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
96 sslctx.verify_mode = ssl.CERT_REQUIRED
97 sslctx.load_verify_locations(cafile=cacertfn)
98 sslctx.load_cert_chain(certfn, None)
99
100 return create_websocket_server(ssl_context=sslctx)
101
102
f849d77a 103def ws_validate_and_send(rh, req):
104 rh.validate(req)
105 max_tries = 3
106 for i in range(max_tries):
107 try:
108 return rh.send(req)
109 except TransportError as e:
110 if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
111 # websockets server sometimes hangs on new connections
112 continue
113 raise
114
115
ccfd70f4 116@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
117class TestWebsSocketRequestHandlerConformance:
118 @classmethod
119 def setup_class(cls):
120 cls.ws_thread, cls.ws_port = create_ws_websocket_server()
121 cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
122
123 cls.wss_thread, cls.wss_port = create_wss_websocket_server()
124 cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
125
126 cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
127 cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
128
129 cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
130 cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
131
132 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
133 def test_basic_websockets(self, handler):
134 with handler() as rh:
f849d77a 135 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 136 assert 'upgrade' in ws.headers
137 assert ws.status == 101
138 ws.send('foo')
139 assert ws.recv() == 'foo'
140 ws.close()
141
142 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
143 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
144 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
145 def test_send_types(self, handler, msg, opcode):
146 with handler() as rh:
f849d77a 147 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 148 ws.send(msg)
149 assert int(ws.recv()) == opcode
150 ws.close()
151
152 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
153 def test_verify_cert(self, handler):
154 with handler() as rh:
155 with pytest.raises(CertificateVerifyError):
f849d77a 156 ws_validate_and_send(rh, Request(self.wss_base_url))
ccfd70f4 157
158 with handler(verify=False) as rh:
f849d77a 159 ws = ws_validate_and_send(rh, Request(self.wss_base_url))
ccfd70f4 160 assert ws.status == 101
161 ws.close()
162
163 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
164 def test_ssl_error(self, handler):
165 with handler(verify=False) as rh:
37755a03 166 with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
f849d77a 167 ws_validate_and_send(rh, Request(self.bad_wss_host))
ccfd70f4 168 assert not issubclass(exc_info.type, CertificateVerifyError)
169
170 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
171 @pytest.mark.parametrize('path,expected', [
172 # Unicode characters should be encoded with uppercase percent-encoding
173 ('/中文', '/%E4%B8%AD%E6%96%87'),
174 # don't normalize existing percent encodings
175 ('/%c7%9f', '/%c7%9f'),
176 ])
177 def test_percent_encode(self, handler, path, expected):
178 with handler() as rh:
f849d77a 179 ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
ccfd70f4 180 ws.send('path')
181 assert ws.recv() == expected
182 assert ws.status == 101
183 ws.close()
184
185 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
186 def test_remove_dot_segments(self, handler):
187 with handler() as rh:
188 # This isn't a comprehensive test,
189 # but it should be enough to check whether the handler is removing dot segments
f849d77a 190 ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
ccfd70f4 191 assert ws.status == 101
192 ws.send('path')
193 assert ws.recv() == '/test'
194 ws.close()
195
196 # We are restricted to known HTTP status codes in http.HTTPStatus
197 # Redirects are not supported for websockets
198 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
199 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
200 def test_raise_http_error(self, handler, status):
201 with handler() as rh:
202 with pytest.raises(HTTPError) as exc_info:
f849d77a 203 ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
ccfd70f4 204 assert exc_info.value.status == status
205
206 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
207 @pytest.mark.parametrize('params,extensions', [
ac340d07 208 ({'timeout': sys.float_info.min}, {}),
209 ({}, {'timeout': sys.float_info.min}),
ccfd70f4 210 ])
211 def test_timeout(self, handler, params, extensions):
212 with handler(**params) as rh:
213 with pytest.raises(TransportError):
f849d77a 214 ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
ccfd70f4 215
216 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
217 def test_cookies(self, handler):
218 cookiejar = YoutubeDLCookieJar()
219 cookiejar.set_cookie(http.cookiejar.Cookie(
220 version=0, name='test', value='ytdlp', port=None, port_specified=False,
221 domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
222 path_specified=True, secure=False, expires=None, discard=False, comment=None,
223 comment_url=None, rest={}))
224
225 with handler(cookiejar=cookiejar) as rh:
f849d77a 226 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 227 ws.send('headers')
228 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
229 ws.close()
230
231 with handler() as rh:
f849d77a 232 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 233 ws.send('headers')
234 assert 'cookie' not in json.loads(ws.recv())
235 ws.close()
236
f849d77a 237 ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
ccfd70f4 238 ws.send('headers')
239 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
240 ws.close()
241
242 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
243 def test_source_address(self, handler):
244 source_address = f'127.0.0.{random.randint(5, 255)}'
69d31914 245 verify_address_availability(source_address)
ccfd70f4 246 with handler(source_address=source_address) as rh:
f849d77a 247 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 248 ws.send('source_address')
249 assert source_address == ws.recv()
250 ws.close()
251
252 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
253 def test_response_url(self, handler):
254 with handler() as rh:
255 url = f'{self.ws_base_url}/something'
f849d77a 256 ws = ws_validate_and_send(rh, Request(url))
ccfd70f4 257 assert ws.url == url
258 ws.close()
259
260 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
261 def test_request_headers(self, handler):
262 with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
263 # Global Headers
f849d77a 264 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ccfd70f4 265 ws.send('headers')
266 headers = HTTPHeaderDict(json.loads(ws.recv()))
267 assert headers['test1'] == 'test'
268 ws.close()
269
270 # Per request headers, merged with global
f849d77a 271 ws = ws_validate_and_send(rh, Request(
ccfd70f4 272 self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
273 ws.send('headers')
274 headers = HTTPHeaderDict(json.loads(ws.recv()))
275 assert headers['test1'] == 'test'
276 assert headers['test2'] == 'changed'
277 assert headers['test3'] == 'test3'
278 ws.close()
279
280 @pytest.mark.parametrize('client_cert', (
281 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
282 {
283 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
284 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
285 },
286 {
287 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
288 'client_certificate_password': 'foobar',
289 },
290 {
291 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
292 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
293 'client_certificate_password': 'foobar',
294 }
295 ))
296 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
297 def test_mtls(self, handler, client_cert):
298 with handler(
299 # Disable client-side validation of unacceptable self-signed testcert.pem
300 # The test is of a check on the server side, so unaffected
301 verify=False,
302 client_cert=client_cert
303 ) as rh:
f849d77a 304 ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
ccfd70f4 305
306
307def create_fake_ws_connection(raised):
308 import websockets.sync.client
309
310 class FakeWsConnection(websockets.sync.client.ClientConnection):
311 def __init__(self, *args, **kwargs):
312 class FakeResponse:
313 body = b''
314 headers = {}
315 status_code = 101
316 reason_phrase = 'test'
317
318 self.response = FakeResponse()
319
320 def send(self, *args, **kwargs):
321 raise raised()
322
323 def recv(self, *args, **kwargs):
324 raise raised()
325
326 def close(self, *args, **kwargs):
327 return
328
329 return FakeWsConnection()
330
331
332@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
333class TestWebsocketsRequestHandler:
334 @pytest.mark.parametrize('raised,expected', [
335 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
336 (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
337 # Requires a response object. Should be covered by HTTP error tests.
338 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
339 (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
340 # These are subclasses of InvalidHandshake
341 (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
342 (lambda: websockets.exceptions.NegotiationError(), TransportError),
343 # Catch-all
344 (lambda: websockets.exceptions.WebSocketException(), TransportError),
345 (lambda: TimeoutError(), TransportError),
346 # These may be raised by our create_connection implementation, which should also be caught
347 (lambda: OSError(), TransportError),
348 (lambda: ssl.SSLError(), SSLError),
349 (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
350 (lambda: socks.ProxyError(), ProxyError),
351 ])
352 def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
353 import websockets.sync.client
354
355 import yt_dlp.networking._websockets
356 with handler() as rh:
357 def fake_connect(*args, **kwargs):
358 raise raised()
359 monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
360 monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
361 with pytest.raises(expected) as exc_info:
362 rh.send(Request('ws://fake-url'))
363 assert exc_info.type is expected
364
365 @pytest.mark.parametrize('raised,expected,match', [
366 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
367 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
368 (lambda: RuntimeError(), TransportError, None),
369 (lambda: TimeoutError(), TransportError, None),
370 (lambda: TypeError(), RequestError, None),
371 (lambda: socks.ProxyError(), ProxyError, None),
372 # Catch-all
373 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
374 ])
375 def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
376 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
377 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
378 with pytest.raises(expected, match=match) as exc_info:
379 ws.send('test')
380 assert exc_info.type is expected
381
382 @pytest.mark.parametrize('raised,expected,match', [
383 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
384 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
385 (lambda: RuntimeError(), TransportError, None),
386 (lambda: TimeoutError(), TransportError, None),
387 (lambda: socks.ProxyError(), ProxyError, None),
388 # Catch-all
389 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
390 ])
391 def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
392 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
393 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
394 with pytest.raises(expected, match=match) as exc_info:
395 ws.recv()
396 assert exc_info.type is expected