]>
Commit | Line | Data |
---|---|---|
ccfd70f4 | 1 | #!/usr/bin/env python3 |
2 | ||
3 | # Allow direct execution | |
4 | import os | |
5 | import sys | |
53b4d44f | 6 | import time |
ccfd70f4 | 7 | |
8 | import pytest | |
9 | ||
69d31914 | 10 | from test.helper import verify_address_availability |
53b4d44f | 11 | from yt_dlp.networking.common import Features, DEFAULT_TIMEOUT |
69d31914 | 12 | |
ccfd70f4 | 13 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
14 | ||
15 | import http.client | |
16 | import http.cookiejar | |
17 | import http.server | |
18 | import json | |
19 | import random | |
20 | import ssl | |
21 | import threading | |
22 | ||
3c7a287e | 23 | from yt_dlp import socks, traverse_obj |
ccfd70f4 | 24 | from yt_dlp.cookies import YoutubeDLCookieJar |
25 | from yt_dlp.dependencies import websockets | |
26 | from yt_dlp.networking import Request | |
27 | from yt_dlp.networking.exceptions import ( | |
28 | CertificateVerifyError, | |
29 | HTTPError, | |
30 | ProxyError, | |
31 | RequestError, | |
32 | SSLError, | |
33 | TransportError, | |
34 | ) | |
35 | from yt_dlp.utils.networking import HTTPHeaderDict | |
36 | ||
ccfd70f4 | 37 | TEST_DIR = os.path.dirname(os.path.abspath(__file__)) |
38 | ||
39 | ||
40 | def websocket_handler(websocket): | |
41 | for message in websocket: | |
42 | if isinstance(message, bytes): | |
43 | if message == b'bytes': | |
44 | return websocket.send('2') | |
45 | elif isinstance(message, str): | |
46 | if message == 'headers': | |
47 | return websocket.send(json.dumps(dict(websocket.request.headers))) | |
48 | elif message == 'path': | |
49 | return websocket.send(websocket.request.path) | |
50 | elif message == 'source_address': | |
51 | return websocket.send(websocket.remote_address[0]) | |
52 | elif message == 'str': | |
53 | return websocket.send('1') | |
54 | return websocket.send(message) | |
55 | ||
56 | ||
57 | def process_request(self, request): | |
58 | if request.path.startswith('/gen_'): | |
59 | status = http.HTTPStatus(int(request.path[5:])) | |
60 | if 300 <= status.value <= 300: | |
61 | return websockets.http11.Response( | |
62 | status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'') | |
63 | return self.protocol.reject(status.value, status.phrase) | |
64 | return self.protocol.accept(request) | |
65 | ||
66 | ||
67 | def create_websocket_server(**ws_kwargs): | |
68 | import websockets.sync.server | |
f849d77a | 69 | wsd = websockets.sync.server.serve( |
70 | websocket_handler, '127.0.0.1', 0, | |
71 | process_request=process_request, open_timeout=2, **ws_kwargs) | |
ccfd70f4 | 72 | ws_port = wsd.socket.getsockname()[1] |
73 | ws_server_thread = threading.Thread(target=wsd.serve_forever) | |
74 | ws_server_thread.daemon = True | |
75 | ws_server_thread.start() | |
76 | return ws_server_thread, ws_port | |
77 | ||
78 | ||
79 | def create_ws_websocket_server(): | |
80 | return create_websocket_server() | |
81 | ||
82 | ||
83 | def create_wss_websocket_server(): | |
84 | certfn = os.path.join(TEST_DIR, 'testcert.pem') | |
85 | sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
86 | sslctx.load_cert_chain(certfn, None) | |
87 | return create_websocket_server(ssl_context=sslctx) | |
88 | ||
89 | ||
90 | MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate') | |
91 | ||
92 | ||
93 | def create_mtls_wss_websocket_server(): | |
94 | certfn = os.path.join(TEST_DIR, 'testcert.pem') | |
95 | cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt') | |
96 | ||
97 | sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
98 | sslctx.verify_mode = ssl.CERT_REQUIRED | |
99 | sslctx.load_verify_locations(cafile=cacertfn) | |
100 | sslctx.load_cert_chain(certfn, None) | |
101 | ||
102 | return create_websocket_server(ssl_context=sslctx) | |
103 | ||
104 | ||
f849d77a | 105 | def ws_validate_and_send(rh, req): |
106 | rh.validate(req) | |
107 | max_tries = 3 | |
108 | for i in range(max_tries): | |
109 | try: | |
110 | return rh.send(req) | |
111 | except TransportError as e: | |
112 | if i < (max_tries - 1) and 'connection closed during handshake' in str(e): | |
113 | # websockets server sometimes hangs on new connections | |
114 | continue | |
115 | raise | |
116 | ||
117 | ||
ccfd70f4 | 118 | @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') |
3c7a287e | 119 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) |
ccfd70f4 | 120 | class TestWebsSocketRequestHandlerConformance: |
121 | @classmethod | |
122 | def setup_class(cls): | |
123 | cls.ws_thread, cls.ws_port = create_ws_websocket_server() | |
124 | cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}' | |
125 | ||
126 | cls.wss_thread, cls.wss_port = create_wss_websocket_server() | |
127 | cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}' | |
128 | ||
129 | cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)) | |
130 | cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}' | |
131 | ||
132 | cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server() | |
133 | cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}' | |
134 | ||
ccfd70f4 | 135 | def test_basic_websockets(self, handler): |
136 | with handler() as rh: | |
f849d77a | 137 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) |
ccfd70f4 | 138 | assert 'upgrade' in ws.headers |
139 | assert ws.status == 101 | |
140 | ws.send('foo') | |
141 | assert ws.recv() == 'foo' | |
142 | ws.close() | |
143 | ||
144 | # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 | |
145 | @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)]) | |
ccfd70f4 | 146 | def test_send_types(self, handler, msg, opcode): |
147 | with handler() as rh: | |
f849d77a | 148 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) |
ccfd70f4 | 149 | ws.send(msg) |
150 | assert int(ws.recv()) == opcode | |
151 | ws.close() | |
152 | ||
ccfd70f4 | 153 | def test_verify_cert(self, handler): |
154 | with handler() as rh: | |
155 | with pytest.raises(CertificateVerifyError): | |
f849d77a | 156 | ws_validate_and_send(rh, Request(self.wss_base_url)) |
ccfd70f4 | 157 | |
158 | with handler(verify=False) as rh: | |
f849d77a | 159 | ws = ws_validate_and_send(rh, Request(self.wss_base_url)) |
ccfd70f4 | 160 | assert ws.status == 101 |
161 | ws.close() | |
162 | ||
ccfd70f4 | 163 | def test_ssl_error(self, handler): |
164 | with handler(verify=False) as rh: | |
37755a03 | 165 | with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info: |
f849d77a | 166 | ws_validate_and_send(rh, Request(self.bad_wss_host)) |
ccfd70f4 | 167 | assert not issubclass(exc_info.type, CertificateVerifyError) |
168 | ||
ccfd70f4 | 169 | @pytest.mark.parametrize('path,expected', [ |
170 | # Unicode characters should be encoded with uppercase percent-encoding | |
171 | ('/中文', '/%E4%B8%AD%E6%96%87'), | |
172 | # don't normalize existing percent encodings | |
173 | ('/%c7%9f', '/%c7%9f'), | |
174 | ]) | |
175 | def test_percent_encode(self, handler, path, expected): | |
176 | with handler() as rh: | |
f849d77a | 177 | ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) |
ccfd70f4 | 178 | ws.send('path') |
179 | assert ws.recv() == expected | |
180 | assert ws.status == 101 | |
181 | ws.close() | |
182 | ||
ccfd70f4 | 183 | def test_remove_dot_segments(self, handler): |
184 | with handler() as rh: | |
185 | # This isn't a comprehensive test, | |
186 | # but it should be enough to check whether the handler is removing dot segments | |
f849d77a | 187 | ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) |
ccfd70f4 | 188 | assert ws.status == 101 |
189 | ws.send('path') | |
190 | assert ws.recv() == '/test' | |
191 | ws.close() | |
192 | ||
193 | # We are restricted to known HTTP status codes in http.HTTPStatus | |
194 | # Redirects are not supported for websockets | |
ccfd70f4 | 195 | @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511)) |
196 | def test_raise_http_error(self, handler, status): | |
197 | with handler() as rh: | |
198 | with pytest.raises(HTTPError) as exc_info: | |
f849d77a | 199 | ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) |
ccfd70f4 | 200 | assert exc_info.value.status == status |
201 | ||
ccfd70f4 | 202 | @pytest.mark.parametrize('params,extensions', [ |
ac340d07 | 203 | ({'timeout': sys.float_info.min}, {}), |
204 | ({}, {'timeout': sys.float_info.min}), | |
ccfd70f4 | 205 | ]) |
53b4d44f | 206 | def test_read_timeout(self, handler, params, extensions): |
ccfd70f4 | 207 | with handler(**params) as rh: |
208 | with pytest.raises(TransportError): | |
f849d77a | 209 | ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) |
ccfd70f4 | 210 | |
53b4d44f | 211 | def test_connect_timeout(self, handler): |
212 | # nothing should be listening on this port | |
213 | connect_timeout_url = 'ws://10.255.255.255' | |
214 | with handler(timeout=0.01) as rh, pytest.raises(TransportError): | |
215 | now = time.time() | |
216 | ws_validate_and_send(rh, Request(connect_timeout_url)) | |
217 | assert time.time() - now < DEFAULT_TIMEOUT | |
218 | ||
219 | # Per request timeout, should override handler timeout | |
220 | request = Request(connect_timeout_url, extensions={'timeout': 0.01}) | |
221 | with handler() as rh, pytest.raises(TransportError): | |
222 | now = time.time() | |
223 | ws_validate_and_send(rh, request) | |
224 | assert time.time() - now < DEFAULT_TIMEOUT | |
225 | ||
ccfd70f4 | 226 | def test_cookies(self, handler): |
227 | cookiejar = YoutubeDLCookieJar() | |
228 | cookiejar.set_cookie(http.cookiejar.Cookie( | |
229 | version=0, name='test', value='ytdlp', port=None, port_specified=False, | |
230 | domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/', | |
231 | path_specified=True, secure=False, expires=None, discard=False, comment=None, | |
232 | comment_url=None, rest={})) | |
233 | ||
234 | with handler(cookiejar=cookiejar) as rh: | |
f849d77a | 235 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) |
ccfd70f4 | 236 | ws.send('headers') |
237 | assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' | |
238 | ws.close() | |
239 | ||
240 | with handler() as rh: | |
f849d77a | 241 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) |
ccfd70f4 | 242 | ws.send('headers') |
243 | assert 'cookie' not in json.loads(ws.recv()) | |
244 | ws.close() | |
245 | ||
f849d77a | 246 | ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) |
ccfd70f4 | 247 | ws.send('headers') |
248 | assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' | |
249 | ws.close() | |
250 | ||
ccfd70f4 | 251 | def test_source_address(self, handler): |
252 | source_address = f'127.0.0.{random.randint(5, 255)}' | |
69d31914 | 253 | verify_address_availability(source_address) |
ccfd70f4 | 254 | with handler(source_address=source_address) as rh: |
f849d77a | 255 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) |
ccfd70f4 | 256 | ws.send('source_address') |
257 | assert source_address == ws.recv() | |
258 | ws.close() | |
259 | ||
ccfd70f4 | 260 | def test_response_url(self, handler): |
261 | with handler() as rh: | |
262 | url = f'{self.ws_base_url}/something' | |
f849d77a | 263 | ws = ws_validate_and_send(rh, Request(url)) |
ccfd70f4 | 264 | assert ws.url == url |
265 | ws.close() | |
266 | ||
ccfd70f4 | 267 | def test_request_headers(self, handler): |
268 | with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: | |
269 | # Global Headers | |
f849d77a | 270 | ws = ws_validate_and_send(rh, Request(self.ws_base_url)) |
ccfd70f4 | 271 | ws.send('headers') |
272 | headers = HTTPHeaderDict(json.loads(ws.recv())) | |
273 | assert headers['test1'] == 'test' | |
274 | ws.close() | |
275 | ||
276 | # Per request headers, merged with global | |
f849d77a | 277 | ws = ws_validate_and_send(rh, Request( |
ccfd70f4 | 278 | self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'})) |
279 | ws.send('headers') | |
280 | headers = HTTPHeaderDict(json.loads(ws.recv())) | |
281 | assert headers['test1'] == 'test' | |
282 | assert headers['test2'] == 'changed' | |
283 | assert headers['test3'] == 'test3' | |
284 | ws.close() | |
285 | ||
286 | @pytest.mark.parametrize('client_cert', ( | |
287 | {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}, | |
288 | { | |
289 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), | |
290 | 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'), | |
291 | }, | |
292 | { | |
293 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'), | |
294 | 'client_certificate_password': 'foobar', | |
295 | }, | |
296 | { | |
297 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), | |
298 | 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'), | |
299 | 'client_certificate_password': 'foobar', | |
add96eb9 | 300 | }, |
ccfd70f4 | 301 | )) |
ccfd70f4 | 302 | def test_mtls(self, handler, client_cert): |
303 | with handler( | |
304 | # Disable client-side validation of unacceptable self-signed testcert.pem | |
305 | # The test is of a check on the server side, so unaffected | |
306 | verify=False, | |
add96eb9 | 307 | client_cert=client_cert, |
ccfd70f4 | 308 | ) as rh: |
f849d77a | 309 | ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close() |
ccfd70f4 | 310 | |
3c7a287e | 311 | def test_request_disable_proxy(self, handler): |
312 | for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']: | |
313 | # Given handler is configured with a proxy | |
314 | with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh: | |
315 | # When a proxy is explicitly set to None for the request | |
316 | ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None})) | |
317 | # Then no proxy should be used | |
318 | assert ws.status == 101 | |
319 | ws.close() | |
320 | ||
321 | @pytest.mark.skip_handlers_if( | |
322 | lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY') | |
323 | def test_noproxy(self, handler): | |
324 | for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']: | |
325 | # Given the handler is configured with a proxy | |
326 | with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh: | |
327 | for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'): | |
328 | # When request no proxy includes the request url host | |
329 | ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy})) | |
330 | # Then the proxy should not be used | |
331 | assert ws.status == 101 | |
332 | ws.close() | |
333 | ||
334 | @pytest.mark.skip_handlers_if( | |
335 | lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY') | |
336 | def test_allproxy(self, handler): | |
337 | supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws') | |
338 | # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy. | |
339 | # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures. | |
340 | with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh: | |
341 | with pytest.raises(TransportError): | |
342 | ws_validate_and_send(rh, Request(self.ws_base_url)).close() | |
343 | ||
344 | with handler(timeout=0.1) as rh: | |
345 | with pytest.raises(TransportError): | |
346 | ws_validate_and_send( | |
347 | rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close() | |
348 | ||
ccfd70f4 | 349 | |
350 | def create_fake_ws_connection(raised): | |
351 | import websockets.sync.client | |
352 | ||
353 | class FakeWsConnection(websockets.sync.client.ClientConnection): | |
354 | def __init__(self, *args, **kwargs): | |
355 | class FakeResponse: | |
356 | body = b'' | |
357 | headers = {} | |
358 | status_code = 101 | |
359 | reason_phrase = 'test' | |
360 | ||
361 | self.response = FakeResponse() | |
362 | ||
363 | def send(self, *args, **kwargs): | |
364 | raise raised() | |
365 | ||
366 | def recv(self, *args, **kwargs): | |
367 | raise raised() | |
368 | ||
369 | def close(self, *args, **kwargs): | |
370 | return | |
371 | ||
372 | return FakeWsConnection() | |
373 | ||
374 | ||
375 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
376 | class TestWebsocketsRequestHandler: | |
377 | @pytest.mark.parametrize('raised,expected', [ | |
378 | # https://websockets.readthedocs.io/en/stable/reference/exceptions.html | |
379 | (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError), | |
380 | # Requires a response object. Should be covered by HTTP error tests. | |
381 | # (lambda: websockets.exceptions.InvalidStatus(), TransportError), | |
382 | (lambda: websockets.exceptions.InvalidHandshake(), TransportError), | |
383 | # These are subclasses of InvalidHandshake | |
384 | (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError), | |
385 | (lambda: websockets.exceptions.NegotiationError(), TransportError), | |
386 | # Catch-all | |
387 | (lambda: websockets.exceptions.WebSocketException(), TransportError), | |
388 | (lambda: TimeoutError(), TransportError), | |
389 | # These may be raised by our create_connection implementation, which should also be caught | |
390 | (lambda: OSError(), TransportError), | |
391 | (lambda: ssl.SSLError(), SSLError), | |
392 | (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError), | |
393 | (lambda: socks.ProxyError(), ProxyError), | |
394 | ]) | |
395 | def test_request_error_mapping(self, handler, monkeypatch, raised, expected): | |
396 | import websockets.sync.client | |
397 | ||
398 | import yt_dlp.networking._websockets | |
399 | with handler() as rh: | |
400 | def fake_connect(*args, **kwargs): | |
401 | raise raised() | |
402 | monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None) | |
403 | monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect) | |
404 | with pytest.raises(expected) as exc_info: | |
405 | rh.send(Request('ws://fake-url')) | |
406 | assert exc_info.type is expected | |
407 | ||
408 | @pytest.mark.parametrize('raised,expected,match', [ | |
409 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send | |
410 | (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), | |
411 | (lambda: RuntimeError(), TransportError, None), | |
412 | (lambda: TimeoutError(), TransportError, None), | |
413 | (lambda: TypeError(), RequestError, None), | |
414 | (lambda: socks.ProxyError(), ProxyError, None), | |
415 | # Catch-all | |
416 | (lambda: websockets.exceptions.WebSocketException(), TransportError, None), | |
417 | ]) | |
418 | def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match): | |
419 | from yt_dlp.networking._websockets import WebsocketsResponseAdapter | |
420 | ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') | |
421 | with pytest.raises(expected, match=match) as exc_info: | |
422 | ws.send('test') | |
423 | assert exc_info.type is expected | |
424 | ||
425 | @pytest.mark.parametrize('raised,expected,match', [ | |
426 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv | |
427 | (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), | |
428 | (lambda: RuntimeError(), TransportError, None), | |
429 | (lambda: TimeoutError(), TransportError, None), | |
430 | (lambda: socks.ProxyError(), ProxyError, None), | |
431 | # Catch-all | |
432 | (lambda: websockets.exceptions.WebSocketException(), TransportError, None), | |
433 | ]) | |
434 | def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match): | |
435 | from yt_dlp.networking._websockets import WebsocketsResponseAdapter | |
436 | ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') | |
437 | with pytest.raises(expected, match=match) as exc_info: | |
438 | ws.recv() | |
439 | assert exc_info.type is expected |