]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils.py
Use certificates from `certifi` if installed (#3115)
[yt-dlp.git] / yt_dlp / utils.py
index 012a115baf218870b6dfccd7f60799f7398b4ae7..a08dc3c1172b9a8b4e392cc4e7e74882c6b539cd 100644 (file)
@@ -4,6 +4,7 @@
 from __future__ import unicode_literals
 
 import asyncio
+import atexit
 import base64
 import binascii
 import calendar
@@ -46,6 +47,7 @@
     compat_HTMLParser,
     compat_HTTPError,
     compat_basestring,
+    compat_brotli,
     compat_chr,
     compat_cookiejar,
     compat_ctypes_WINFUNCTYPE,
     sockssocket,
 )
 
+try:
+    import certifi
+    has_certifi = True
+except ImportError:
+    has_certifi = False
+
 
 def register_socks_protocols():
     # "Register" SOCKS protocols
@@ -142,10 +150,16 @@ def random_user_agent():
     return _USER_AGENT_TPL % random.choice(_CHROME_VERSIONS)
 
 
+SUPPORTED_ENCODINGS = [
+    'gzip', 'deflate'
+]
+if compat_brotli:
+    SUPPORTED_ENCODINGS.append('br')
+
 std_headers = {
     'User-Agent': random_user_agent(),
     'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
-    'Accept-Encoding': 'gzip, deflate',
+    'Accept-Encoding': ', '.join(SUPPORTED_ENCODINGS),
     'Accept-Language': 'en-us,en;q=0.5',
     'Sec-Fetch-Mode': 'navigate',
 }
@@ -1002,27 +1016,30 @@ def make_HTTPS_handler(params, **kwargs):
         context.options |= 4  # SSL_OP_LEGACY_SERVER_CONNECT
     context.verify_mode = ssl.CERT_REQUIRED if opts_check_certificate else ssl.CERT_NONE
     if opts_check_certificate:
-        try:
-            context.load_default_certs()
-            # Work around the issue in load_default_certs when there are bad certificates. See:
-            # https://github.com/yt-dlp/yt-dlp/issues/1060,
-            # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312
-        except ssl.SSLError:
-            # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
-            if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
-                # Create a new context to discard any certificates that were already loaded
-                context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
-                context.check_hostname, context.verify_mode = True, ssl.CERT_REQUIRED
-                for storename in ('CA', 'ROOT'):
-                    _ssl_load_windows_store_certs(context, storename)
-            context.set_default_verify_paths()
+        if has_certifi and 'no-certifi' not in params.get('compat_opts', []):
+            context.load_verify_locations(cafile=certifi.where())
+        else:
+            try:
+                context.load_default_certs()
+                # Work around the issue in load_default_certs when there are bad certificates. See:
+                # https://github.com/yt-dlp/yt-dlp/issues/1060,
+                # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312
+            except ssl.SSLError:
+                # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
+                if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
+                    # Create a new context to discard any certificates that were already loaded
+                    context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+                    context.check_hostname, context.verify_mode = True, ssl.CERT_REQUIRED
+                    for storename in ('CA', 'ROOT'):
+                        _ssl_load_windows_store_certs(context, storename)
+                context.set_default_verify_paths()
     return YoutubeDLHTTPSHandler(params, context=context, **kwargs)
 
 
 def bug_reports_message(before=';'):
     msg = ('please report this issue on  https://github.com/yt-dlp/yt-dlp , '
-           'filling out the "Broken site" issue template properly. '
-           'Confirm you are on the latest version using -U')
+           'filling out the appropriate issue template. '
+           'Confirm you are on the latest version using  yt-dlp -U')
 
     before = before.rstrip()
     if not before or before.endswith(('.', '!', '?')):
@@ -1059,7 +1076,7 @@ def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=N
         if sys.exc_info()[0] in network_exceptions:
             expected = True
 
-        self.msg = str(msg)
+        self.orig_msg = str(msg)
         self.traceback = tb
         self.expected = expected
         self.cause = cause
