2 # Allow direct execution
7 sys
.path
.insert(0, os
.path
.dirname(os
.path
.dirname(os
.path
.abspath(__file__
))))
11 from test
.helper
import http_server_port
13 from yt_dlp
import YoutubeDL
14 from yt_dlp
.compat
import compat_http_server
, compat_urllib_request
16 TEST_DIR
= os
.path
.dirname(os
.path
.abspath(__file__
))
19 class HTTPTestRequestHandler(compat_http_server
.BaseHTTPRequestHandler
):
20 def log_message(self
, format
, *args
):
24 if self
.path
== '/video.html':
25 self
.send_response(200)
26 self
.send_header('Content-Type', 'text/html; charset=utf-8')
28 self
.wfile
.write(b
'<html><video src="/vid.mp4" /></html>')
29 elif self
.path
== '/vid.mp4':
30 self
.send_response(200)
31 self
.send_header('Content-Type', 'video/mp4')
33 self
.wfile
.write(b
'\x00\x00\x00\x00\x20\x66\x74[video]')
34 elif self
.path
== '/%E4%B8%AD%E6%96%87.html':
35 self
.send_response(200)
36 self
.send_header('Content-Type', 'text/html; charset=utf-8')
38 self
.wfile
.write(b
'<html><video src="/vid.mp4" /></html>')
47 def warning(self
, msg
):
54 class TestHTTP(unittest
.TestCase
):
56 self
.httpd
= compat_http_server
.HTTPServer(
57 ('127.0.0.1', 0), HTTPTestRequestHandler
)
58 self
.port
= http_server_port(self
.httpd
)
59 self
.server_thread
= threading
.Thread(target
=self
.httpd
.serve_forever
)
60 self
.server_thread
.daemon
= True
61 self
.server_thread
.start()
64 class TestHTTPS(unittest
.TestCase
):
66 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
67 self
.httpd
= compat_http_server
.HTTPServer(
68 ('127.0.0.1', 0), HTTPTestRequestHandler
)
69 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
70 sslctx
.load_cert_chain(certfn
, None)
71 self
.httpd
.socket
= sslctx
.wrap_socket(self
.httpd
.socket
, server_side
=True)
72 self
.port
= http_server_port(self
.httpd
)
73 self
.server_thread
= threading
.Thread(target
=self
.httpd
.serve_forever
)
74 self
.server_thread
.daemon
= True
75 self
.server_thread
.start()
77 def test_nocheckcertificate(self
):
78 ydl
= YoutubeDL({'logger': FakeLogger()}
)
81 ydl
.extract_info
, 'https://127.0.0.1:%d/video.html' % self
.port
)
83 ydl
= YoutubeDL({'logger': FakeLogger(), 'nocheckcertificate': True}
)
84 r
= ydl
.extract_info('https://127.0.0.1:%d/video.html' % self
.port
)
85 self
.assertEqual(r
['entries'][0]['url'], 'https://127.0.0.1:%d/vid.mp4' % self
.port
)
88 class TestClientCert(unittest
.TestCase
):
90 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
91 self
.certdir
= os
.path
.join(TEST_DIR
, 'testdata', 'certificate')
92 cacertfn
= os
.path
.join(self
.certdir
, 'ca.crt')
93 self
.httpd
= compat_http_server
.HTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler
)
94 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
95 sslctx
.verify_mode
= ssl
.CERT_REQUIRED
96 sslctx
.load_verify_locations(cafile
=cacertfn
)
97 sslctx
.load_cert_chain(certfn
, None)
98 self
.httpd
.socket
= sslctx
.wrap_socket(self
.httpd
.socket
, server_side
=True)
99 self
.port
= http_server_port(self
.httpd
)
100 self
.server_thread
= threading
.Thread(target
=self
.httpd
.serve_forever
)
101 self
.server_thread
.daemon
= True
102 self
.server_thread
.start()
104 def _run_test(self
, **params
):
106 'logger': FakeLogger(),
107 # Disable client-side validation of unacceptable self-signed testcert.pem
108 # The test is of a check on the server side, so unaffected
109 'nocheckcertificate': True,
112 r
= ydl
.extract_info('https://127.0.0.1:%d/video.html' % self
.port
)
113 self
.assertEqual(r
['entries'][0]['url'], 'https://127.0.0.1:%d/vid.mp4' % self
.port
)
115 def test_certificate_combined_nopass(self
):
116 self
._run
_test
(client_certificate
=os
.path
.join(self
.certdir
, 'clientwithkey.crt'))
118 def test_certificate_nocombined_nopass(self
):
119 self
._run
_test
(client_certificate
=os
.path
.join(self
.certdir
, 'client.crt'),
120 client_certificate_key
=os
.path
.join(self
.certdir
, 'client.key'))
122 def test_certificate_combined_pass(self
):
123 self
._run
_test
(client_certificate
=os
.path
.join(self
.certdir
, 'clientwithencryptedkey.crt'),
124 client_certificate_password
='foobar')
126 def test_certificate_nocombined_pass(self
):
127 self
._run
_test
(client_certificate
=os
.path
.join(self
.certdir
, 'client.crt'),
128 client_certificate_key
=os
.path
.join(self
.certdir
, 'clientencrypted.key'),
129 client_certificate_password
='foobar')
132 def _build_proxy_handler(name
):
133 class HTTPTestRequestHandler(compat_http_server
.BaseHTTPRequestHandler
):
136 def log_message(self
, format
, *args
):
140 self
.send_response(200)
141 self
.send_header('Content-Type', 'text/plain; charset=utf-8')
143 self
.wfile
.write(f
'{self.proxy_name}: {self.path}'.encode())
144 return HTTPTestRequestHandler
147 class TestProxy(unittest
.TestCase
):
149 self
.proxy
= compat_http_server
.HTTPServer(
150 ('127.0.0.1', 0), _build_proxy_handler('normal'))
151 self
.port
= http_server_port(self
.proxy
)
152 self
.proxy_thread
= threading
.Thread(target
=self
.proxy
.serve_forever
)
153 self
.proxy_thread
.daemon
= True
154 self
.proxy_thread
.start()
156 self
.geo_proxy
= compat_http_server
.HTTPServer(
157 ('127.0.0.1', 0), _build_proxy_handler('geo'))
158 self
.geo_port
= http_server_port(self
.geo_proxy
)
159 self
.geo_proxy_thread
= threading
.Thread(target
=self
.geo_proxy
.serve_forever
)
160 self
.geo_proxy_thread
.daemon
= True
161 self
.geo_proxy_thread
.start()
163 def test_proxy(self
):
164 geo_proxy
= f
'127.0.0.1:{self.geo_port}'
166 'proxy': f
'127.0.0.1:{self.port}',
167 'geo_verification_proxy': geo_proxy
,
169 url
= 'http://foo.com/bar'
170 response
= ydl
.urlopen(url
).read().decode()
171 self
.assertEqual(response
, f
'normal: {url}')
173 req
= compat_urllib_request
.Request(url
)
174 req
.add_header('Ytdl-request-proxy', geo_proxy
)
175 response
= ydl
.urlopen(req
).read().decode()
176 self
.assertEqual(response
, f
'geo: {url}')
178 def test_proxy_with_idn(self
):
180 'proxy': f
'127.0.0.1:{self.port}',
182 url
= 'http://中文.tw/'
183 response
= ydl
.urlopen(url
).read().decode()
184 # b'xn--fiq228c' is '中文'.encode('idna')
185 self
.assertEqual(response
, 'normal: http://xn--fiq228c.tw/')
188 if __name__
== '__main__':