3 # Allow direct execution
10 from test
.helper
import verify_address_availability
11 from yt_dlp
.networking
.common
import Features
, DEFAULT_TIMEOUT
13 sys
.path
.insert(0, os
.path
.dirname(os
.path
.dirname(os
.path
.abspath(__file__
))))
23 from yt_dlp
import socks
, traverse_obj
24 from yt_dlp
.cookies
import YoutubeDLCookieJar
25 from yt_dlp
.dependencies
import websockets
26 from yt_dlp
.networking
import Request
27 from yt_dlp
.networking
.exceptions
import (
28 CertificateVerifyError
,
35 from yt_dlp
.utils
.networking
import HTTPHeaderDict
37 TEST_DIR
= os
.path
.dirname(os
.path
.abspath(__file__
))
40 def websocket_handler(websocket
):
41 for message
in websocket
:
42 if isinstance(message
, bytes):
43 if message
== b
'bytes':
44 return websocket
.send('2')
45 elif isinstance(message
, str):
46 if message
== 'headers':
47 return websocket
.send(json
.dumps(dict(websocket
.request
.headers
)))
48 elif message
== 'path':
49 return websocket
.send(websocket
.request
.path
)
50 elif message
== 'source_address':
51 return websocket
.send(websocket
.remote_address
[0])
52 elif message
== 'str':
53 return websocket
.send('1')
54 return websocket
.send(message
)
57 def process_request(self
, request
):
58 if request
.path
.startswith('/gen_'):
59 status
= http
.HTTPStatus(int(request
.path
[5:]))
60 if 300 <= status
.value
<= 300:
61 return websockets
.http11
.Response(
62 status
.value
, status
.phrase
, websockets
.datastructures
.Headers([('Location', '/')]), b
'')
63 return self
.protocol
.reject(status
.value
, status
.phrase
)
64 return self
.protocol
.accept(request
)
67 def create_websocket_server(**ws_kwargs
):
68 import websockets
.sync
.server
69 wsd
= websockets
.sync
.server
.serve(
70 websocket_handler
, '127.0.0.1', 0,
71 process_request
=process_request
, open_timeout
=2, **ws_kwargs
)
72 ws_port
= wsd
.socket
.getsockname()[1]
73 ws_server_thread
= threading
.Thread(target
=wsd
.serve_forever
)
74 ws_server_thread
.daemon
= True
75 ws_server_thread
.start()
76 return ws_server_thread
, ws_port
79 def create_ws_websocket_server():
80 return create_websocket_server()
83 def create_wss_websocket_server():
84 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
85 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
86 sslctx
.load_cert_chain(certfn
, None)
87 return create_websocket_server(ssl_context
=sslctx
)
90 MTLS_CERT_DIR
= os
.path
.join(TEST_DIR
, 'testdata', 'certificate')
93 def create_mtls_wss_websocket_server():
94 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
95 cacertfn
= os
.path
.join(MTLS_CERT_DIR
, 'ca.crt')
97 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
98 sslctx
.verify_mode
= ssl
.CERT_REQUIRED
99 sslctx
.load_verify_locations(cafile
=cacertfn
)
100 sslctx
.load_cert_chain(certfn
, None)
102 return create_websocket_server(ssl_context
=sslctx
)
105 def ws_validate_and_send(rh
, req
):
108 for i
in range(max_tries
):
111 except TransportError
as e
:
112 if i
< (max_tries
- 1) and 'connection closed during handshake' in str(e
):
113 # websockets server sometimes hangs on new connections
118 @pytest.mark.skipif(not websockets
, reason
='websockets must be installed to test websocket request handlers')
119 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
120 class TestWebsSocketRequestHandlerConformance
:
122 def setup_class(cls
):
123 cls
.ws_thread
, cls
.ws_port
= create_ws_websocket_server()
124 cls
.ws_base_url
= f
'ws://127.0.0.1:{cls.ws_port}'
126 cls
.wss_thread
, cls
.wss_port
= create_wss_websocket_server()
127 cls
.wss_base_url
= f
'wss://127.0.0.1:{cls.wss_port}'
129 cls
.bad_wss_thread
, cls
.bad_wss_port
= create_websocket_server(ssl_context
=ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
))
130 cls
.bad_wss_host
= f
'wss://127.0.0.1:{cls.bad_wss_port}'
132 cls
.mtls_wss_thread
, cls
.mtls_wss_port
= create_mtls_wss_websocket_server()
133 cls
.mtls_wss_base_url
= f
'wss://127.0.0.1:{cls.mtls_wss_port}'
135 def test_basic_websockets(self
, handler
):
136 with handler() as rh
:
137 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
138 assert 'upgrade' in ws
.headers
139 assert ws
.status
== 101
141 assert ws
.recv() == 'foo'
144 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
145 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b
'bytes', 2)])
146 def test_send_types(self
, handler
, msg
, opcode
):
147 with handler() as rh
:
148 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
150 assert int(ws
.recv()) == opcode
153 def test_verify_cert(self
, handler
):
154 with handler() as rh
:
155 with pytest
.raises(CertificateVerifyError
):
156 ws_validate_and_send(rh
, Request(self
.wss_base_url
))
158 with handler(verify
=False) as rh
:
159 ws
= ws_validate_and_send(rh
, Request(self
.wss_base_url
))
160 assert ws
.status
== 101
163 def test_ssl_error(self
, handler
):
164 with handler(verify
=False) as rh
:
165 with pytest
.raises(SSLError
, match
=r
'ssl(?:v3|/tls) alert handshake failure') as exc_info
:
166 ws_validate_and_send(rh
, Request(self
.bad_wss_host
))
167 assert not issubclass(exc_info
.type, CertificateVerifyError
)
169 @pytest.mark.parametrize('path,expected', [
170 # Unicode characters should be encoded with uppercase percent-encoding
171 ('/中文', '/%E4%B8%AD%E6%96%87'),
172 # don't normalize existing percent encodings
173 ('/%c7%9f', '/%c7%9f'),
175 def test_percent_encode(self
, handler
, path
, expected
):
176 with handler() as rh
:
177 ws
= ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}{path}'))
179 assert ws
.recv() == expected
180 assert ws
.status
== 101
183 def test_remove_dot_segments(self
, handler
):
184 with handler() as rh
:
185 # This isn't a comprehensive test,
186 # but it should be enough to check whether the handler is removing dot segments
187 ws
= ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}/a/b/./../../test'))
188 assert ws
.status
== 101
190 assert ws
.recv() == '/test'
193 # We are restricted to known HTTP status codes in http.HTTPStatus
194 # Redirects are not supported for websockets
195 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
196 def test_raise_http_error(self
, handler
, status
):
197 with handler() as rh
:
198 with pytest
.raises(HTTPError
) as exc_info
:
199 ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}/gen_{status}'))
200 assert exc_info
.value
.status
== status
202 @pytest.mark.parametrize('params,extensions', [
203 ({'timeout': sys.float_info.min}
, {}),
204 ({}, {'timeout': sys.float_info.min}
),
206 def test_read_timeout(self
, handler
, params
, extensions
):
207 with handler(**params
) as rh
:
208 with pytest
.raises(TransportError
):
209 ws_validate_and_send(rh
, Request(self
.ws_base_url
, extensions
=extensions
))
211 def test_connect_timeout(self
, handler
):
212 # nothing should be listening on this port
213 connect_timeout_url
= 'ws://10.255.255.255'
214 with handler(timeout
=0.01) as rh
, pytest
.raises(TransportError
):
216 ws_validate_and_send(rh
, Request(connect_timeout_url
))
217 assert time
.time() - now
< DEFAULT_TIMEOUT
219 # Per request timeout, should override handler timeout
220 request
= Request(connect_timeout_url
, extensions
={'timeout': 0.01}
)
221 with handler() as rh
, pytest
.raises(TransportError
):
223 ws_validate_and_send(rh
, request
)
224 assert time
.time() - now
< DEFAULT_TIMEOUT
226 def test_cookies(self
, handler
):
227 cookiejar
= YoutubeDLCookieJar()
228 cookiejar
.set_cookie(http
.cookiejar
.Cookie(
229 version
=0, name
='test', value
='ytdlp', port
=None, port_specified
=False,
230 domain
='127.0.0.1', domain_specified
=True, domain_initial_dot
=False, path
='/',
231 path_specified
=True, secure
=False, expires
=None, discard
=False, comment
=None,
232 comment_url
=None, rest
={}))
234 with handler(cookiejar
=cookiejar
) as rh
:
235 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
237 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
240 with handler() as rh
:
241 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
243 assert 'cookie' not in json
.loads(ws
.recv())
246 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
, extensions
={'cookiejar': cookiejar}
))
248 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
251 def test_source_address(self
, handler
):
252 source_address
= f
'127.0.0.{random.randint(5, 255)}'
253 verify_address_availability(source_address
)
254 with handler(source_address
=source_address
) as rh
:
255 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
256 ws
.send('source_address')
257 assert source_address
== ws
.recv()
260 def test_response_url(self
, handler
):
261 with handler() as rh
:
262 url
= f
'{self.ws_base_url}/something'
263 ws
= ws_validate_and_send(rh
, Request(url
))
267 def test_request_headers(self
, handler
):
268 with handler(headers
=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'}
)) as rh
:
270 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
272 headers
= HTTPHeaderDict(json
.loads(ws
.recv()))
273 assert headers
['test1'] == 'test'
276 # Per request headers, merged with global
277 ws
= ws_validate_and_send(rh
, Request(
278 self
.ws_base_url
, headers
={'test2': 'changed', 'test3': 'test3'}
))
280 headers
= HTTPHeaderDict(json
.loads(ws
.recv()))
281 assert headers
['test1'] == 'test'
282 assert headers
['test2'] == 'changed'
283 assert headers
['test3'] == 'test3'
286 @pytest.mark.parametrize('client_cert', (
287 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}
,
289 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'client.crt'),
290 'client_certificate_key': os
.path
.join(MTLS_CERT_DIR
, 'client.key'),
293 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'clientwithencryptedkey.crt'),
294 'client_certificate_password': 'foobar',
297 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'client.crt'),
298 'client_certificate_key': os
.path
.join(MTLS_CERT_DIR
, 'clientencrypted.key'),
299 'client_certificate_password': 'foobar',
302 def test_mtls(self
, handler
, client_cert
):
304 # Disable client-side validation of unacceptable self-signed testcert.pem
305 # The test is of a check on the server side, so unaffected
307 client_cert
=client_cert
309 ws_validate_and_send(rh
, Request(self
.mtls_wss_base_url
)).close()
311 def test_request_disable_proxy(self
, handler
):
312 for proxy_proto
in handler
._SUPPORTED
_PROXY
_SCHEMES
or ['ws']:
313 # Given handler is configured with a proxy
314 with handler(proxies
={'ws': f'{proxy_proto}
://10.255.255.255'}, timeout=5) as rh:
315 # When a proxy is explicitly set to None for the request
316 ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
317 # Then no proxy should be used
318 assert ws.status == 101
321 @pytest.mark.skip_handlers_if(
322 lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does
not support NO_PROXY
')
323 def test_noproxy(self, handler):
324 for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws
']:
325 # Given the handler is configured with a proxy
326 with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout
=5) as rh
:
327 for no_proxy
in (f
'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
328 # When request no proxy includes the request url host
329 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
, proxies
={'no': no_proxy}
))
330 # Then the proxy should not be used
331 assert ws
.status
== 101
334 @pytest.mark.skip_handlers_if(
335 lambda _
, handler
: Features
.ALL_PROXY
not in handler
._SUPPORTED
_FEATURES
, 'handler does not support ALL_PROXY')
336 def test_allproxy(self
, handler
):
337 supported_proto
= traverse_obj(handler
._SUPPORTED
_PROXY
_SCHEMES
, 0, default
='ws')
338 # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
339 # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
340 with handler(proxies
={'all': f'{supported_proto}
://10.255.255.255'}, timeout=0.1) as rh:
341 with pytest.raises(TransportError):
342 ws_validate_and_send(rh, Request(self.ws_base_url)).close()
344 with handler(timeout=0.1) as rh:
345 with pytest.raises(TransportError):
346 ws_validate_and_send(
347 rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
350 def create_fake_ws_connection(raised
):
351 import websockets
.sync
.client
353 class FakeWsConnection(websockets
.sync
.client
.ClientConnection
):
354 def __init__(self
, *args
, **kwargs
):
359 reason_phrase
= 'test'
361 self
.response
= FakeResponse()
363 def send(self
, *args
, **kwargs
):
366 def recv(self
, *args
, **kwargs
):
369 def close(self
, *args
, **kwargs
):
372 return FakeWsConnection()
375 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
376 class TestWebsocketsRequestHandler
:
377 @pytest.mark.parametrize('raised,expected', [
378 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
379 (lambda: websockets
.exceptions
.InvalidURI(msg
='test', uri
='test://'), RequestError
),
380 # Requires a response object. Should be covered by HTTP error tests.
381 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
382 (lambda: websockets
.exceptions
.InvalidHandshake(), TransportError
),
383 # These are subclasses of InvalidHandshake
384 (lambda: websockets
.exceptions
.InvalidHeader(name
='test'), TransportError
),
385 (lambda: websockets
.exceptions
.NegotiationError(), TransportError
),
387 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
),
388 (lambda: TimeoutError(), TransportError
),
389 # These may be raised by our create_connection implementation, which should also be caught
390 (lambda: OSError(), TransportError
),
391 (lambda: ssl
.SSLError(), SSLError
),
392 (lambda: ssl
.SSLCertVerificationError(), CertificateVerifyError
),
393 (lambda: socks
.ProxyError(), ProxyError
),
395 def test_request_error_mapping(self
, handler
, monkeypatch
, raised
, expected
):
396 import websockets
.sync
.client
398 import yt_dlp
.networking
._websockets
399 with handler() as rh
:
400 def fake_connect(*args
, **kwargs
):
402 monkeypatch
.setattr(yt_dlp
.networking
._websockets
, 'create_connection', lambda *args
, **kwargs
: None)
403 monkeypatch
.setattr(websockets
.sync
.client
, 'connect', fake_connect
)
404 with pytest
.raises(expected
) as exc_info
:
405 rh
.send(Request('ws://fake-url'))
406 assert exc_info
.type is expected
408 @pytest.mark.parametrize('raised,expected,match', [
409 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
410 (lambda: websockets
.exceptions
.ConnectionClosed(None, None), TransportError
, None),
411 (lambda: RuntimeError(), TransportError
, None),
412 (lambda: TimeoutError(), TransportError
, None),
413 (lambda: TypeError(), RequestError
, None),
414 (lambda: socks
.ProxyError(), ProxyError
, None),
416 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
, None),
418 def test_ws_send_error_mapping(self
, handler
, monkeypatch
, raised
, expected
, match
):
419 from yt_dlp
.networking
._websockets
import WebsocketsResponseAdapter
420 ws
= WebsocketsResponseAdapter(create_fake_ws_connection(raised
), url
='ws://fake-url')
421 with pytest
.raises(expected
, match
=match
) as exc_info
:
423 assert exc_info
.type is expected
425 @pytest.mark.parametrize('raised,expected,match', [
426 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
427 (lambda: websockets
.exceptions
.ConnectionClosed(None, None), TransportError
, None),
428 (lambda: RuntimeError(), TransportError
, None),
429 (lambda: TimeoutError(), TransportError
, None),
430 (lambda: socks
.ProxyError(), ProxyError
, None),
432 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
, None),
434 def test_ws_recv_error_mapping(self
, handler
, monkeypatch
, raised
, expected
, match
):
435 from yt_dlp
.networking
._websockets
import WebsocketsResponseAdapter
436 ws
= WebsocketsResponseAdapter(create_fake_ws_connection(raised
), url
='ws://fake-url')
437 with pytest
.raises(expected
, match
=match
) as exc_info
:
439 assert exc_info
.type is expected