3 # Allow direct execution
9 from test
.helper
import verify_address_availability
11 sys
.path
.insert(0, os
.path
.dirname(os
.path
.dirname(os
.path
.abspath(__file__
))))
21 from yt_dlp
import socks
22 from yt_dlp
.cookies
import YoutubeDLCookieJar
23 from yt_dlp
.dependencies
import websockets
24 from yt_dlp
.networking
import Request
25 from yt_dlp
.networking
.exceptions
import (
26 CertificateVerifyError
,
33 from yt_dlp
.utils
.networking
import HTTPHeaderDict
35 TEST_DIR
= os
.path
.dirname(os
.path
.abspath(__file__
))
38 def websocket_handler(websocket
):
39 for message
in websocket
:
40 if isinstance(message
, bytes):
41 if message
== b
'bytes':
42 return websocket
.send('2')
43 elif isinstance(message
, str):
44 if message
== 'headers':
45 return websocket
.send(json
.dumps(dict(websocket
.request
.headers
)))
46 elif message
== 'path':
47 return websocket
.send(websocket
.request
.path
)
48 elif message
== 'source_address':
49 return websocket
.send(websocket
.remote_address
[0])
50 elif message
== 'str':
51 return websocket
.send('1')
52 return websocket
.send(message
)
55 def process_request(self
, request
):
56 if request
.path
.startswith('/gen_'):
57 status
= http
.HTTPStatus(int(request
.path
[5:]))
58 if 300 <= status
.value
<= 300:
59 return websockets
.http11
.Response(
60 status
.value
, status
.phrase
, websockets
.datastructures
.Headers([('Location', '/')]), b
'')
61 return self
.protocol
.reject(status
.value
, status
.phrase
)
62 return self
.protocol
.accept(request
)
65 def create_websocket_server(**ws_kwargs
):
66 import websockets
.sync
.server
67 wsd
= websockets
.sync
.server
.serve(
68 websocket_handler
, '127.0.0.1', 0,
69 process_request
=process_request
, open_timeout
=2, **ws_kwargs
)
70 ws_port
= wsd
.socket
.getsockname()[1]
71 ws_server_thread
= threading
.Thread(target
=wsd
.serve_forever
)
72 ws_server_thread
.daemon
= True
73 ws_server_thread
.start()
74 return ws_server_thread
, ws_port
77 def create_ws_websocket_server():
78 return create_websocket_server()
81 def create_wss_websocket_server():
82 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
83 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
84 sslctx
.load_cert_chain(certfn
, None)
85 return create_websocket_server(ssl_context
=sslctx
)
88 MTLS_CERT_DIR
= os
.path
.join(TEST_DIR
, 'testdata', 'certificate')
91 def create_mtls_wss_websocket_server():
92 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
93 cacertfn
= os
.path
.join(MTLS_CERT_DIR
, 'ca.crt')
95 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
96 sslctx
.verify_mode
= ssl
.CERT_REQUIRED
97 sslctx
.load_verify_locations(cafile
=cacertfn
)
98 sslctx
.load_cert_chain(certfn
, None)
100 return create_websocket_server(ssl_context
=sslctx
)
103 def ws_validate_and_send(rh
, req
):
106 for i
in range(max_tries
):
109 except TransportError
as e
:
110 if i
< (max_tries
- 1) and 'connection closed during handshake' in str(e
):
111 # websockets server sometimes hangs on new connections
116 @pytest.mark.skipif(not websockets
, reason
='websockets must be installed to test websocket request handlers')
117 class TestWebsSocketRequestHandlerConformance
:
119 def setup_class(cls
):
120 cls
.ws_thread
, cls
.ws_port
= create_ws_websocket_server()
121 cls
.ws_base_url
= f
'ws://127.0.0.1:{cls.ws_port}'
123 cls
.wss_thread
, cls
.wss_port
= create_wss_websocket_server()
124 cls
.wss_base_url
= f
'wss://127.0.0.1:{cls.wss_port}'
126 cls
.bad_wss_thread
, cls
.bad_wss_port
= create_websocket_server(ssl_context
=ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
))
127 cls
.bad_wss_host
= f
'wss://127.0.0.1:{cls.bad_wss_port}'
129 cls
.mtls_wss_thread
, cls
.mtls_wss_port
= create_mtls_wss_websocket_server()
130 cls
.mtls_wss_base_url
= f
'wss://127.0.0.1:{cls.mtls_wss_port}'
132 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
133 def test_basic_websockets(self
, handler
):
134 with handler() as rh
:
135 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
136 assert 'upgrade' in ws
.headers
137 assert ws
.status
== 101
139 assert ws
.recv() == 'foo'
142 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
143 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b
'bytes', 2)])
144 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
145 def test_send_types(self
, handler
, msg
, opcode
):
146 with handler() as rh
:
147 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
149 assert int(ws
.recv()) == opcode
152 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
153 def test_verify_cert(self
, handler
):
154 with handler() as rh
:
155 with pytest
.raises(CertificateVerifyError
):
156 ws_validate_and_send(rh
, Request(self
.wss_base_url
))
158 with handler(verify
=False) as rh
:
159 ws
= ws_validate_and_send(rh
, Request(self
.wss_base_url
))
160 assert ws
.status
== 101
163 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
164 def test_ssl_error(self
, handler
):
165 with handler(verify
=False) as rh
:
166 with pytest
.raises(SSLError
, match
=r
'ssl(?:v3|/tls) alert handshake failure') as exc_info
:
167 ws_validate_and_send(rh
, Request(self
.bad_wss_host
))
168 assert not issubclass(exc_info
.type, CertificateVerifyError
)
170 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
171 @pytest.mark.parametrize('path,expected', [
172 # Unicode characters should be encoded with uppercase percent-encoding
173 ('/中文', '/%E4%B8%AD%E6%96%87'),
174 # don't normalize existing percent encodings
175 ('/%c7%9f', '/%c7%9f'),
177 def test_percent_encode(self
, handler
, path
, expected
):
178 with handler() as rh
:
179 ws
= ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}{path}'))
181 assert ws
.recv() == expected
182 assert ws
.status
== 101
185 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
186 def test_remove_dot_segments(self
, handler
):
187 with handler() as rh
:
188 # This isn't a comprehensive test,
189 # but it should be enough to check whether the handler is removing dot segments
190 ws
= ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}/a/b/./../../test'))
191 assert ws
.status
== 101
193 assert ws
.recv() == '/test'
196 # We are restricted to known HTTP status codes in http.HTTPStatus
197 # Redirects are not supported for websockets
198 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
199 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
200 def test_raise_http_error(self
, handler
, status
):
201 with handler() as rh
:
202 with pytest
.raises(HTTPError
) as exc_info
:
203 ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}/gen_{status}'))
204 assert exc_info
.value
.status
== status
206 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
207 @pytest.mark.parametrize('params,extensions', [
208 ({'timeout': sys.float_info.min}
, {}),
209 ({}, {'timeout': sys.float_info.min}
),
211 def test_timeout(self
, handler
, params
, extensions
):
212 with handler(**params
) as rh
:
213 with pytest
.raises(TransportError
):
214 ws_validate_and_send(rh
, Request(self
.ws_base_url
, extensions
=extensions
))
216 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
217 def test_cookies(self
, handler
):
218 cookiejar
= YoutubeDLCookieJar()
219 cookiejar
.set_cookie(http
.cookiejar
.Cookie(
220 version
=0, name
='test', value
='ytdlp', port
=None, port_specified
=False,
221 domain
='127.0.0.1', domain_specified
=True, domain_initial_dot
=False, path
='/',
222 path_specified
=True, secure
=False, expires
=None, discard
=False, comment
=None,
223 comment_url
=None, rest
={}))
225 with handler(cookiejar
=cookiejar
) as rh
:
226 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
228 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
231 with handler() as rh
:
232 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
234 assert 'cookie' not in json
.loads(ws
.recv())
237 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
, extensions
={'cookiejar': cookiejar}
))
239 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
242 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
243 def test_source_address(self
, handler
):
244 source_address
= f
'127.0.0.{random.randint(5, 255)}'
245 verify_address_availability(source_address
)
246 with handler(source_address
=source_address
) as rh
:
247 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
248 ws
.send('source_address')
249 assert source_address
== ws
.recv()
252 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
253 def test_response_url(self
, handler
):
254 with handler() as rh
:
255 url
= f
'{self.ws_base_url}/something'
256 ws
= ws_validate_and_send(rh
, Request(url
))
260 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
261 def test_request_headers(self
, handler
):
262 with handler(headers
=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'}
)) as rh
:
264 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
266 headers
= HTTPHeaderDict(json
.loads(ws
.recv()))
267 assert headers
['test1'] == 'test'
270 # Per request headers, merged with global
271 ws
= ws_validate_and_send(rh
, Request(
272 self
.ws_base_url
, headers
={'test2': 'changed', 'test3': 'test3'}
))
274 headers
= HTTPHeaderDict(json
.loads(ws
.recv()))
275 assert headers
['test1'] == 'test'
276 assert headers
['test2'] == 'changed'
277 assert headers
['test3'] == 'test3'
280 @pytest.mark.parametrize('client_cert', (
281 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}
,
283 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'client.crt'),
284 'client_certificate_key': os
.path
.join(MTLS_CERT_DIR
, 'client.key'),
287 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'clientwithencryptedkey.crt'),
288 'client_certificate_password': 'foobar',
291 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'client.crt'),
292 'client_certificate_key': os
.path
.join(MTLS_CERT_DIR
, 'clientencrypted.key'),
293 'client_certificate_password': 'foobar',
296 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
297 def test_mtls(self
, handler
, client_cert
):
299 # Disable client-side validation of unacceptable self-signed testcert.pem
300 # The test is of a check on the server side, so unaffected
302 client_cert
=client_cert
304 ws_validate_and_send(rh
, Request(self
.mtls_wss_base_url
)).close()
307 def create_fake_ws_connection(raised
):
308 import websockets
.sync
.client
310 class FakeWsConnection(websockets
.sync
.client
.ClientConnection
):
311 def __init__(self
, *args
, **kwargs
):
316 reason_phrase
= 'test'
318 self
.response
= FakeResponse()
320 def send(self
, *args
, **kwargs
):
323 def recv(self
, *args
, **kwargs
):
326 def close(self
, *args
, **kwargs
):
329 return FakeWsConnection()
332 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
333 class TestWebsocketsRequestHandler
:
334 @pytest.mark.parametrize('raised,expected', [
335 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
336 (lambda: websockets
.exceptions
.InvalidURI(msg
='test', uri
='test://'), RequestError
),
337 # Requires a response object. Should be covered by HTTP error tests.
338 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
339 (lambda: websockets
.exceptions
.InvalidHandshake(), TransportError
),
340 # These are subclasses of InvalidHandshake
341 (lambda: websockets
.exceptions
.InvalidHeader(name
='test'), TransportError
),
342 (lambda: websockets
.exceptions
.NegotiationError(), TransportError
),
344 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
),
345 (lambda: TimeoutError(), TransportError
),
346 # These may be raised by our create_connection implementation, which should also be caught
347 (lambda: OSError(), TransportError
),
348 (lambda: ssl
.SSLError(), SSLError
),
349 (lambda: ssl
.SSLCertVerificationError(), CertificateVerifyError
),
350 (lambda: socks
.ProxyError(), ProxyError
),
352 def test_request_error_mapping(self
, handler
, monkeypatch
, raised
, expected
):
353 import websockets
.sync
.client
355 import yt_dlp
.networking
._websockets
356 with handler() as rh
:
357 def fake_connect(*args
, **kwargs
):
359 monkeypatch
.setattr(yt_dlp
.networking
._websockets
, 'create_connection', lambda *args
, **kwargs
: None)
360 monkeypatch
.setattr(websockets
.sync
.client
, 'connect', fake_connect
)
361 with pytest
.raises(expected
) as exc_info
:
362 rh
.send(Request('ws://fake-url'))
363 assert exc_info
.type is expected
365 @pytest.mark.parametrize('raised,expected,match', [
366 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
367 (lambda: websockets
.exceptions
.ConnectionClosed(None, None), TransportError
, None),
368 (lambda: RuntimeError(), TransportError
, None),
369 (lambda: TimeoutError(), TransportError
, None),
370 (lambda: TypeError(), RequestError
, None),
371 (lambda: socks
.ProxyError(), ProxyError
, None),
373 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
, None),
375 def test_ws_send_error_mapping(self
, handler
, monkeypatch
, raised
, expected
, match
):
376 from yt_dlp
.networking
._websockets
import WebsocketsResponseAdapter
377 ws
= WebsocketsResponseAdapter(create_fake_ws_connection(raised
), url
='ws://fake-url')
378 with pytest
.raises(expected
, match
=match
) as exc_info
:
380 assert exc_info
.type is expected
382 @pytest.mark.parametrize('raised,expected,match', [
383 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
384 (lambda: websockets
.exceptions
.ConnectionClosed(None, None), TransportError
, None),
385 (lambda: RuntimeError(), TransportError
, None),
386 (lambda: TimeoutError(), TransportError
, None),
387 (lambda: socks
.ProxyError(), ProxyError
, None),
389 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
, None),
391 def test_ws_recv_error_mapping(self
, handler
, monkeypatch
, raised
, expected
, match
):
392 from yt_dlp
.networking
._websockets
import WebsocketsResponseAdapter
393 ws
= WebsocketsResponseAdapter(create_fake_ws_connection(raised
), url
='ws://fake-url')
394 with pytest
.raises(expected
, match
=match
) as exc_info
:
396 assert exc_info
.type is expected