]>
Commit | Line | Data |
---|---|---|
1 | #!/usr/bin/env python3 | |
2 | ||
3 | # Allow direct execution | |
4 | import os | |
5 | import sys | |
6 | ||
7 | import pytest | |
8 | ||
9 | from test.helper import verify_address_availability | |
10 | ||
11 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
12 | ||
13 | import http.client | |
14 | import http.cookiejar | |
15 | import http.server | |
16 | import json | |
17 | import random | |
18 | import ssl | |
19 | import threading | |
20 | ||
21 | from yt_dlp import socks | |
22 | from yt_dlp.cookies import YoutubeDLCookieJar | |
23 | from yt_dlp.dependencies import websockets | |
24 | from yt_dlp.networking import Request | |
25 | from yt_dlp.networking.exceptions import ( | |
26 | CertificateVerifyError, | |
27 | HTTPError, | |
28 | ProxyError, | |
29 | RequestError, | |
30 | SSLError, | |
31 | TransportError, | |
32 | ) | |
33 | from yt_dlp.utils.networking import HTTPHeaderDict | |
34 | ||
35 | TEST_DIR = os.path.dirname(os.path.abspath(__file__)) | |
36 | ||
37 | ||
38 | def websocket_handler(websocket): | |
39 | for message in websocket: | |
40 | if isinstance(message, bytes): | |
41 | if message == b'bytes': | |
42 | return websocket.send('2') | |
43 | elif isinstance(message, str): | |
44 | if message == 'headers': | |
45 | return websocket.send(json.dumps(dict(websocket.request.headers))) | |
46 | elif message == 'path': | |
47 | return websocket.send(websocket.request.path) | |
48 | elif message == 'source_address': | |
49 | return websocket.send(websocket.remote_address[0]) | |
50 | elif message == 'str': | |
51 | return websocket.send('1') | |
52 | return websocket.send(message) | |
53 | ||
54 | ||
55 | def process_request(self, request): | |
56 | if request.path.startswith('/gen_'): | |
57 | status = http.HTTPStatus(int(request.path[5:])) | |
58 | if 300 <= status.value <= 300: | |
59 | return websockets.http11.Response( | |
60 | status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'') | |
61 | return self.protocol.reject(status.value, status.phrase) | |
62 | return self.protocol.accept(request) | |
63 | ||
64 | ||
65 | def create_websocket_server(**ws_kwargs): | |
66 | import websockets.sync.server | |
67 | wsd = websockets.sync.server.serve( | |
68 | websocket_handler, '127.0.0.1', 0, | |
69 | process_request=process_request, open_timeout=2, **ws_kwargs) | |
70 | ws_port = wsd.socket.getsockname()[1] | |
71 | ws_server_thread = threading.Thread(target=wsd.serve_forever) | |
72 | ws_server_thread.daemon = True | |
73 | ws_server_thread.start() | |
74 | return ws_server_thread, ws_port | |
75 | ||
76 | ||
77 | def create_ws_websocket_server(): | |
78 | return create_websocket_server() | |
79 | ||
80 | ||
81 | def create_wss_websocket_server(): | |
82 | certfn = os.path.join(TEST_DIR, 'testcert.pem') | |
83 | sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
84 | sslctx.load_cert_chain(certfn, None) | |
85 | return create_websocket_server(ssl_context=sslctx) | |
86 | ||
87 | ||
88 | MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate') | |
89 | ||
90 | ||
91 | def create_mtls_wss_websocket_server(): | |
92 | certfn = os.path.join(TEST_DIR, 'testcert.pem') | |
93 | cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt') | |
94 | ||
95 | sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
96 | sslctx.verify_mode = ssl.CERT_REQUIRED | |
97 | sslctx.load_verify_locations(cafile=cacertfn) | |
98 | sslctx.load_cert_chain(certfn, None) | |
99 | ||
100 | return create_websocket_server(ssl_context=sslctx) | |
101 | ||
102 | ||
103 | def ws_validate_and_send(rh, req): | |
104 | rh.validate(req) | |
105 | max_tries = 3 | |
106 | for i in range(max_tries): | |
107 | try: | |
108 | return rh.send(req) | |
109 | except TransportError as e: | |
110 | if i < (max_tries - 1) and 'connection closed during handshake' in str(e): | |
111 | # websockets server sometimes hangs on new connections | |
112 | continue | |
113 | raise | |
114 | ||
115 | ||
116 | @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') | |
117 | class TestWebsSocketRequestHandlerConformance: | |
118 | @classmethod | |
119 | def setup_class(cls): | |
120 | cls.ws_thread, cls.ws_port = create_ws_websocket_server() | |
121 | cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}' | |
122 | ||
123 | cls.wss_thread, cls.wss_port = create_wss_websocket_server() | |
124 | cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}' | |
125 | ||
126 | cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)) | |
127 | cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}' | |
128 | ||
129 | cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server() | |
130 | cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}' | |
131 | ||
132 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
133 | def test_basic_websockets(self, handler): | |
134 | with handler() as rh: | |
135 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | |
136 | assert 'upgrade' in ws.headers | |
137 | assert ws.status == 101 | |
138 | ws.send('foo') | |
139 | assert ws.recv() == 'foo' | |
140 | ws.close() | |
141 | ||
142 | # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 | |
143 | @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)]) | |
144 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
145 | def test_send_types(self, handler, msg, opcode): | |
146 | with handler() as rh: | |
147 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | |
148 | ws.send(msg) | |
149 | assert int(ws.recv()) == opcode | |
150 | ws.close() | |
151 | ||
152 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
153 | def test_verify_cert(self, handler): | |
154 | with handler() as rh: | |
155 | with pytest.raises(CertificateVerifyError): | |
156 | ws_validate_and_send(rh, Request(self.wss_base_url)) | |
157 | ||
158 | with handler(verify=False) as rh: | |
159 | ws = ws_validate_and_send(rh, Request(self.wss_base_url)) | |
160 | assert ws.status == 101 | |
161 | ws.close() | |
162 | ||
163 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
164 | def test_ssl_error(self, handler): | |
165 | with handler(verify=False) as rh: | |
166 | with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info: | |
167 | ws_validate_and_send(rh, Request(self.bad_wss_host)) | |
168 | assert not issubclass(exc_info.type, CertificateVerifyError) | |
169 | ||
170 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
171 | @pytest.mark.parametrize('path,expected', [ | |
172 | # Unicode characters should be encoded with uppercase percent-encoding | |
173 | ('/中文', '/%E4%B8%AD%E6%96%87'), | |
174 | # don't normalize existing percent encodings | |
175 | ('/%c7%9f', '/%c7%9f'), | |
176 | ]) | |
177 | def test_percent_encode(self, handler, path, expected): | |
178 | with handler() as rh: | |
179 | ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) | |
180 | ws.send('path') | |
181 | assert ws.recv() == expected | |
182 | assert ws.status == 101 | |
183 | ws.close() | |
184 | ||
185 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
186 | def test_remove_dot_segments(self, handler): | |
187 | with handler() as rh: | |
188 | # This isn't a comprehensive test, | |
189 | # but it should be enough to check whether the handler is removing dot segments | |
190 | ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) | |
191 | assert ws.status == 101 | |
192 | ws.send('path') | |
193 | assert ws.recv() == '/test' | |
194 | ws.close() | |
195 | ||
196 | # We are restricted to known HTTP status codes in http.HTTPStatus | |
197 | # Redirects are not supported for websockets | |
198 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
199 | @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511)) | |
200 | def test_raise_http_error(self, handler, status): | |
201 | with handler() as rh: | |
202 | with pytest.raises(HTTPError) as exc_info: | |
203 | ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) | |
204 | assert exc_info.value.status == status | |
205 | ||
206 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
207 | @pytest.mark.parametrize('params,extensions', [ | |
208 | ({'timeout': sys.float_info.min}, {}), | |
209 | ({}, {'timeout': sys.float_info.min}), | |
210 | ]) | |
211 | def test_timeout(self, handler, params, extensions): | |
212 | with handler(**params) as rh: | |
213 | with pytest.raises(TransportError): | |
214 | ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) | |
215 | ||
216 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
217 | def test_cookies(self, handler): | |
218 | cookiejar = YoutubeDLCookieJar() | |
219 | cookiejar.set_cookie(http.cookiejar.Cookie( | |
220 | version=0, name='test', value='ytdlp', port=None, port_specified=False, | |
221 | domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/', | |
222 | path_specified=True, secure=False, expires=None, discard=False, comment=None, | |
223 | comment_url=None, rest={})) | |
224 | ||
225 | with handler(cookiejar=cookiejar) as rh: | |
226 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | |
227 | ws.send('headers') | |
228 | assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' | |
229 | ws.close() | |
230 | ||
231 | with handler() as rh: | |
232 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | |
233 | ws.send('headers') | |
234 | assert 'cookie' not in json.loads(ws.recv()) | |
235 | ws.close() | |
236 | ||
237 | ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) | |
238 | ws.send('headers') | |
239 | assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' | |
240 | ws.close() | |
241 | ||
242 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
243 | def test_source_address(self, handler): | |
244 | source_address = f'127.0.0.{random.randint(5, 255)}' | |
245 | verify_address_availability(source_address) | |
246 | with handler(source_address=source_address) as rh: | |
247 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | |
248 | ws.send('source_address') | |
249 | assert source_address == ws.recv() | |
250 | ws.close() | |
251 | ||
252 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
253 | def test_response_url(self, handler): | |
254 | with handler() as rh: | |
255 | url = f'{self.ws_base_url}/something' | |
256 | ws = ws_validate_and_send(rh, Request(url)) | |
257 | assert ws.url == url | |
258 | ws.close() | |
259 | ||
260 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
261 | def test_request_headers(self, handler): | |
262 | with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: | |
263 | # Global Headers | |
264 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | |
265 | ws.send('headers') | |
266 | headers = HTTPHeaderDict(json.loads(ws.recv())) | |
267 | assert headers['test1'] == 'test' | |
268 | ws.close() | |
269 | ||
270 | # Per request headers, merged with global | |
271 | ws = ws_validate_and_send(rh, Request( | |
272 | self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'})) | |
273 | ws.send('headers') | |
274 | headers = HTTPHeaderDict(json.loads(ws.recv())) | |
275 | assert headers['test1'] == 'test' | |
276 | assert headers['test2'] == 'changed' | |
277 | assert headers['test3'] == 'test3' | |
278 | ws.close() | |
279 | ||
280 | @pytest.mark.parametrize('client_cert', ( | |
281 | {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}, | |
282 | { | |
283 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), | |
284 | 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'), | |
285 | }, | |
286 | { | |
287 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'), | |
288 | 'client_certificate_password': 'foobar', | |
289 | }, | |
290 | { | |
291 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), | |
292 | 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'), | |
293 | 'client_certificate_password': 'foobar', | |
294 | } | |
295 | )) | |
296 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
297 | def test_mtls(self, handler, client_cert): | |
298 | with handler( | |
299 | # Disable client-side validation of unacceptable self-signed testcert.pem | |
300 | # The test is of a check on the server side, so unaffected | |
301 | verify=False, | |
302 | client_cert=client_cert | |
303 | ) as rh: | |
304 | ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close() | |
305 | ||
306 | ||
307 | def create_fake_ws_connection(raised): | |
308 | import websockets.sync.client | |
309 | ||
310 | class FakeWsConnection(websockets.sync.client.ClientConnection): | |
311 | def __init__(self, *args, **kwargs): | |
312 | class FakeResponse: | |
313 | body = b'' | |
314 | headers = {} | |
315 | status_code = 101 | |
316 | reason_phrase = 'test' | |
317 | ||
318 | self.response = FakeResponse() | |
319 | ||
320 | def send(self, *args, **kwargs): | |
321 | raise raised() | |
322 | ||
323 | def recv(self, *args, **kwargs): | |
324 | raise raised() | |
325 | ||
326 | def close(self, *args, **kwargs): | |
327 | return | |
328 | ||
329 | return FakeWsConnection() | |
330 | ||
331 | ||
332 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
333 | class TestWebsocketsRequestHandler: | |
334 | @pytest.mark.parametrize('raised,expected', [ | |
335 | # https://websockets.readthedocs.io/en/stable/reference/exceptions.html | |
336 | (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError), | |
337 | # Requires a response object. Should be covered by HTTP error tests. | |
338 | # (lambda: websockets.exceptions.InvalidStatus(), TransportError), | |
339 | (lambda: websockets.exceptions.InvalidHandshake(), TransportError), | |
340 | # These are subclasses of InvalidHandshake | |
341 | (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError), | |
342 | (lambda: websockets.exceptions.NegotiationError(), TransportError), | |
343 | # Catch-all | |
344 | (lambda: websockets.exceptions.WebSocketException(), TransportError), | |
345 | (lambda: TimeoutError(), TransportError), | |
346 | # These may be raised by our create_connection implementation, which should also be caught | |
347 | (lambda: OSError(), TransportError), | |
348 | (lambda: ssl.SSLError(), SSLError), | |
349 | (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError), | |
350 | (lambda: socks.ProxyError(), ProxyError), | |
351 | ]) | |
352 | def test_request_error_mapping(self, handler, monkeypatch, raised, expected): | |
353 | import websockets.sync.client | |
354 | ||
355 | import yt_dlp.networking._websockets | |
356 | with handler() as rh: | |
357 | def fake_connect(*args, **kwargs): | |
358 | raise raised() | |
359 | monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None) | |
360 | monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect) | |
361 | with pytest.raises(expected) as exc_info: | |
362 | rh.send(Request('ws://fake-url')) | |
363 | assert exc_info.type is expected | |
364 | ||
365 | @pytest.mark.parametrize('raised,expected,match', [ | |
366 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send | |
367 | (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), | |
368 | (lambda: RuntimeError(), TransportError, None), | |
369 | (lambda: TimeoutError(), TransportError, None), | |
370 | (lambda: TypeError(), RequestError, None), | |
371 | (lambda: socks.ProxyError(), ProxyError, None), | |
372 | # Catch-all | |
373 | (lambda: websockets.exceptions.WebSocketException(), TransportError, None), | |
374 | ]) | |
375 | def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match): | |
376 | from yt_dlp.networking._websockets import WebsocketsResponseAdapter | |
377 | ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') | |
378 | with pytest.raises(expected, match=match) as exc_info: | |
379 | ws.send('test') | |
380 | assert exc_info.type is expected | |
381 | ||
382 | @pytest.mark.parametrize('raised,expected,match', [ | |
383 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv | |
384 | (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), | |
385 | (lambda: RuntimeError(), TransportError, None), | |
386 | (lambda: TimeoutError(), TransportError, None), | |
387 | (lambda: socks.ProxyError(), ProxyError, None), | |
388 | # Catch-all | |
389 | (lambda: websockets.exceptions.WebSocketException(), TransportError, None), | |
390 | ]) | |
391 | def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match): | |
392 | from yt_dlp.networking._websockets import WebsocketsResponseAdapter | |
393 | ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') | |
394 | with pytest.raises(expected, match=match) as exc_info: | |
395 | ws.recv() | |
396 | assert exc_info.type is expected |