@@ -1070,14 +1087,15 @@ def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=N
         super(ExtractorError, self).__init__(''.join((
             format_field(ie, template='[%s] '),
             format_field(video_id, template='%s: '),
-            self.msg,
+            msg,
             format_field(cause, template=' (caused by %r)'),
             '' if expected else bug_reports_message())))
 
     def format_traceback(self):
-        if self.traceback is None:
-            return None
-        return ''.join(traceback.format_tb(self.traceback))
+        return join_nonempty(
+            self.traceback and ''.join(traceback.format_tb(self.traceback)),
+            self.cause and ''.join(traceback.format_exception(None, self.cause, self.cause.__traceback__)[1:]),
+            delim='\n') or None
 
 
 class UnsupportedError(ExtractorError):
@@ -1355,6 +1373,12 @@ def deflate(data):
         except zlib.error:
             return zlib.decompress(data)
 
+    @staticmethod
+    def brotli(data):
+        if not data:
+            return data
+        return compat_brotli.decompress(data)
+
     def http_request(self, req):
         # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not
         # always respected by websites, some tend to give out URLs with non percent-encoded
@@ -1371,7 +1395,7 @@ def http_request(self, req):
         if url != url_escaped:
             req = update_Request(req, url=url_escaped)
 
-        for h, v in std_headers.items():
+        for h, v in self._params.get('http_headers', std_headers).items():
             # Capitalize is needed because of Python bug 2275: http://bugs.python.org/issue2275
             # The dict keys are capitalized because of this bug by urllib
             if h.capitalize() not in req.headers:
@@ -1415,6 +1439,12 @@ def http_response(self, req, resp):
             resp = compat_urllib_request.addinfourl(gz, old_resp.headers, old_resp.url, old_resp.code)
             resp.msg = old_resp.msg
             del resp.headers['Content-encoding']
+        # brotli
+        if resp.headers.get('Content-encoding', '') == 'br':
+            resp = compat_urllib_request.addinfourl(
+                io.BytesIO(self.brotli(resp.read())), old_resp.headers, old_resp.url, old_resp.code)
+            resp.msg = old_resp.msg
+            del resp.headers['Content-encoding']
         # Percent-encode redirect URL of Location HTTP header to satisfy RFC 3986 (see
         # https://github.com/ytdl-org/youtube-dl/issues/6457).
         if 300 <= resp.code < 400:
@@ -2121,37 +2151,47 @@ class OVERLAPPED(ctypes.Structure):
     whole_low = 0xffffffff
     whole_high = 0x7fffffff
 
-    def _lock_file(f, exclusive, block):  # todo: block unused on win32
+    def _lock_file(f, exclusive, block):
         overlapped = OVERLAPPED()
         overlapped.Offset = 0
         overlapped.OffsetHigh = 0
         overlapped.hEvent = 0
         f._lock_file_overlapped_p = ctypes.pointer(overlapped)
-        handle = msvcrt.get_osfhandle(f.fileno())
-        if not LockFileEx(handle, 0x2 if exclusive else 0x0, 0,
-                          whole_low, whole_high, f._lock_file_overlapped_p):
-            raise OSError('Locking file failed: %r' % ctypes.FormatError())
+
+        if not LockFileEx(msvcrt.get_osfhandle(f.fileno()),
+                          (0x2 if exclusive else 0x0) | (0x0 if block else 0x1),
+                          0, whole_low, whole_high, f._lock_file_overlapped_p):
+            raise BlockingIOError('Locking file failed: %r' % ctypes.FormatError())
 
     def _unlock_file(f):
         assert f._lock_file_overlapped_p
         handle = msvcrt.get_osfhandle(f.fileno())
-        if not UnlockFileEx(handle, 0,
-                            whole_low, whole_high, f._lock_file_overlapped_p):
+        if not UnlockFileEx(handle, 0, whole_low, whole_high, f._lock_file_overlapped_p):
             raise OSError('Unlocking file failed: %r' % ctypes.FormatError())
 
 else:
-    # Some platforms, such as Jython, is missing fcntl
     try:
         import fcntl
 
         def _lock_file(f, exclusive, block):
