]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils.py
[utils] Add `get_first`
[yt-dlp.git] / yt_dlp / utils.py
index be0c69d8fb30a75830b72baf2e850c6f3bd3b7b6..9b130e109b1c535008491226f8e7589e1f52d4a0 100644 (file)
@@ -47,6 +47,7 @@
     compat_HTMLParser,
     compat_HTTPError,
     compat_basestring,
+    compat_brotli,
     compat_chr,
     compat_cookiejar,
     compat_ctypes_WINFUNCTYPE,
@@ -143,10 +144,16 @@ def random_user_agent():
     return _USER_AGENT_TPL % random.choice(_CHROME_VERSIONS)
 
 
+SUPPORTED_ENCODINGS = [
+    'gzip', 'deflate'
+]
+if compat_brotli:
+    SUPPORTED_ENCODINGS.append('br')
+
 std_headers = {
     'User-Agent': random_user_agent(),
     'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
-    'Accept-Encoding': 'gzip, deflate',
+    'Accept-Encoding': ', '.join(SUPPORTED_ENCODINGS),
     'Accept-Language': 'en-us,en;q=0.5',
     'Sec-Fetch-Mode': 'navigate',
 }
@@ -1023,7 +1030,7 @@ def make_HTTPS_handler(params, **kwargs):
 def bug_reports_message(before=';'):
     msg = ('please report this issue on  https://github.com/yt-dlp/yt-dlp , '
            'filling out the "Broken site" issue template properly. '
-           'Confirm you are on the latest version using -U')
+           'Confirm you are on the latest version using  yt-dlp -U')
 
     before = before.rstrip()
     if not before or before.endswith(('.', '!', '?')):
@@ -1076,9 +1083,10 @@ def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=N
             '' if expected else bug_reports_message())))
 
     def format_traceback(self):
-        if self.traceback is None:
-            return None
-        return ''.join(traceback.format_tb(self.traceback))
+        return join_nonempty(
+            self.traceback and ''.join(traceback.format_tb(self.traceback)),
+            self.cause and ''.join(traceback.format_exception(self.cause)[1:]),
+            delim='\n') or None
 
 
 class UnsupportedError(ExtractorError):
@@ -1356,6 +1364,12 @@ def deflate(data):
         except zlib.error:
             return zlib.decompress(data)
 
+    @staticmethod
+    def brotli(data):
+        if not data:
+            return data
+        return compat_brotli.decompress(data)
+
     def http_request(self, req):
         # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not
         # always respected by websites, some tend to give out URLs with non percent-encoded
@@ -1416,6 +1430,12 @@ def http_response(self, req, resp):
             resp = compat_urllib_request.addinfourl(gz, old_resp.headers, old_resp.url, old_resp.code)
             resp.msg = old_resp.msg
             del resp.headers['Content-encoding']
+        # brotli
+        if resp.headers.get('Content-encoding', '') == 'br':
+            resp = compat_urllib_request.addinfourl(
+                io.BytesIO(self.brotli(resp.read())), old_resp.headers, old_resp.url, old_resp.code)
+            resp.msg = old_resp.msg
+            del resp.headers['Content-encoding']
         # Percent-encode redirect URL of Location HTTP header to satisfy RFC 3986 (see
         # https://github.com/ytdl-org/youtube-dl/issues/6457).
         if 300 <= resp.code < 400:
@@ -3485,7 +3505,7 @@ def filter_using_list(row, filterArray):
     extra_gap += 1
     if delim:
         table = [header_row, [delim * (ml + extra_gap) for ml in max_lens]] + data
-        table[1][-1] = table[1][-1][:-extra_gap]  # Remove extra_gap from end of delimiter
+        table[1][-1] = table[1][-1][:-extra_gap * len(delim)]  # Remove extra_gap from end of delimiter
     for row in table:
         for pos, text in enumerate(map(str, row)):
             if '\t' in text:
@@ -3583,6 +3603,9 @@ def match_str(filter_str, dct, incomplete=False):
 
 
 def match_filter_func(filter_str):
+    if filter_str is None:
+        return None
+
     def _match_func(info_dict, *args, **kwargs):
         if match_str(filter_str, info_dict, *args, **kwargs):
             return None
@@ -5195,6 +5218,10 @@ def traverse_dict(dictn, keys, casesense=True):
     return traverse_obj(dictn, keys, casesense=casesense, is_user_input=True, traverse_string=True)
 
 
+def get_first(obj, keys, **kwargs):
+    return traverse_obj(obj, (..., *variadic(keys)), **kwargs, get_all=False)
+
+
 def variadic(x, allowed_types=(str, bytes, dict)):
     return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,)
 
@@ -5271,6 +5298,28 @@ def join_nonempty(*values, delim='-', from_dict=None):
     return delim.join(map(str, filter(None, values)))
 
 
+def scale_thumbnails_to_max_format_width(formats, thumbnails, url_width_re):
+    """
+    Find the largest format dimensions in terms of video width and, for each thumbnail:
+    * Modify the URL: Match the width with the provided regex and replace with the former width
+    * Update dimensions
+
+    This function is useful with video services that scale the provided thumbnails on demand
+    """
+    _keys = ('width', 'height')
+    max_dimensions = max(
+        [tuple(format.get(k) or 0 for k in _keys) for format in formats],
+        default=(0, 0))
+    if not max_dimensions[0]:
+        return thumbnails
+    return [
+        merge_dicts(
+            {'url': re.sub(url_width_re, str(max_dimensions[0]), thumbnail['url'])},
+            dict(zip(_keys, max_dimensions)), thumbnail)
+        for thumbnail in thumbnails
+    ]
+
+
 def parse_http_range(range):
     """ Parse value of "Range" or "Content-Range" HTTP header into tuple. """
     if not range:
@@ -5439,5 +5488,5 @@ def _cancel_all_tasks(loop):
 
 
 def merge_headers(*dicts):
-    """Merge dicts of network headers case insensitively, prioritizing the latter ones"""
+    """Merge dicts of http headers case insensitively, prioritizing the latter ones"""
     return {k.capitalize(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}