]>
Commit | Line | Data |
---|---|---|
1 | #!/usr/bin/env python3 | |
2 | ||
3 | # Allow direct execution | |
4 | import os | |
5 | import sys | |
6 | ||
7 | import pytest | |
8 | ||
9 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
10 | ||
11 | import http.client | |
12 | import http.cookiejar | |
13 | import http.server | |
14 | import json | |
15 | import random | |
16 | import ssl | |
17 | import threading | |
18 | ||
19 | from yt_dlp import socks | |
20 | from yt_dlp.cookies import YoutubeDLCookieJar | |
21 | from yt_dlp.dependencies import websockets | |
22 | from yt_dlp.networking import Request | |
23 | from yt_dlp.networking.exceptions import ( | |
24 | CertificateVerifyError, | |
25 | HTTPError, | |
26 | ProxyError, | |
27 | RequestError, | |
28 | SSLError, | |
29 | TransportError, | |
30 | ) | |
31 | from yt_dlp.utils.networking import HTTPHeaderDict | |
32 | ||
33 | from test.conftest import validate_and_send | |
34 | ||
35 | TEST_DIR = os.path.dirname(os.path.abspath(__file__)) | |
36 | ||
37 | ||
38 | def websocket_handler(websocket): | |
39 | for message in websocket: | |
40 | if isinstance(message, bytes): | |
41 | if message == b'bytes': | |
42 | return websocket.send('2') | |
43 | elif isinstance(message, str): | |
44 | if message == 'headers': | |
45 | return websocket.send(json.dumps(dict(websocket.request.headers))) | |
46 | elif message == 'path': | |
47 | return websocket.send(websocket.request.path) | |
48 | elif message == 'source_address': | |
49 | return websocket.send(websocket.remote_address[0]) | |
50 | elif message == 'str': | |
51 | return websocket.send('1') | |
52 | return websocket.send(message) | |
53 | ||
54 | ||
55 | def process_request(self, request): | |
56 | if request.path.startswith('/gen_'): | |
57 | status = http.HTTPStatus(int(request.path[5:])) | |
58 | if 300 <= status.value <= 300: | |
59 | return websockets.http11.Response( | |
60 | status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'') | |
61 | return self.protocol.reject(status.value, status.phrase) | |
62 | return self.protocol.accept(request) | |
63 | ||
64 | ||
65 | def create_websocket_server(**ws_kwargs): | |
66 | import websockets.sync.server | |
67 | wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs) | |
68 | ws_port = wsd.socket.getsockname()[1] | |
69 | ws_server_thread = threading.Thread(target=wsd.serve_forever) | |
70 | ws_server_thread.daemon = True | |
71 | ws_server_thread.start() | |
72 | return ws_server_thread, ws_port | |
73 | ||
74 | ||
75 | def create_ws_websocket_server(): | |
76 | return create_websocket_server() | |
77 | ||
78 | ||
79 | def create_wss_websocket_server(): | |
80 | certfn = os.path.join(TEST_DIR, 'testcert.pem') | |
81 | sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
82 | sslctx.load_cert_chain(certfn, None) | |
83 | return create_websocket_server(ssl_context=sslctx) | |
84 | ||
85 | ||
86 | MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate') | |
87 | ||
88 | ||
89 | def create_mtls_wss_websocket_server(): | |
90 | certfn = os.path.join(TEST_DIR, 'testcert.pem') | |
91 | cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt') | |
92 | ||
93 | sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
94 | sslctx.verify_mode = ssl.CERT_REQUIRED | |
95 | sslctx.load_verify_locations(cafile=cacertfn) | |
96 | sslctx.load_cert_chain(certfn, None) | |
97 | ||
98 | return create_websocket_server(ssl_context=sslctx) | |
99 | ||
100 | ||
101 | @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') | |
102 | class TestWebsSocketRequestHandlerConformance: | |
103 | @classmethod | |
104 | def setup_class(cls): | |
105 | cls.ws_thread, cls.ws_port = create_ws_websocket_server() | |
106 | cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}' | |
107 | ||
108 | cls.wss_thread, cls.wss_port = create_wss_websocket_server() | |
109 | cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}' | |
110 | ||
111 | cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)) | |
112 | cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}' | |
113 | ||
114 | cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server() | |
115 | cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}' | |
116 | ||
117 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
118 | def test_basic_websockets(self, handler): | |
119 | with handler() as rh: | |
120 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
121 | assert 'upgrade' in ws.headers | |
122 | assert ws.status == 101 | |
123 | ws.send('foo') | |
124 | assert ws.recv() == 'foo' | |
125 | ws.close() | |
126 | ||
127 | # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 | |
128 | @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)]) | |
129 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
130 | def test_send_types(self, handler, msg, opcode): | |
131 | with handler() as rh: | |
132 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
133 | ws.send(msg) | |
134 | assert int(ws.recv()) == opcode | |
135 | ws.close() | |
136 | ||
137 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
138 | def test_verify_cert(self, handler): | |
139 | with handler() as rh: | |
140 | with pytest.raises(CertificateVerifyError): | |
141 | validate_and_send(rh, Request(self.wss_base_url)) | |
142 | ||
143 | with handler(verify=False) as rh: | |
144 | ws = validate_and_send(rh, Request(self.wss_base_url)) | |
145 | assert ws.status == 101 | |
146 | ws.close() | |
147 | ||
148 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
149 | def test_ssl_error(self, handler): | |
150 | with handler(verify=False) as rh: | |
151 | with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info: | |
152 | validate_and_send(rh, Request(self.bad_wss_host)) | |
153 | assert not issubclass(exc_info.type, CertificateVerifyError) | |
154 | ||
155 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
156 | @pytest.mark.parametrize('path,expected', [ | |
157 | # Unicode characters should be encoded with uppercase percent-encoding | |
158 | ('/中文', '/%E4%B8%AD%E6%96%87'), | |
159 | # don't normalize existing percent encodings | |
160 | ('/%c7%9f', '/%c7%9f'), | |
161 | ]) | |
162 | def test_percent_encode(self, handler, path, expected): | |
163 | with handler() as rh: | |
164 | ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) | |
165 | ws.send('path') | |
166 | assert ws.recv() == expected | |
167 | assert ws.status == 101 | |
168 | ws.close() | |
169 | ||
170 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
171 | def test_remove_dot_segments(self, handler): | |
172 | with handler() as rh: | |
173 | # This isn't a comprehensive test, | |
174 | # but it should be enough to check whether the handler is removing dot segments | |
175 | ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) | |
176 | assert ws.status == 101 | |
177 | ws.send('path') | |
178 | assert ws.recv() == '/test' | |
179 | ws.close() | |
180 | ||
181 | # We are restricted to known HTTP status codes in http.HTTPStatus | |
182 | # Redirects are not supported for websockets | |
183 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
184 | @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511)) | |
185 | def test_raise_http_error(self, handler, status): | |
186 | with handler() as rh: | |
187 | with pytest.raises(HTTPError) as exc_info: | |
188 | validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) | |
189 | assert exc_info.value.status == status | |
190 | ||
191 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
192 | @pytest.mark.parametrize('params,extensions', [ | |
193 | ({'timeout': 0.00001}, {}), | |
194 | ({}, {'timeout': 0.00001}), | |
195 | ]) | |
196 | def test_timeout(self, handler, params, extensions): | |
197 | with handler(**params) as rh: | |
198 | with pytest.raises(TransportError): | |
199 | validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) | |
200 | ||
201 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
202 | def test_cookies(self, handler): | |
203 | cookiejar = YoutubeDLCookieJar() | |
204 | cookiejar.set_cookie(http.cookiejar.Cookie( | |
205 | version=0, name='test', value='ytdlp', port=None, port_specified=False, | |
206 | domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/', | |
207 | path_specified=True, secure=False, expires=None, discard=False, comment=None, | |
208 | comment_url=None, rest={})) | |
209 | ||
210 | with handler(cookiejar=cookiejar) as rh: | |
211 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
212 | ws.send('headers') | |
213 | assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' | |
214 | ws.close() | |
215 | ||
216 | with handler() as rh: | |
217 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
218 | ws.send('headers') | |
219 | assert 'cookie' not in json.loads(ws.recv()) | |
220 | ws.close() | |
221 | ||
222 | ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) | |
223 | ws.send('headers') | |
224 | assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' | |
225 | ws.close() | |
226 | ||
227 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
228 | def test_source_address(self, handler): | |
229 | source_address = f'127.0.0.{random.randint(5, 255)}' | |
230 | with handler(source_address=source_address) as rh: | |
231 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
232 | ws.send('source_address') | |
233 | assert source_address == ws.recv() | |
234 | ws.close() | |
235 | ||
236 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
237 | def test_response_url(self, handler): | |
238 | with handler() as rh: | |
239 | url = f'{self.ws_base_url}/something' | |
240 | ws = validate_and_send(rh, Request(url)) | |
241 | assert ws.url == url | |
242 | ws.close() | |
243 | ||
244 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
245 | def test_request_headers(self, handler): | |
246 | with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: | |
247 | # Global Headers | |
248 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
249 | ws.send('headers') | |
250 | headers = HTTPHeaderDict(json.loads(ws.recv())) | |
251 | assert headers['test1'] == 'test' | |
252 | ws.close() | |
253 | ||
254 | # Per request headers, merged with global | |
255 | ws = validate_and_send(rh, Request( | |
256 | self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'})) | |
257 | ws.send('headers') | |
258 | headers = HTTPHeaderDict(json.loads(ws.recv())) | |
259 | assert headers['test1'] == 'test' | |
260 | assert headers['test2'] == 'changed' | |
261 | assert headers['test3'] == 'test3' | |
262 | ws.close() | |
263 | ||
264 | @pytest.mark.parametrize('client_cert', ( | |
265 | {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}, | |
266 | { | |
267 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), | |
268 | 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'), | |
269 | }, | |
270 | { | |
271 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'), | |
272 | 'client_certificate_password': 'foobar', | |
273 | }, | |
274 | { | |
275 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), | |
276 | 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'), | |
277 | 'client_certificate_password': 'foobar', | |
278 | } | |
279 | )) | |
280 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
281 | def test_mtls(self, handler, client_cert): | |
282 | with handler( | |
283 | # Disable client-side validation of unacceptable self-signed testcert.pem | |
284 | # The test is of a check on the server side, so unaffected | |
285 | verify=False, | |
286 | client_cert=client_cert | |
287 | ) as rh: | |
288 | validate_and_send(rh, Request(self.mtls_wss_base_url)).close() | |
289 | ||
290 | ||
291 | def create_fake_ws_connection(raised): | |
292 | import websockets.sync.client | |
293 | ||
294 | class FakeWsConnection(websockets.sync.client.ClientConnection): | |
295 | def __init__(self, *args, **kwargs): | |
296 | class FakeResponse: | |
297 | body = b'' | |
298 | headers = {} | |
299 | status_code = 101 | |
300 | reason_phrase = 'test' | |
301 | ||
302 | self.response = FakeResponse() | |
303 | ||
304 | def send(self, *args, **kwargs): | |
305 | raise raised() | |
306 | ||
307 | def recv(self, *args, **kwargs): | |
308 | raise raised() | |
309 | ||
310 | def close(self, *args, **kwargs): | |
311 | return | |
312 | ||
313 | return FakeWsConnection() | |
314 | ||
315 | ||
316 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
317 | class TestWebsocketsRequestHandler: | |
318 | @pytest.mark.parametrize('raised,expected', [ | |
319 | # https://websockets.readthedocs.io/en/stable/reference/exceptions.html | |
320 | (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError), | |
321 | # Requires a response object. Should be covered by HTTP error tests. | |
322 | # (lambda: websockets.exceptions.InvalidStatus(), TransportError), | |
323 | (lambda: websockets.exceptions.InvalidHandshake(), TransportError), | |
324 | # These are subclasses of InvalidHandshake | |
325 | (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError), | |
326 | (lambda: websockets.exceptions.NegotiationError(), TransportError), | |
327 | # Catch-all | |
328 | (lambda: websockets.exceptions.WebSocketException(), TransportError), | |
329 | (lambda: TimeoutError(), TransportError), | |
330 | # These may be raised by our create_connection implementation, which should also be caught | |
331 | (lambda: OSError(), TransportError), | |
332 | (lambda: ssl.SSLError(), SSLError), | |
333 | (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError), | |
334 | (lambda: socks.ProxyError(), ProxyError), | |
335 | ]) | |
336 | def test_request_error_mapping(self, handler, monkeypatch, raised, expected): | |
337 | import websockets.sync.client | |
338 | ||
339 | import yt_dlp.networking._websockets | |
340 | with handler() as rh: | |
341 | def fake_connect(*args, **kwargs): | |
342 | raise raised() | |
343 | monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None) | |
344 | monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect) | |
345 | with pytest.raises(expected) as exc_info: | |
346 | rh.send(Request('ws://fake-url')) | |
347 | assert exc_info.type is expected | |
348 | ||
349 | @pytest.mark.parametrize('raised,expected,match', [ | |
350 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send | |
351 | (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), | |
352 | (lambda: RuntimeError(), TransportError, None), | |
353 | (lambda: TimeoutError(), TransportError, None), | |
354 | (lambda: TypeError(), RequestError, None), | |
355 | (lambda: socks.ProxyError(), ProxyError, None), | |
356 | # Catch-all | |
357 | (lambda: websockets.exceptions.WebSocketException(), TransportError, None), | |
358 | ]) | |
359 | def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match): | |
360 | from yt_dlp.networking._websockets import WebsocketsResponseAdapter | |
361 | ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') | |
362 | with pytest.raises(expected, match=match) as exc_info: | |
363 | ws.send('test') | |
364 | assert exc_info.type is expected | |
365 | ||
366 | @pytest.mark.parametrize('raised,expected,match', [ | |
367 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv | |
368 | (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), | |
369 | (lambda: RuntimeError(), TransportError, None), | |
370 | (lambda: TimeoutError(), TransportError, None), | |
371 | (lambda: socks.ProxyError(), ProxyError, None), | |
372 | # Catch-all | |
373 | (lambda: websockets.exceptions.WebSocketException(), TransportError, None), | |
374 | ]) | |
375 | def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match): | |
376 | from yt_dlp.networking._websockets import WebsocketsResponseAdapter | |
377 | ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') | |
378 | with pytest.raises(expected, match=match) as exc_info: | |
379 | ws.recv() | |
380 | assert exc_info.type is expected |