-            fcntl.flock(f,
-                        fcntl.LOCK_SH if not exclusive
-                        else fcntl.LOCK_EX if block
-                        else fcntl.LOCK_EX | fcntl.LOCK_NB)
+            try:
+                fcntl.flock(f,
+                            fcntl.LOCK_SH if not exclusive
+                            else fcntl.LOCK_EX if block
+                            else fcntl.LOCK_EX | fcntl.LOCK_NB)
+            except BlockingIOError:
+                raise
+            except OSError:  # AOSP does not have flock()
+                fcntl.lockf(f,
+                            fcntl.LOCK_SH if not exclusive
+                            else fcntl.LOCK_EX if block
+                            else fcntl.LOCK_EX | fcntl.LOCK_NB)
 
         def _unlock_file(f):
-            fcntl.flock(f, fcntl.LOCK_UN)
+            try:
+                fcntl.flock(f, fcntl.LOCK_UN)
+            except OSError:
+                fcntl.lockf(f, fcntl.LOCK_UN)
 
     except ImportError:
         UNSUPPORTED_MSG = 'file locking is not supported on this platform'
@@ -2164,6 +2204,8 @@ def _unlock_file(f):
 
 
 class locked_file(object):
+    _closed = False
+
     def __init__(self, filename, mode, block=True, encoding=None):
         assert mode in ['r', 'rb', 'a', 'ab', 'w', 'wb']
         self.f = io.open(filename, mode, encoding=encoding)
@@ -2181,9 +2223,11 @@ def __enter__(self):
 
     def __exit__(self, etype, value, traceback):
         try:
-            _unlock_file(self.f)
+            if not self._closed:
+                _unlock_file(self.f)
         finally:
             self.f.close()
+            self._closed = True
 
     def __iter__(self):
         return iter(self.f)
@@ -2242,10 +2286,11 @@ def unsmuggle_url(smug_url, default=None):
 def format_decimal_suffix(num, fmt='%d%s', *, factor=1000):
     """ Formats numbers with decimal sufixes like K, M, etc """
     num, factor = float_or_none(num), float(factor)
-    if num is None:
+    if num is None or num < 0:
         return None
-    exponent = 0 if num == 0 else int(math.log(num, factor))
-    suffix = ['', *'kMGTPEZY'][exponent]
+    POSSIBLE_SUFFIXES = 'kMGTPEZY'
+    exponent = 0 if num == 0 else min(int(math.log(num, factor)), len(POSSIBLE_SUFFIXES))
+    suffix = ['', *POSSIBLE_SUFFIXES][exponent]
     if factor == 1024:
         suffix = {'k': 'Ki', '': ''}.get(suffix, f'{suffix}i')
     converted = num / (factor ** exponent)
@@ -2798,13 +2843,14 @@ def __len__(self):
     def __init__(self, pagefunc, pagesize, use_cache=True):
         self._pagefunc = pagefunc
         self._pagesize = pagesize
+        self._pagecount = float('inf')
         self._use_cache = use_cache
         self._cache = {}
 
     def getpage(self, pagenum):
         page_results = self._cache.get(pagenum)
         if page_results is None:
-            page_results = list(self._pagefunc(pagenum))
+            page_results = [] if pagenum > self._pagecount else list(self._pagefunc(pagenum))
         if self._use_cache:
             self._cache[pagenum] = page_results
         return page_results
@@ -2816,7 +2862,7 @@ def _getslice(self, start, end):
         raise NotImplementedError('This method must be implemented by subclasses')
 
     def __getitem__(self, idx):
-        # NOTE: cache must be enabled if this is used
+        assert self._use_cache, 'Indexing PagedList requires cache'
         if not isinstance(idx, int) or idx < 0:
             raise TypeError('indices must be non-negative integers')
         entries = self.getslice(idx, idx + 1)
@@ -2842,7 +2888,11 @@ def _getslice(self, start, end):
                 if (end is not None and firstid <= end <= nextfirstid)
                 else None)
 
-            page_results = self.getpage(pagenum)
+            try:
+                page_results = self.getpage(pagenum)
+            except Exception:
+                self._pagecount = pagenum - 1
+                raise
             if startv != 0 or endv is not None:
                 page_results = page_results[startv:endv]
             yield from page_results
