3 # Allow direct execution
9 from test
.helper
import verify_address_availability
11 sys
.path
.insert(0, os
.path
.dirname(os
.path
.dirname(os
.path
.abspath(__file__
))))
21 from yt_dlp
import socks
22 from yt_dlp
.cookies
import YoutubeDLCookieJar
23 from yt_dlp
.dependencies
import websockets
24 from yt_dlp
.networking
import Request
25 from yt_dlp
.networking
.exceptions
import (
26 CertificateVerifyError
,
33 from yt_dlp
.utils
.networking
import HTTPHeaderDict
35 from test
.conftest
import validate_and_send
37 TEST_DIR
= os
.path
.dirname(os
.path
.abspath(__file__
))
40 def websocket_handler(websocket
):
41 for message
in websocket
:
42 if isinstance(message
, bytes):
43 if message
== b
'bytes':
44 return websocket
.send('2')
45 elif isinstance(message
, str):
46 if message
== 'headers':
47 return websocket
.send(json
.dumps(dict(websocket
.request
.headers
)))
48 elif message
== 'path':
49 return websocket
.send(websocket
.request
.path
)
50 elif message
== 'source_address':
51 return websocket
.send(websocket
.remote_address
[0])
52 elif message
== 'str':
53 return websocket
.send('1')
54 return websocket
.send(message
)
57 def process_request(self
, request
):
58 if request
.path
.startswith('/gen_'):
59 status
= http
.HTTPStatus(int(request
.path
[5:]))
60 if 300 <= status
.value
<= 300:
61 return websockets
.http11
.Response(
62 status
.value
, status
.phrase
, websockets
.datastructures
.Headers([('Location', '/')]), b
'')
63 return self
.protocol
.reject(status
.value
, status
.phrase
)
64 return self
.protocol
.accept(request
)
67 def create_websocket_server(**ws_kwargs
):
68 import websockets
.sync
.server
69 wsd
= websockets
.sync
.server
.serve(websocket_handler
, '127.0.0.1', 0, process_request
=process_request
, **ws_kwargs
)
70 ws_port
= wsd
.socket
.getsockname()[1]
71 ws_server_thread
= threading
.Thread(target
=wsd
.serve_forever
)
72 ws_server_thread
.daemon
= True
73 ws_server_thread
.start()
74 return ws_server_thread
, ws_port
77 def create_ws_websocket_server():
78 return create_websocket_server()
81 def create_wss_websocket_server():
82 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
83 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
84 sslctx
.load_cert_chain(certfn
, None)
85 return create_websocket_server(ssl_context
=sslctx
)
88 MTLS_CERT_DIR
= os
.path
.join(TEST_DIR
, 'testdata', 'certificate')
91 def create_mtls_wss_websocket_server():
92 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
93 cacertfn
= os
.path
.join(MTLS_CERT_DIR
, 'ca.crt')
95 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
96 sslctx
.verify_mode
= ssl
.CERT_REQUIRED
97 sslctx
.load_verify_locations(cafile
=cacertfn
)
98 sslctx
.load_cert_chain(certfn
, None)
100 return create_websocket_server(ssl_context
=sslctx
)
103 @pytest.mark.skipif(not websockets
, reason
='websockets must be installed to test websocket request handlers')
104 class TestWebsSocketRequestHandlerConformance
:
106 def setup_class(cls
):
107 cls
.ws_thread
, cls
.ws_port
= create_ws_websocket_server()
108 cls
.ws_base_url
= f
'ws://127.0.0.1:{cls.ws_port}'
110 cls
.wss_thread
, cls
.wss_port
= create_wss_websocket_server()
111 cls
.wss_base_url
= f
'wss://127.0.0.1:{cls.wss_port}'
113 cls
.bad_wss_thread
, cls
.bad_wss_port
= create_websocket_server(ssl_context
=ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
))
114 cls
.bad_wss_host
= f
'wss://127.0.0.1:{cls.bad_wss_port}'
116 cls
.mtls_wss_thread
, cls
.mtls_wss_port
= create_mtls_wss_websocket_server()
117 cls
.mtls_wss_base_url
= f
'wss://127.0.0.1:{cls.mtls_wss_port}'
119 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
120 def test_basic_websockets(self
, handler
):
121 with handler() as rh
:
122 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
123 assert 'upgrade' in ws
.headers
124 assert ws
.status
== 101
126 assert ws
.recv() == 'foo'
129 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
130 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b
'bytes', 2)])
131 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
132 def test_send_types(self
, handler
, msg
, opcode
):
133 with handler() as rh
:
134 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
136 assert int(ws
.recv()) == opcode
139 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
140 def test_verify_cert(self
, handler
):
141 with handler() as rh
:
142 with pytest
.raises(CertificateVerifyError
):
143 validate_and_send(rh
, Request(self
.wss_base_url
))
145 with handler(verify
=False) as rh
:
146 ws
= validate_and_send(rh
, Request(self
.wss_base_url
))
147 assert ws
.status
== 101
150 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
151 def test_ssl_error(self
, handler
):
152 with handler(verify
=False) as rh
:
153 with pytest
.raises(SSLError
, match
=r
'ssl(?:v3|/tls) alert handshake failure') as exc_info
:
154 validate_and_send(rh
, Request(self
.bad_wss_host
))
155 assert not issubclass(exc_info
.type, CertificateVerifyError
)
157 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
158 @pytest.mark.parametrize('path,expected', [
159 # Unicode characters should be encoded with uppercase percent-encoding
160 ('/中文', '/%E4%B8%AD%E6%96%87'),
161 # don't normalize existing percent encodings
162 ('/%c7%9f', '/%c7%9f'),
164 def test_percent_encode(self
, handler
, path
, expected
):
165 with handler() as rh
:
166 ws
= validate_and_send(rh
, Request(f
'{self.ws_base_url}{path}'))
168 assert ws
.recv() == expected
169 assert ws
.status
== 101
172 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
173 def test_remove_dot_segments(self
, handler
):
174 with handler() as rh
:
175 # This isn't a comprehensive test,
176 # but it should be enough to check whether the handler is removing dot segments
177 ws
= validate_and_send(rh
, Request(f
'{self.ws_base_url}/a/b/./../../test'))
178 assert ws
.status
== 101
180 assert ws
.recv() == '/test'
183 # We are restricted to known HTTP status codes in http.HTTPStatus
184 # Redirects are not supported for websockets
185 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
186 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
187 def test_raise_http_error(self
, handler
, status
):
188 with handler() as rh
:
189 with pytest
.raises(HTTPError
) as exc_info
:
190 validate_and_send(rh
, Request(f
'{self.ws_base_url}/gen_{status}'))
191 assert exc_info
.value
.status
== status
193 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
194 @pytest.mark.parametrize('params,extensions', [
195 ({'timeout': 0.00001}
, {}),
196 ({}, {'timeout': 0.00001}
),
198 def test_timeout(self
, handler
, params
, extensions
):
199 with handler(**params
) as rh
:
200 with pytest
.raises(TransportError
):
201 validate_and_send(rh
, Request(self
.ws_base_url
, extensions
=extensions
))
203 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
204 def test_cookies(self
, handler
):
205 cookiejar
= YoutubeDLCookieJar()
206 cookiejar
.set_cookie(http
.cookiejar
.Cookie(
207 version
=0, name
='test', value
='ytdlp', port
=None, port_specified
=False,
208 domain
='127.0.0.1', domain_specified
=True, domain_initial_dot
=False, path
='/',
209 path_specified
=True, secure
=False, expires
=None, discard
=False, comment
=None,
210 comment_url
=None, rest
={}))
212 with handler(cookiejar
=cookiejar
) as rh
:
213 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
215 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
218 with handler() as rh
:
219 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
221 assert 'cookie' not in json
.loads(ws
.recv())
224 ws
= validate_and_send(rh
, Request(self
.ws_base_url
, extensions
={'cookiejar': cookiejar}
))
226 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
229 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
230 def test_source_address(self
, handler
):
231 source_address
= f
'127.0.0.{random.randint(5, 255)}'
232 verify_address_availability(source_address
)
233 with handler(source_address
=source_address
) as rh
:
234 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
235 ws
.send('source_address')
236 assert source_address
== ws
.recv()
239 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
240 def test_response_url(self
, handler
):
241 with handler() as rh
:
242 url
= f
'{self.ws_base_url}/something'
243 ws
= validate_and_send(rh
, Request(url
))
247 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
248 def test_request_headers(self
, handler
):
249 with handler(headers
=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'}
)) as rh
:
251 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
253 headers
= HTTPHeaderDict(json
.loads(ws
.recv()))
254 assert headers
['test1'] == 'test'
257 # Per request headers, merged with global
258 ws
= validate_and_send(rh
, Request(
259 self
.ws_base_url
, headers
={'test2': 'changed', 'test3': 'test3'}
))
261 headers
= HTTPHeaderDict(json
.loads(ws
.recv()))
262 assert headers
['test1'] == 'test'
263 assert headers
['test2'] == 'changed'
264 assert headers
['test3'] == 'test3'
267 @pytest.mark.parametrize('client_cert', (
268 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}
,
270 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'client.crt'),
271 'client_certificate_key': os
.path
.join(MTLS_CERT_DIR
, 'client.key'),
274 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'clientwithencryptedkey.crt'),
275 'client_certificate_password': 'foobar',
278 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'client.crt'),
279 'client_certificate_key': os
.path
.join(MTLS_CERT_DIR
, 'clientencrypted.key'),
280 'client_certificate_password': 'foobar',
283 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
284 def test_mtls(self
, handler
, client_cert
):
286 # Disable client-side validation of unacceptable self-signed testcert.pem
287 # The test is of a check on the server side, so unaffected
289 client_cert
=client_cert
291 validate_and_send(rh
, Request(self
.mtls_wss_base_url
)).close()
294 def create_fake_ws_connection(raised
):
295 import websockets
.sync
.client
297 class FakeWsConnection(websockets
.sync
.client
.ClientConnection
):
298 def __init__(self
, *args
, **kwargs
):
303 reason_phrase
= 'test'
305 self
.response
= FakeResponse()
307 def send(self
, *args
, **kwargs
):
310 def recv(self
, *args
, **kwargs
):
313 def close(self
, *args
, **kwargs
):
316 return FakeWsConnection()
319 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
320 class TestWebsocketsRequestHandler
:
321 @pytest.mark.parametrize('raised,expected', [
322 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
323 (lambda: websockets
.exceptions
.InvalidURI(msg
='test', uri
='test://'), RequestError
),
324 # Requires a response object. Should be covered by HTTP error tests.
325 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
326 (lambda: websockets
.exceptions
.InvalidHandshake(), TransportError
),
327 # These are subclasses of InvalidHandshake
328 (lambda: websockets
.exceptions
.InvalidHeader(name
='test'), TransportError
),
329 (lambda: websockets
.exceptions
.NegotiationError(), TransportError
),
331 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
),
332 (lambda: TimeoutError(), TransportError
),
333 # These may be raised by our create_connection implementation, which should also be caught
334 (lambda: OSError(), TransportError
),
335 (lambda: ssl
.SSLError(), SSLError
),
336 (lambda: ssl
.SSLCertVerificationError(), CertificateVerifyError
),
337 (lambda: socks
.ProxyError(), ProxyError
),
339 def test_request_error_mapping(self
, handler
, monkeypatch
, raised
, expected
):
340 import websockets
.sync
.client
342 import yt_dlp
.networking
._websockets
343 with handler() as rh
:
344 def fake_connect(*args
, **kwargs
):
346 monkeypatch
.setattr(yt_dlp
.networking
._websockets
, 'create_connection', lambda *args
, **kwargs
: None)
347 monkeypatch
.setattr(websockets
.sync
.client
, 'connect', fake_connect
)
348 with pytest
.raises(expected
) as exc_info
:
349 rh
.send(Request('ws://fake-url'))
350 assert exc_info
.type is expected
352 @pytest.mark.parametrize('raised,expected,match', [
353 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
354 (lambda: websockets
.exceptions
.ConnectionClosed(None, None), TransportError
, None),
355 (lambda: RuntimeError(), TransportError
, None),
356 (lambda: TimeoutError(), TransportError
, None),
357 (lambda: TypeError(), RequestError
, None),
358 (lambda: socks
.ProxyError(), ProxyError
, None),
360 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
, None),
362 def test_ws_send_error_mapping(self
, handler
, monkeypatch
, raised
, expected
, match
):
363 from yt_dlp
.networking
._websockets
import WebsocketsResponseAdapter
364 ws
= WebsocketsResponseAdapter(create_fake_ws_connection(raised
), url
='ws://fake-url')
365 with pytest
.raises(expected
, match
=match
) as exc_info
:
367 assert exc_info
.type is expected
369 @pytest.mark.parametrize('raised,expected,match', [
370 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
371 (lambda: websockets
.exceptions
.ConnectionClosed(None, None), TransportError
, None),
372 (lambda: RuntimeError(), TransportError
, None),
373 (lambda: TimeoutError(), TransportError
, None),
374 (lambda: socks
.ProxyError(), ProxyError
, None),
376 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
, None),
378 def test_ws_recv_error_mapping(self
, handler
, monkeypatch
, raised
, expected
, match
):
379 from yt_dlp
.networking
._websockets
import WebsocketsResponseAdapter
380 ws
= WebsocketsResponseAdapter(create_fake_ws_connection(raised
), url
='ws://fake-url')
381 with pytest
.raises(expected
, match
=match
) as exc_info
:
383 assert exc_info
.type is expected