]>
Commit | Line | Data |
---|---|---|
ccfd70f4 | 1 | #!/usr/bin/env python3 |
2 | ||
3 | # Allow direct execution | |
4 | import os | |
5 | import sys | |
6 | ||
7 | import pytest | |
8 | ||
69d31914 | 9 | from test.helper import verify_address_availability |
10 | ||
ccfd70f4 | 11 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
12 | ||
13 | import http.client | |
14 | import http.cookiejar | |
15 | import http.server | |
16 | import json | |
17 | import random | |
18 | import ssl | |
19 | import threading | |
20 | ||
21 | from yt_dlp import socks | |
22 | from yt_dlp.cookies import YoutubeDLCookieJar | |
23 | from yt_dlp.dependencies import websockets | |
24 | from yt_dlp.networking import Request | |
25 | from yt_dlp.networking.exceptions import ( | |
26 | CertificateVerifyError, | |
27 | HTTPError, | |
28 | ProxyError, | |
29 | RequestError, | |
30 | SSLError, | |
31 | TransportError, | |
32 | ) | |
33 | from yt_dlp.utils.networking import HTTPHeaderDict | |
34 | ||
35 | from test.conftest import validate_and_send | |
36 | ||
37 | TEST_DIR = os.path.dirname(os.path.abspath(__file__)) | |
38 | ||
39 | ||
40 | def websocket_handler(websocket): | |
41 | for message in websocket: | |
42 | if isinstance(message, bytes): | |
43 | if message == b'bytes': | |
44 | return websocket.send('2') | |
45 | elif isinstance(message, str): | |
46 | if message == 'headers': | |
47 | return websocket.send(json.dumps(dict(websocket.request.headers))) | |
48 | elif message == 'path': | |
49 | return websocket.send(websocket.request.path) | |
50 | elif message == 'source_address': | |
51 | return websocket.send(websocket.remote_address[0]) | |
52 | elif message == 'str': | |
53 | return websocket.send('1') | |
54 | return websocket.send(message) | |
55 | ||
56 | ||
57 | def process_request(self, request): | |
58 | if request.path.startswith('/gen_'): | |
59 | status = http.HTTPStatus(int(request.path[5:])) | |
60 | if 300 <= status.value <= 300: | |
61 | return websockets.http11.Response( | |
62 | status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'') | |
63 | return self.protocol.reject(status.value, status.phrase) | |
64 | return self.protocol.accept(request) | |
65 | ||
66 | ||
67 | def create_websocket_server(**ws_kwargs): | |
68 | import websockets.sync.server | |
69 | wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs) | |
70 | ws_port = wsd.socket.getsockname()[1] | |
71 | ws_server_thread = threading.Thread(target=wsd.serve_forever) | |
72 | ws_server_thread.daemon = True | |
73 | ws_server_thread.start() | |
74 | return ws_server_thread, ws_port | |
75 | ||
76 | ||
77 | def create_ws_websocket_server(): | |
78 | return create_websocket_server() | |
79 | ||
80 | ||
81 | def create_wss_websocket_server(): | |
82 | certfn = os.path.join(TEST_DIR, 'testcert.pem') | |
83 | sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
84 | sslctx.load_cert_chain(certfn, None) | |
85 | return create_websocket_server(ssl_context=sslctx) | |
86 | ||
87 | ||
88 | MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate') | |
89 | ||
90 | ||
91 | def create_mtls_wss_websocket_server(): | |
92 | certfn = os.path.join(TEST_DIR, 'testcert.pem') | |
93 | cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt') | |
94 | ||
95 | sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
96 | sslctx.verify_mode = ssl.CERT_REQUIRED | |
97 | sslctx.load_verify_locations(cafile=cacertfn) | |
98 | sslctx.load_cert_chain(certfn, None) | |
99 | ||
100 | return create_websocket_server(ssl_context=sslctx) | |
101 | ||
102 | ||
103 | @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') | |
104 | class TestWebsSocketRequestHandlerConformance: | |
105 | @classmethod | |
106 | def setup_class(cls): | |
107 | cls.ws_thread, cls.ws_port = create_ws_websocket_server() | |
108 | cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}' | |
109 | ||
110 | cls.wss_thread, cls.wss_port = create_wss_websocket_server() | |
111 | cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}' | |
112 | ||
113 | cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)) | |
114 | cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}' | |
115 | ||
116 | cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server() | |
117 | cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}' | |
118 | ||
119 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
120 | def test_basic_websockets(self, handler): | |
121 | with handler() as rh: | |
122 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
123 | assert 'upgrade' in ws.headers | |
124 | assert ws.status == 101 | |
125 | ws.send('foo') | |
126 | assert ws.recv() == 'foo' | |
127 | ws.close() | |
128 | ||
129 | # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 | |
130 | @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)]) | |
131 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
132 | def test_send_types(self, handler, msg, opcode): | |
133 | with handler() as rh: | |
134 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
135 | ws.send(msg) | |
136 | assert int(ws.recv()) == opcode | |
137 | ws.close() | |
138 | ||
139 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
140 | def test_verify_cert(self, handler): | |
141 | with handler() as rh: | |
142 | with pytest.raises(CertificateVerifyError): | |
143 | validate_and_send(rh, Request(self.wss_base_url)) | |
144 | ||
145 | with handler(verify=False) as rh: | |
146 | ws = validate_and_send(rh, Request(self.wss_base_url)) | |
147 | assert ws.status == 101 | |
148 | ws.close() | |
149 | ||
150 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
151 | def test_ssl_error(self, handler): | |
152 | with handler(verify=False) as rh: | |
37755a03 | 153 | with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info: |
ccfd70f4 | 154 | validate_and_send(rh, Request(self.bad_wss_host)) |
155 | assert not issubclass(exc_info.type, CertificateVerifyError) | |
156 | ||
157 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
158 | @pytest.mark.parametrize('path,expected', [ | |
159 | # Unicode characters should be encoded with uppercase percent-encoding | |
160 | ('/中文', '/%E4%B8%AD%E6%96%87'), | |
161 | # don't normalize existing percent encodings | |
162 | ('/%c7%9f', '/%c7%9f'), | |
163 | ]) | |
164 | def test_percent_encode(self, handler, path, expected): | |
165 | with handler() as rh: | |
166 | ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) | |
167 | ws.send('path') | |
168 | assert ws.recv() == expected | |
169 | assert ws.status == 101 | |
170 | ws.close() | |
171 | ||
172 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
173 | def test_remove_dot_segments(self, handler): | |
174 | with handler() as rh: | |
175 | # This isn't a comprehensive test, | |
176 | # but it should be enough to check whether the handler is removing dot segments | |
177 | ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) | |
178 | assert ws.status == 101 | |
179 | ws.send('path') | |
180 | assert ws.recv() == '/test' | |
181 | ws.close() | |
182 | ||
183 | # We are restricted to known HTTP status codes in http.HTTPStatus | |
184 | # Redirects are not supported for websockets | |
185 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
186 | @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511)) | |
187 | def test_raise_http_error(self, handler, status): | |
188 | with handler() as rh: | |
189 | with pytest.raises(HTTPError) as exc_info: | |
190 | validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) | |
191 | assert exc_info.value.status == status | |
192 | ||
193 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
194 | @pytest.mark.parametrize('params,extensions', [ | |
195 | ({'timeout': 0.00001}, {}), | |
196 | ({}, {'timeout': 0.00001}), | |
197 | ]) | |
198 | def test_timeout(self, handler, params, extensions): | |
199 | with handler(**params) as rh: | |
200 | with pytest.raises(TransportError): | |
201 | validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) | |
202 | ||
203 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
204 | def test_cookies(self, handler): | |
205 | cookiejar = YoutubeDLCookieJar() | |
206 | cookiejar.set_cookie(http.cookiejar.Cookie( | |
207 | version=0, name='test', value='ytdlp', port=None, port_specified=False, | |
208 | domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/', | |
209 | path_specified=True, secure=False, expires=None, discard=False, comment=None, | |
210 | comment_url=None, rest={})) | |
211 | ||
212 | with handler(cookiejar=cookiejar) as rh: | |
213 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
214 | ws.send('headers') | |
215 | assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' | |
216 | ws.close() | |
217 | ||
218 | with handler() as rh: | |
219 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
220 | ws.send('headers') | |
221 | assert 'cookie' not in json.loads(ws.recv()) | |
222 | ws.close() | |
223 | ||
224 | ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) | |
225 | ws.send('headers') | |
226 | assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' | |
227 | ws.close() | |
228 | ||
229 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
230 | def test_source_address(self, handler): | |
231 | source_address = f'127.0.0.{random.randint(5, 255)}' | |
69d31914 | 232 | verify_address_availability(source_address) |
ccfd70f4 | 233 | with handler(source_address=source_address) as rh: |
234 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
235 | ws.send('source_address') | |
236 | assert source_address == ws.recv() | |
237 | ws.close() | |
238 | ||
239 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
240 | def test_response_url(self, handler): | |
241 | with handler() as rh: | |
242 | url = f'{self.ws_base_url}/something' | |
243 | ws = validate_and_send(rh, Request(url)) | |
244 | assert ws.url == url | |
245 | ws.close() | |
246 | ||
247 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
248 | def test_request_headers(self, handler): | |
249 | with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: | |
250 | # Global Headers | |
251 | ws = validate_and_send(rh, Request(self.ws_base_url)) | |
252 | ws.send('headers') | |
253 | headers = HTTPHeaderDict(json.loads(ws.recv())) | |
254 | assert headers['test1'] == 'test' | |
255 | ws.close() | |
256 | ||
257 | # Per request headers, merged with global | |
258 | ws = validate_and_send(rh, Request( | |
259 | self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'})) | |
260 | ws.send('headers') | |
261 | headers = HTTPHeaderDict(json.loads(ws.recv())) | |
262 | assert headers['test1'] == 'test' | |
263 | assert headers['test2'] == 'changed' | |
264 | assert headers['test3'] == 'test3' | |
265 | ws.close() | |
266 | ||
267 | @pytest.mark.parametrize('client_cert', ( | |
268 | {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}, | |
269 | { | |
270 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), | |
271 | 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'), | |
272 | }, | |
273 | { | |
274 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'), | |
275 | 'client_certificate_password': 'foobar', | |
276 | }, | |
277 | { | |
278 | 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), | |
279 | 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'), | |
280 | 'client_certificate_password': 'foobar', | |
281 | } | |
282 | )) | |
283 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
284 | def test_mtls(self, handler, client_cert): | |
285 | with handler( | |
286 | # Disable client-side validation of unacceptable self-signed testcert.pem | |
287 | # The test is of a check on the server side, so unaffected | |
288 | verify=False, | |
289 | client_cert=client_cert | |
290 | ) as rh: | |
291 | validate_and_send(rh, Request(self.mtls_wss_base_url)).close() | |
292 | ||
293 | ||
294 | def create_fake_ws_connection(raised): | |
295 | import websockets.sync.client | |
296 | ||
297 | class FakeWsConnection(websockets.sync.client.ClientConnection): | |
298 | def __init__(self, *args, **kwargs): | |
299 | class FakeResponse: | |
300 | body = b'' | |
301 | headers = {} | |
302 | status_code = 101 | |
303 | reason_phrase = 'test' | |
304 | ||
305 | self.response = FakeResponse() | |
306 | ||
307 | def send(self, *args, **kwargs): | |
308 | raise raised() | |
309 | ||
310 | def recv(self, *args, **kwargs): | |
311 | raise raised() | |
312 | ||
313 | def close(self, *args, **kwargs): | |
314 | return | |
315 | ||
316 | return FakeWsConnection() | |
317 | ||
318 | ||
319 | @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | |
320 | class TestWebsocketsRequestHandler: | |
321 | @pytest.mark.parametrize('raised,expected', [ | |
322 | # https://websockets.readthedocs.io/en/stable/reference/exceptions.html | |
323 | (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError), | |
324 | # Requires a response object. Should be covered by HTTP error tests. | |
325 | # (lambda: websockets.exceptions.InvalidStatus(), TransportError), | |
326 | (lambda: websockets.exceptions.InvalidHandshake(), TransportError), | |
327 | # These are subclasses of InvalidHandshake | |
328 | (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError), | |
329 | (lambda: websockets.exceptions.NegotiationError(), TransportError), | |
330 | # Catch-all | |
331 | (lambda: websockets.exceptions.WebSocketException(), TransportError), | |
332 | (lambda: TimeoutError(), TransportError), | |
333 | # These may be raised by our create_connection implementation, which should also be caught | |
334 | (lambda: OSError(), TransportError), | |
335 | (lambda: ssl.SSLError(), SSLError), | |
336 | (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError), | |
337 | (lambda: socks.ProxyError(), ProxyError), | |
338 | ]) | |
339 | def test_request_error_mapping(self, handler, monkeypatch, raised, expected): | |
340 | import websockets.sync.client | |
341 | ||
342 | import yt_dlp.networking._websockets | |
343 | with handler() as rh: | |
344 | def fake_connect(*args, **kwargs): | |
345 | raise raised() | |
346 | monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None) | |
347 | monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect) | |
348 | with pytest.raises(expected) as exc_info: | |
349 | rh.send(Request('ws://fake-url')) | |
350 | assert exc_info.type is expected | |
351 | ||
352 | @pytest.mark.parametrize('raised,expected,match', [ | |
353 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send | |
354 | (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), | |
355 | (lambda: RuntimeError(), TransportError, None), | |
356 | (lambda: TimeoutError(), TransportError, None), | |
357 | (lambda: TypeError(), RequestError, None), | |
358 | (lambda: socks.ProxyError(), ProxyError, None), | |
359 | # Catch-all | |
360 | (lambda: websockets.exceptions.WebSocketException(), TransportError, None), | |
361 | ]) | |
362 | def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match): | |
363 | from yt_dlp.networking._websockets import WebsocketsResponseAdapter | |
364 | ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') | |
365 | with pytest.raises(expected, match=match) as exc_info: | |
366 | ws.send('test') | |
367 | assert exc_info.type is expected | |
368 | ||
369 | @pytest.mark.parametrize('raised,expected,match', [ | |
370 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv | |
371 | (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), | |
372 | (lambda: RuntimeError(), TransportError, None), | |
373 | (lambda: TimeoutError(), TransportError, None), | |
374 | (lambda: socks.ProxyError(), ProxyError, None), | |
375 | # Catch-all | |
376 | (lambda: websockets.exceptions.WebSocketException(), TransportError, None), | |
377 | ]) | |
378 | def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match): | |
379 | from yt_dlp.networking._websockets import WebsocketsResponseAdapter | |
380 | ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') | |
381 | with pytest.raises(expected, match=match) as exc_info: | |
382 | ws.recv() | |
383 | assert exc_info.type is expected |