@@ -2862,8 +2912,8 @@ def _getslice(self, start, end):
 
 class InAdvancePagedList(PagedList):
     def __init__(self, pagefunc, pagecount, pagesize):
-        self._pagecount = pagecount
         PagedList.__init__(self, pagefunc, pagesize, True)
+        self._pagecount = pagecount
 
     def _getslice(self, start, end):
         start_page = start // self._pagesize
@@ -3465,7 +3515,7 @@ def filter_using_list(row, filterArray):
     extra_gap += 1
     if delim:
         table = [header_row, [delim * (ml + extra_gap) for ml in max_lens]] + data
-        table[1][-1] = table[1][-1][:-extra_gap]  # Remove extra_gap from end of delimiter
+        table[1][-1] = table[1][-1][:-extra_gap * len(delim)]  # Remove extra_gap from end of delimiter
     for row in table:
         for pos, text in enumerate(map(str, row)):
             if '\t' in text:
@@ -3563,6 +3613,9 @@ def match_str(filter_str, dct, incomplete=False):
 
 
 def match_filter_func(filter_str):
+    if filter_str is None:
+        return None
+
     def _match_func(info_dict, *args, **kwargs):
         if match_str(filter_str, info_dict, *args, **kwargs):
             return None
@@ -5175,6 +5228,10 @@ def traverse_dict(dictn, keys, casesense=True):
     return traverse_obj(dictn, keys, casesense=casesense, is_user_input=True, traverse_string=True)
 
 
+def get_first(obj, keys, **kwargs):
+    return traverse_obj(obj, (..., *variadic(keys)), **kwargs, get_all=False)
+
+
 def variadic(x, allowed_types=(str, bytes, dict)):
     return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,)
 
@@ -5251,6 +5308,38 @@ def join_nonempty(*values, delim='-', from_dict=None):
     return delim.join(map(str, filter(None, values)))
 
 
+def scale_thumbnails_to_max_format_width(formats, thumbnails, url_width_re):
+    """
+    Find the largest format dimensions in terms of video width and, for each thumbnail:
+    * Modify the URL: Match the width with the provided regex and replace with the former width
+    * Update dimensions
+
+    This function is useful with video services that scale the provided thumbnails on demand
+    """
+    _keys = ('width', 'height')
+    max_dimensions = max(
+        [tuple(format.get(k) or 0 for k in _keys) for format in formats],
+        default=(0, 0))
+    if not max_dimensions[0]:
+        return thumbnails
+    return [
+        merge_dicts(
+            {'url': re.sub(url_width_re, str(max_dimensions[0]), thumbnail['url'])},
+            dict(zip(_keys, max_dimensions)), thumbnail)
+        for thumbnail in thumbnails
+    ]
+
+
+def parse_http_range(range):
+    """ Parse value of "Range" or "Content-Range" HTTP header into tuple. """
+    if not range:
+        return None, None, None
+    crg = re.search(r'bytes[ =](\d+)-(\d+)?(?:/(\d+))?', range)
+    if not crg:
+        return None, None, None
+    return int(crg.group(1)), int_or_none(crg.group(2)), int_or_none(crg.group(3))
+
+
 class Config:
     own_args = None
     filename = None
@@ -5348,6 +5437,7 @@ def __init__(self, url, headers=None):
         self.conn = compat_websockets.connect(
             url, extra_headers=headers, ping_interval=None,
             close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
+        atexit.register(self.__exit__, None, None, None)
 
     def __enter__(self):
         self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
@@ -5364,7 +5454,7 @@ def __exit__(self, type, value, traceback):
             return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
         finally:
             self.loop.close()
-            self.r_cancel_all_tasks(self.loop)
+            self._cancel_all_tasks(self.loop)
 
     # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
     # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
@@ -5405,3 +5495,8 @@ def _cancel_all_tasks(loop):
 
 
 has_websockets = bool(compat_websockets)
+
+
+def merge_headers(*dicts):
+    """Merge dicts of http headers case insensitively, prioritizing the latter ones"""
+    return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}