3 # Allow direct execution
9 sys
.path
.insert(0, os
.path
.dirname(os
.path
.dirname(os
.path
.abspath(__file__
))))
19 from yt_dlp
import socks
20 from yt_dlp
.cookies
import YoutubeDLCookieJar
21 from yt_dlp
.dependencies
import websockets
22 from yt_dlp
.networking
import Request
23 from yt_dlp
.networking
.exceptions
import (
24 CertificateVerifyError
,
31 from yt_dlp
.utils
.networking
import HTTPHeaderDict
33 from test
.conftest
import validate_and_send
35 TEST_DIR
= os
.path
.dirname(os
.path
.abspath(__file__
))
38 def websocket_handler(websocket
):
39 for message
in websocket
:
40 if isinstance(message
, bytes):
41 if message
== b
'bytes':
42 return websocket
.send('2')
43 elif isinstance(message
, str):
44 if message
== 'headers':
45 return websocket
.send(json
.dumps(dict(websocket
.request
.headers
)))
46 elif message
== 'path':
47 return websocket
.send(websocket
.request
.path
)
48 elif message
== 'source_address':
49 return websocket
.send(websocket
.remote_address
[0])
50 elif message
== 'str':
51 return websocket
.send('1')
52 return websocket
.send(message
)
55 def process_request(self
, request
):
56 if request
.path
.startswith('/gen_'):
57 status
= http
.HTTPStatus(int(request
.path
[5:]))
58 if 300 <= status
.value
<= 300:
59 return websockets
.http11
.Response(
60 status
.value
, status
.phrase
, websockets
.datastructures
.Headers([('Location', '/')]), b
'')
61 return self
.protocol
.reject(status
.value
, status
.phrase
)
62 return self
.protocol
.accept(request
)
65 def create_websocket_server(**ws_kwargs
):
66 import websockets
.sync
.server
67 wsd
= websockets
.sync
.server
.serve(websocket_handler
, '127.0.0.1', 0, process_request
=process_request
, **ws_kwargs
)
68 ws_port
= wsd
.socket
.getsockname()[1]
69 ws_server_thread
= threading
.Thread(target
=wsd
.serve_forever
)
70 ws_server_thread
.daemon
= True
71 ws_server_thread
.start()
72 return ws_server_thread
, ws_port
75 def create_ws_websocket_server():
76 return create_websocket_server()
79 def create_wss_websocket_server():
80 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
81 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
82 sslctx
.load_cert_chain(certfn
, None)
83 return create_websocket_server(ssl_context
=sslctx
)
86 MTLS_CERT_DIR
= os
.path
.join(TEST_DIR
, 'testdata', 'certificate')
89 def create_mtls_wss_websocket_server():
90 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
91 cacertfn
= os
.path
.join(MTLS_CERT_DIR
, 'ca.crt')
93 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
94 sslctx
.verify_mode
= ssl
.CERT_REQUIRED
95 sslctx
.load_verify_locations(cafile
=cacertfn
)
96 sslctx
.load_cert_chain(certfn
, None)
98 return create_websocket_server(ssl_context
=sslctx
)
101 @pytest.mark.skipif(not websockets
, reason
='websockets must be installed to test websocket request handlers')
102 class TestWebsSocketRequestHandlerConformance
:
104 def setup_class(cls
):
105 cls
.ws_thread
, cls
.ws_port
= create_ws_websocket_server()
106 cls
.ws_base_url
= f
'ws://127.0.0.1:{cls.ws_port}'
108 cls
.wss_thread
, cls
.wss_port
= create_wss_websocket_server()
109 cls
.wss_base_url
= f
'wss://127.0.0.1:{cls.wss_port}'
111 cls
.bad_wss_thread
, cls
.bad_wss_port
= create_websocket_server(ssl_context
=ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
))
112 cls
.bad_wss_host
= f
'wss://127.0.0.1:{cls.bad_wss_port}'
114 cls
.mtls_wss_thread
, cls
.mtls_wss_port
= create_mtls_wss_websocket_server()
115 cls
.mtls_wss_base_url
= f
'wss://127.0.0.1:{cls.mtls_wss_port}'
117 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
118 def test_basic_websockets(self
, handler
):
119 with handler() as rh
:
120 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
121 assert 'upgrade' in ws
.headers
122 assert ws
.status
== 101
124 assert ws
.recv() == 'foo'
127 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
128 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b
'bytes', 2)])
129 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
130 def test_send_types(self
, handler
, msg
, opcode
):
131 with handler() as rh
:
132 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
134 assert int(ws
.recv()) == opcode
137 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
138 def test_verify_cert(self
, handler
):
139 with handler() as rh
:
140 with pytest
.raises(CertificateVerifyError
):
141 validate_and_send(rh
, Request(self
.wss_base_url
))
143 with handler(verify
=False) as rh
:
144 ws
= validate_and_send(rh
, Request(self
.wss_base_url
))
145 assert ws
.status
== 101
148 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
149 def test_ssl_error(self
, handler
):
150 with handler(verify
=False) as rh
:
151 with pytest
.raises(SSLError
, match
=r
'ssl(?:v3|/tls) alert handshake failure') as exc_info
:
152 validate_and_send(rh
, Request(self
.bad_wss_host
))
153 assert not issubclass(exc_info
.type, CertificateVerifyError
)
155 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
156 @pytest.mark.parametrize('path,expected', [
157 # Unicode characters should be encoded with uppercase percent-encoding
158 ('/中文', '/%E4%B8%AD%E6%96%87'),
159 # don't normalize existing percent encodings
160 ('/%c7%9f', '/%c7%9f'),
162 def test_percent_encode(self
, handler
, path
, expected
):
163 with handler() as rh
:
164 ws
= validate_and_send(rh
, Request(f
'{self.ws_base_url}{path}'))
166 assert ws
.recv() == expected
167 assert ws
.status
== 101
170 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
171 def test_remove_dot_segments(self
, handler
):
172 with handler() as rh
:
173 # This isn't a comprehensive test,
174 # but it should be enough to check whether the handler is removing dot segments
175 ws
= validate_and_send(rh
, Request(f
'{self.ws_base_url}/a/b/./../../test'))
176 assert ws
.status
== 101
178 assert ws
.recv() == '/test'
181 # We are restricted to known HTTP status codes in http.HTTPStatus
182 # Redirects are not supported for websockets
183 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
184 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
185 def test_raise_http_error(self
, handler
, status
):
186 with handler() as rh
:
187 with pytest
.raises(HTTPError
) as exc_info
:
188 validate_and_send(rh
, Request(f
'{self.ws_base_url}/gen_{status}'))
189 assert exc_info
.value
.status
== status
191 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
192 @pytest.mark.parametrize('params,extensions', [
193 ({'timeout': 0.00001}
, {}),
194 ({}, {'timeout': 0.00001}
),
196 def test_timeout(self
, handler
, params
, extensions
):
197 with handler(**params
) as rh
:
198 with pytest
.raises(TransportError
):
199 validate_and_send(rh
, Request(self
.ws_base_url
, extensions
=extensions
))
201 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
202 def test_cookies(self
, handler
):
203 cookiejar
= YoutubeDLCookieJar()
204 cookiejar
.set_cookie(http
.cookiejar
.Cookie(
205 version
=0, name
='test', value
='ytdlp', port
=None, port_specified
=False,
206 domain
='127.0.0.1', domain_specified
=True, domain_initial_dot
=False, path
='/',
207 path_specified
=True, secure
=False, expires
=None, discard
=False, comment
=None,
208 comment_url
=None, rest
={}))
210 with handler(cookiejar
=cookiejar
) as rh
:
211 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
213 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
216 with handler() as rh
:
217 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
219 assert 'cookie' not in json
.loads(ws
.recv())
222 ws
= validate_and_send(rh
, Request(self
.ws_base_url
, extensions
={'cookiejar': cookiejar}
))
224 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
227 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
228 def test_source_address(self
, handler
):
229 source_address
= f
'127.0.0.{random.randint(5, 255)}'
230 with handler(source_address
=source_address
) as rh
:
231 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
232 ws
.send('source_address')
233 assert source_address
== ws
.recv()
236 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
237 def test_response_url(self
, handler
):
238 with handler() as rh
:
239 url
= f
'{self.ws_base_url}/something'
240 ws
= validate_and_send(rh
, Request(url
))
244 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
245 def test_request_headers(self
, handler
):
246 with handler(headers
=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'}
)) as rh
:
248 ws
= validate_and_send(rh
, Request(self
.ws_base_url
))
250 headers
= HTTPHeaderDict(json
.loads(ws
.recv()))
251 assert headers
['test1'] == 'test'
254 # Per request headers, merged with global
255 ws
= validate_and_send(rh
, Request(
256 self
.ws_base_url
, headers
={'test2': 'changed', 'test3': 'test3'}
))
258 headers
= HTTPHeaderDict(json
.loads(ws
.recv()))
259 assert headers
['test1'] == 'test'
260 assert headers
['test2'] == 'changed'
261 assert headers
['test3'] == 'test3'
264 @pytest.mark.parametrize('client_cert', (
265 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}
,
267 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'client.crt'),
268 'client_certificate_key': os
.path
.join(MTLS_CERT_DIR
, 'client.key'),
271 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'clientwithencryptedkey.crt'),
272 'client_certificate_password': 'foobar',
275 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'client.crt'),
276 'client_certificate_key': os
.path
.join(MTLS_CERT_DIR
, 'clientencrypted.key'),
277 'client_certificate_password': 'foobar',
280 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
281 def test_mtls(self
, handler
, client_cert
):
283 # Disable client-side validation of unacceptable self-signed testcert.pem
284 # The test is of a check on the server side, so unaffected
286 client_cert
=client_cert
288 validate_and_send(rh
, Request(self
.mtls_wss_base_url
)).close()
291 def create_fake_ws_connection(raised
):
292 import websockets
.sync
.client
294 class FakeWsConnection(websockets
.sync
.client
.ClientConnection
):
295 def __init__(self
, *args
, **kwargs
):
300 reason_phrase
= 'test'
302 self
.response
= FakeResponse()
304 def send(self
, *args
, **kwargs
):
307 def recv(self
, *args
, **kwargs
):
310 def close(self
, *args
, **kwargs
):
313 return FakeWsConnection()
316 @pytest.mark.parametrize('handler', ['Websockets'], indirect
=True)
317 class TestWebsocketsRequestHandler
:
318 @pytest.mark.parametrize('raised,expected', [
319 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
320 (lambda: websockets
.exceptions
.InvalidURI(msg
='test', uri
='test://'), RequestError
),
321 # Requires a response object. Should be covered by HTTP error tests.
322 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
323 (lambda: websockets
.exceptions
.InvalidHandshake(), TransportError
),
324 # These are subclasses of InvalidHandshake
325 (lambda: websockets
.exceptions
.InvalidHeader(name
='test'), TransportError
),
326 (lambda: websockets
.exceptions
.NegotiationError(), TransportError
),
328 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
),
329 (lambda: TimeoutError(), TransportError
),
330 # These may be raised by our create_connection implementation, which should also be caught
331 (lambda: OSError(), TransportError
),
332 (lambda: ssl
.SSLError(), SSLError
),
333 (lambda: ssl
.SSLCertVerificationError(), CertificateVerifyError
),
334 (lambda: socks
.ProxyError(), ProxyError
),
336 def test_request_error_mapping(self
, handler
, monkeypatch
, raised
, expected
):
337 import websockets
.sync
.client
339 import yt_dlp
.networking
._websockets
340 with handler() as rh
:
341 def fake_connect(*args
, **kwargs
):
343 monkeypatch
.setattr(yt_dlp
.networking
._websockets
, 'create_connection', lambda *args
, **kwargs
: None)
344 monkeypatch
.setattr(websockets
.sync
.client
, 'connect', fake_connect
)
345 with pytest
.raises(expected
) as exc_info
:
346 rh
.send(Request('ws://fake-url'))
347 assert exc_info
.type is expected
349 @pytest.mark.parametrize('raised,expected,match', [
350 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
351 (lambda: websockets
.exceptions
.ConnectionClosed(None, None), TransportError
, None),
352 (lambda: RuntimeError(), TransportError
, None),
353 (lambda: TimeoutError(), TransportError
, None),
354 (lambda: TypeError(), RequestError
, None),
355 (lambda: socks
.ProxyError(), ProxyError
, None),
357 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
, None),
359 def test_ws_send_error_mapping(self
, handler
, monkeypatch
, raised
, expected
, match
):
360 from yt_dlp
.networking
._websockets
import WebsocketsResponseAdapter
361 ws
= WebsocketsResponseAdapter(create_fake_ws_connection(raised
), url
='ws://fake-url')
362 with pytest
.raises(expected
, match
=match
) as exc_info
:
364 assert exc_info
.type is expected
366 @pytest.mark.parametrize('raised,expected,match', [
367 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
368 (lambda: websockets
.exceptions
.ConnectionClosed(None, None), TransportError
, None),
369 (lambda: RuntimeError(), TransportError
, None),
370 (lambda: TimeoutError(), TransportError
, None),
371 (lambda: socks
.ProxyError(), ProxyError
, None),
373 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
, None),
375 def test_ws_recv_error_mapping(self
, handler
, monkeypatch
, raised
, expected
, match
):
376 from yt_dlp
.networking
._websockets
import WebsocketsResponseAdapter
377 ws
= WebsocketsResponseAdapter(create_fake_ws_connection(raised
), url
='ws://fake-url')
378 with pytest
.raises(expected
, match
=match
) as exc_info
:
380 assert exc_info
.type is expected