]> jfr.im git - yt-dlp.git/commitdiff
[core] Support decoding multiple content encodings (#7142)
authorcoletdjnz <redacted>
Sat, 27 May 2023 10:40:05 +0000 (22:40 +1200)
committerGitHub <redacted>
Sat, 27 May 2023 10:40:05 +0000 (10:40 +0000)
Authored by: coletdjnz

test/test_http.py
yt_dlp/utils/_utils.py

index d684905da59585d98c905a096a84278d0cc4b34d..3941a6e7762e4f23c55f22793b7e0ec755d38ec0 100644 (file)
 import threading
 import urllib.error
 import urllib.request
+import zlib
 
 from test.helper import http_server_port
 from yt_dlp import YoutubeDL
+from yt_dlp.dependencies import brotli
 from yt_dlp.utils import sanitized_Request, urlencode_postdata
 
 from .helper import FakeYDL
@@ -148,6 +150,31 @@ def do_GET(self):
             self.send_header('Location', new_url)
             self.send_header('Content-Length', '0')
             self.end_headers()
+        elif self.path == '/content-encoding':
+            encodings = self.headers.get('ytdl-encoding', '')
+            payload = b'<html><video src="/vid.mp4" /></html>'
+            for encoding in filter(None, (e.strip() for e in encodings.split(','))):
+                if encoding == 'br' and brotli:
+                    payload = brotli.compress(payload)
+                elif encoding == 'gzip':
+                    buf = io.BytesIO()
+                    with gzip.GzipFile(fileobj=buf, mode='wb') as f:
+                        f.write(payload)
+                    payload = buf.getvalue()
+                elif encoding == 'deflate':
+                    payload = zlib.compress(payload)
+                elif encoding == 'unsupported':
+                    payload = b'raw'
+                    break
+                else:
+                    self._status(415)
+                    return
+            self.send_response(200)
+            self.send_header('Content-Encoding', encodings)
+            self.send_header('Content-Length', str(len(payload)))
+            self.end_headers()
+            self.wfile.write(payload)
+
         else:
             self._status(404)
 
@@ -302,6 +329,55 @@ def test_gzip_trailing_garbage(self):
             data = ydl.urlopen(sanitized_Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode('utf-8')
             self.assertEqual(data, '<html><video src="/vid.mp4" /></html>')
 
+    @unittest.skipUnless(brotli, 'brotli support is not installed')
+    def test_brotli(self):
+        with FakeYDL() as ydl:
+            res = ydl.urlopen(
+                sanitized_Request(
+                    f'http://127.0.0.1:{self.http_port}/content-encoding',
+                    headers={'ytdl-encoding': 'br'}))
+            self.assertEqual(res.headers.get('Content-Encoding'), 'br')
+            self.assertEqual(res.read(), b'<html><video src="/vid.mp4" /></html>')
+
+    def test_deflate(self):
+        with FakeYDL() as ydl:
+            res = ydl.urlopen(
+                sanitized_Request(
+                    f'http://127.0.0.1:{self.http_port}/content-encoding',
+                    headers={'ytdl-encoding': 'deflate'}))
+            self.assertEqual(res.headers.get('Content-Encoding'), 'deflate')
+            self.assertEqual(res.read(), b'<html><video src="/vid.mp4" /></html>')
+
+    def test_gzip(self):
+        with FakeYDL() as ydl:
+            res = ydl.urlopen(
+                sanitized_Request(
+                    f'http://127.0.0.1:{self.http_port}/content-encoding',
+                    headers={'ytdl-encoding': 'gzip'}))
+            self.assertEqual(res.headers.get('Content-Encoding'), 'gzip')
+            self.assertEqual(res.read(), b'<html><video src="/vid.mp4" /></html>')
+
+    def test_multiple_encodings(self):
+        # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.4
+        with FakeYDL() as ydl:
+            for pair in ('gzip,deflate', 'deflate, gzip', 'gzip, gzip', 'deflate, deflate'):
+                res = ydl.urlopen(
+                    sanitized_Request(
+                        f'http://127.0.0.1:{self.http_port}/content-encoding',
+                        headers={'ytdl-encoding': pair}))
+                self.assertEqual(res.headers.get('Content-Encoding'), pair)
+                self.assertEqual(res.read(), b'<html><video src="/vid.mp4" /></html>')
+
+    def test_unsupported_encoding(self):
+        # it should return the raw content
+        with FakeYDL() as ydl:
+            res = ydl.urlopen(
+                sanitized_Request(
+                    f'http://127.0.0.1:{self.http_port}/content-encoding',
+                    headers={'ytdl-encoding': 'unsupported'}))
+            self.assertEqual(res.headers.get('Content-Encoding'), 'unsupported')
+            self.assertEqual(res.read(), b'raw')
+
 
 class TestClientCert(unittest.TestCase):
     def setUp(self):
index 6f4f22bb315efb46c72660bc3579314ff226a894..7c91faff8671c679e2abdb64dd8dec102722c8b8 100644 (file)
@@ -1361,6 +1361,23 @@ def brotli(data):
             return data
         return brotli.decompress(data)
 
+    @staticmethod
+    def gz(data):
+        gz = gzip.GzipFile(fileobj=io.BytesIO(data), mode='rb')
+        try:
+            return gz.read()
+        except OSError as original_oserror:
+            # There may be junk add the end of the file
+            # See http://stackoverflow.com/q/4928560/35070 for details
+            for i in range(1, 1024):
+                try:
+                    gz = gzip.GzipFile(fileobj=io.BytesIO(data[:-i]), mode='rb')
+                    return gz.read()
+                except OSError:
+                    continue
+            else:
+                raise original_oserror
+
     def http_request(self, req):
         # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not
         # always respected by websites, some tend to give out URLs with non percent-encoded
@@ -1394,35 +1411,21 @@ def http_request(self, req):
 
     def http_response(self, req, resp):
         old_resp = resp
-        # gzip
-        if resp.headers.get('Content-encoding', '') == 'gzip':
-            content = resp.read()
-            gz = gzip.GzipFile(fileobj=io.BytesIO(content), mode='rb')
-            try:
-                uncompressed = io.BytesIO(gz.read())
-            except OSError as original_ioerror:
-                # There may be junk add the end of the file
-                # See http://stackoverflow.com/q/4928560/35070 for details
-                for i in range(1, 1024):
-                    try:
-                        gz = gzip.GzipFile(fileobj=io.BytesIO(content[:-i]), mode='rb')
-                        uncompressed = io.BytesIO(gz.read())
-                    except OSError:
-                        continue
-                    break
-                else:
-                    raise original_ioerror
-            resp = urllib.request.addinfourl(uncompressed, old_resp.headers, old_resp.url, old_resp.code)
-            resp.msg = old_resp.msg
-        # deflate
-        if resp.headers.get('Content-encoding', '') == 'deflate':
-            gz = io.BytesIO(self.deflate(resp.read()))
-            resp = urllib.request.addinfourl(gz, old_resp.headers, old_resp.url, old_resp.code)
-            resp.msg = old_resp.msg
-        # brotli
-        if resp.headers.get('Content-encoding', '') == 'br':
-            resp = urllib.request.addinfourl(
-                io.BytesIO(self.brotli(resp.read())), old_resp.headers, old_resp.url, old_resp.code)
+
+        # Content-Encoding header lists the encodings in order that they were applied [1].
+        # To decompress, we simply do the reverse.
+        # [1]: https://datatracker.ietf.org/doc/html/rfc9110#name-content-encoding
+        decoded_response = None
+        for encoding in (e.strip() for e in reversed(resp.headers.get('Content-encoding', '').split(','))):
+            if encoding == 'gzip':
+                decoded_response = self.gz(decoded_response or resp.read())
+            elif encoding == 'deflate':
+                decoded_response = self.deflate(decoded_response or resp.read())
+            elif encoding == 'br' and brotli:
+                decoded_response = self.brotli(decoded_response or resp.read())
+
+        if decoded_response is not None:
+            resp = urllib.request.addinfourl(io.BytesIO(decoded_response), old_resp.headers, old_resp.url, old_resp.code)
             resp.msg = old_resp.msg
         # Percent-encode redirect URL of Location HTTP header to satisfy RFC 3986 (see
         # https://github.com/ytdl-org/youtube-dl/issues/6457).