]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils.py
[utils] Add `try_call`
[yt-dlp.git] / yt_dlp / utils.py
index 9b130e109b1c535008491226f8e7589e1f52d4a0..22062f85f8d118e7fb1aba8beb66b6062ee9177f 100644 (file)
     sockssocket,
 )
 
+try:
+    import certifi
+    has_certifi = True
+except ImportError:
+    has_certifi = False
+
 
 def register_socks_protocols():
     # "Register" SOCKS protocols
@@ -153,7 +159,6 @@ def random_user_agent():
 std_headers = {
     'User-Agent': random_user_agent(),
     'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
-    'Accept-Encoding': ', '.join(SUPPORTED_ENCODINGS),
     'Accept-Language': 'en-us,en;q=0.5',
     'Sec-Fetch-Mode': 'navigate',
 }
@@ -700,36 +705,40 @@ def timeconvert(timestr):
     return timestamp
 
 
-def sanitize_filename(s, restricted=False, is_id=False):
+def sanitize_filename(s, restricted=False, is_id=NO_DEFAULT):
     """Sanitizes a string so it could be used as part of a filename.
-    If restricted is set, use a stricter subset of allowed characters.
-    Set is_id if this is not an arbitrary string, but an ID that should be kept
-    if possible.
+    @param restricted   Use a stricter subset of allowed characters
+    @param is_id        Whether this is an ID that should be kept unchanged if possible.
+                        If unset, yt-dlp's new sanitization rules are in effect
     """
+    if s == '':
+        return ''
+
     def replace_insane(char):
         if restricted and char in ACCENT_CHARS:
             return ACCENT_CHARS[char]
         elif not restricted and char == '\n':
-            return ' '
+            return '\0 '
         elif char == '?' or ord(char) < 32 or ord(char) == 127:
             return ''
         elif char == '"':
             return '' if restricted else '\''
         elif char == ':':
-            return '_-' if restricted else ' -'
+            return '\0_\0-' if restricted else '\0 \0-'
         elif char in '\\/|*<>':
-            return '_'
-        if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace()):
-            return '_'
-        if restricted and ord(char) > 127:
-            return '_'
+            return '\0_'
+        if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace() or ord(char) > 127):
+            return '\0_'
         return char
 
-    if s == '':
-        return ''
-    # Handle timestamps
-    s = re.sub(r'[0-9]+(?::[0-9]+)+', lambda m: m.group(0).replace(':', '_'), s)
+    s = re.sub(r'[0-9]+(?::[0-9]+)+', lambda m: m.group(0).replace(':', '_'), s)  # Handle timestamps
     result = ''.join(map(replace_insane, s))
+    if is_id is NO_DEFAULT:
+        result = re.sub('(\0.)(?:(?=\\1)..)+', r'\1', result)  # Remove repeated substitute chars
+        STRIP_RE = '(?:\0.|[ _-])*'
+        result = re.sub(f'^\0.{STRIP_RE}|{STRIP_RE}\0.$', '', result)  # Remove substitute chars from start/end
+    result = result.replace('\0', '') or '_'
+
     if not is_id:
         while '__' in result:
             result = result.replace('__', '_')
@@ -1010,26 +1019,29 @@ def make_HTTPS_handler(params, **kwargs):
         context.options |= 4  # SSL_OP_LEGACY_SERVER_CONNECT
     context.verify_mode = ssl.CERT_REQUIRED if opts_check_certificate else ssl.CERT_NONE
     if opts_check_certificate:
-        try:
-            context.load_default_certs()
-            # Work around the issue in load_default_certs when there are bad certificates. See:
-            # https://github.com/yt-dlp/yt-dlp/issues/1060,
-            # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312
-        except ssl.SSLError:
-            # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
-            if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
-                # Create a new context to discard any certificates that were already loaded
-                context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
-                context.check_hostname, context.verify_mode = True, ssl.CERT_REQUIRED
-                for storename in ('CA', 'ROOT'):
-                    _ssl_load_windows_store_certs(context, storename)
-            context.set_default_verify_paths()
+        if has_certifi and 'no-certifi' not in params.get('compat_opts', []):
+            context.load_verify_locations(cafile=certifi.where())
+        else:
+            try:
+                context.load_default_certs()
+                # Work around the issue in load_default_certs when there are bad certificates. See:
+                # https://github.com/yt-dlp/yt-dlp/issues/1060,
+                # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312
+            except ssl.SSLError:
+                # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
+                if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
+                    # Create a new context to discard any certificates that were already loaded
+                    context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+                    context.check_hostname, context.verify_mode = True, ssl.CERT_REQUIRED
+                    for storename in ('CA', 'ROOT'):
+                        _ssl_load_windows_store_certs(context, storename)
+                context.set_default_verify_paths()
     return YoutubeDLHTTPSHandler(params, context=context, **kwargs)
 
 
 def bug_reports_message(before=';'):
     msg = ('please report this issue on  https://github.com/yt-dlp/yt-dlp , '
-           'filling out the "Broken site" issue template properly. '
+           'filling out the appropriate issue template. '
            'Confirm you are on the latest version using  yt-dlp -U')
 
     before = before.rstrip()
@@ -1085,7 +1097,7 @@ def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=N
     def format_traceback(self):
         return join_nonempty(
             self.traceback and ''.join(traceback.format_tb(self.traceback)),
-            self.cause and ''.join(traceback.format_exception(self.cause)[1:]),
+            self.cause and ''.join(traceback.format_exception(None, self.cause, self.cause.__traceback__)[1:]),
             delim='\n') or None
 
 
@@ -1392,6 +1404,9 @@ def http_request(self, req):
             if h.capitalize() not in req.headers:
                 req.add_header(h, v)
 
+        if 'Accept-encoding' not in req.headers:
+            req.add_header('Accept-encoding', ', '.join(SUPPORTED_ENCODINGS))
+
         req.headers = handle_youtubedl_headers(req.headers)
 
         if sys.version_info < (2, 7) and '#' in req.get_full_url():
@@ -2279,8 +2294,9 @@ def format_decimal_suffix(num, fmt='%d%s', *, factor=1000):
     num, factor = float_or_none(num), float(factor)
     if num is None or num < 0:
         return None
-    exponent = 0 if num == 0 else int(math.log(num, factor))
-    suffix = ['', *'kMGTPEZY'][exponent]
+    POSSIBLE_SUFFIXES = 'kMGTPEZY'
+    exponent = 0 if num == 0 else min(int(math.log(num, factor)), len(POSSIBLE_SUFFIXES))
+    suffix = ['', *POSSIBLE_SUFFIXES][exponent]
     if factor == 1024:
         suffix = {'k': 'Ki', '': ''}.get(suffix, f'{suffix}i')
     converted = num / (factor ** exponent)
@@ -2628,23 +2644,23 @@ def parse_duration(s):
         m = re.match(
             r'''(?ix)(?:P?
                 (?:
-                    [0-9]+\s*y(?:ears?)?\s*
+                    [0-9]+\s*y(?:ears?)?,?\s*
                 )?
                 (?:
-                    [0-9]+\s*m(?:onths?)?\s*
+                    [0-9]+\s*m(?:onths?)?,?\s*
                 )?
                 (?:
-                    [0-9]+\s*w(?:eeks?)?\s*
+                    [0-9]+\s*w(?:eeks?)?,?\s*
                 )?
                 (?:
-                    (?P<days>[0-9]+)\s*d(?:ays?)?\s*
+                    (?P<days>[0-9]+)\s*d(?:ays?)?,?\s*
                 )?
                 T)?
                 (?:
-                    (?P<hours>[0-9]+)\s*h(?:ours?)?\s*
+                    (?P<hours>[0-9]+)\s*h(?:ours?)?,?\s*
                 )?
                 (?:
-                    (?P<mins>[0-9]+)\s*m(?:in(?:ute)?s?)?\s*
+                    (?P<mins>[0-9]+)\s*m(?:in(?:ute)?s?)?,?\s*
                 )?
                 (?:
                     (?P<secs>[0-9]+)(?P<ms>\.[0-9]+)?\s*s(?:ec(?:ond)?s?)?\s*
@@ -2697,7 +2713,9 @@ def check_executable(exe, args=[]):
     return exe
 
 
-def _get_exe_version_output(exe, args):
+def _get_exe_version_output(exe, args, *, to_screen=None):
+    if to_screen:
+        to_screen(f'Checking exe version: {shell_quote([exe] + args)}')
     try:
         # STDIN should be redirected too. On UNIX-like systems, ffmpeg triggers
         # SIGTTOU if yt-dlp is run in the background.
@@ -3078,27 +3096,31 @@ def dict_get(d, key_or_keys, default=None, skip_false_values=True):
     return d.get(key_or_keys, default)
 
 
-def try_get(src, getter, expected_type=None):
-    for get in variadic(getter):
+def try_call(*funcs, expected_type=None, args=[], kwargs={}):
+    for f in funcs:
         try:
-            v = get(src)
-        except (AttributeError, KeyError, TypeError, IndexError):
+            val = f(*args, **kwargs)
+        except (AttributeError, KeyError, TypeError, IndexError, ZeroDivisionError):
             pass
         else:
-            if expected_type is None or isinstance(v, expected_type):
-                return v
+            if expected_type is None or isinstance(val, expected_type):
+                return val
+
+
+def try_get(src, getter, expected_type=None):
+    return try_call(*variadic(getter), args=(src,), expected_type=expected_type)
+
+
+def filter_dict(dct, cndn=lambda _, v: v is not None):
+    return {k: v for k, v in dct.items() if cndn(k, v)}
 
 
 def merge_dicts(*dicts):
     merged = {}
     for a_dict in dicts:
         for k, v in a_dict.items():
-            if v is None:
-                continue
-            if (k not in merged
-                    or (isinstance(v, compat_str) and v
-                        and isinstance(merged[k], compat_str)
-                        and not merged[k])):
+            if (v is not None and k not in merged
+                    or isinstance(v, str) and merged[k] == ''):
                 merged[k] = v
     return merged
 
@@ -3533,6 +3555,11 @@ def _match_one(filter_part, dct, incomplete):
         '=': operator.eq,
     }
 
+    if isinstance(incomplete, bool):
+        is_incomplete = lambda _: incomplete
+    else:
+        is_incomplete = lambda k: k in incomplete
+
     operator_rex = re.compile(r'''(?x)\s*
         (?P<key>[a-z_]+)
         \s*(?P<negation>!\s*)?(?P<op>%s)(?P<none_inclusive>\s*\?)?\s*
@@ -3571,7 +3598,7 @@ def _match_one(filter_part, dct, incomplete):
         if numeric_comparison is not None and m['op'] in STRING_OPERATORS:
             raise ValueError('Operator %s only supports string values!' % m['op'])
         if actual_value is None:
-            return incomplete or m['none_inclusive']
+            return is_incomplete(m['key']) or m['none_inclusive']
         return op(actual_value, comparison_value if numeric_comparison is None else numeric_comparison)
 
     UNARY_OPERATORS = {
@@ -3586,7 +3613,7 @@ def _match_one(filter_part, dct, incomplete):
     if m:
         op = UNARY_OPERATORS[m.group('op')]
         actual_value = dct.get(m.group('key'))
-        if incomplete and actual_value is None:
+        if is_incomplete(m.group('key')) and actual_value is None:
             return True
         return op(actual_value)
 
@@ -3594,24 +3621,29 @@ def _match_one(filter_part, dct, incomplete):
 
 
 def match_str(filter_str, dct, incomplete=False):
-    """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false
-        When incomplete, all conditions passes on missing fields
+    """ Filter a dictionary with a simple string syntax.
+    @returns           Whether the filter passes
+    @param incomplete  Set of keys that is expected to be missing from dct.
+                       Can be True/False to indicate all/none of the keys may be missing.
+                       All conditions on incomplete keys pass if the key is missing
     """
     return all(
         _match_one(filter_part.replace(r'\&', '&'), dct, incomplete)
         for filter_part in re.split(r'(?<!\\)&', filter_str))
 
 
-def match_filter_func(filter_str):
-    if filter_str is None:
+def match_filter_func(filters):
+    if not filters:
         return None
+    filters = variadic(filters)
 
     def _match_func(info_dict, *args, **kwargs):
-        if match_str(filter_str, info_dict, *args, **kwargs):
+        if any(match_str(f, info_dict, *args, **kwargs) for f in filters):
             return None
         else:
-            video_title = info_dict.get('title', info_dict.get('id', 'video'))
-            return '%s does not pass filter %s, skipping ..' % (video_title, filter_str)
+            video_title = info_dict.get('title') or info_dict.get('id') or 'video'
+            filter_str = ') | ('.join(map(str.strip, filters))
+            return f'{video_title} does not pass filter ({filter_str}), skipping ..'
     return _match_func
 
 
@@ -5422,15 +5454,18 @@ def parse_args(self):
 class WebSocketsWrapper():
     """Wraps websockets module to use in non-async scopes"""
 
-    def __init__(self, url, headers=None):
+    def __init__(self, url, headers=None, connect=True):
         self.loop = asyncio.events.new_event_loop()
         self.conn = compat_websockets.connect(
             url, extra_headers=headers, ping_interval=None,
             close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
+        if connect:
+            self.__enter__()
         atexit.register(self.__exit__, None, None, None)
 
     def __enter__(self):
-        self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
+        if not self.pool:
+            self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
         return self
 
     def send(self, *args):
@@ -5489,4 +5524,12 @@ def _cancel_all_tasks(loop):
 
 def merge_headers(*dicts):
     """Merge dicts of http headers case insensitively, prioritizing the latter ones"""
-    return {k.capitalize(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}
+    return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}
+
+
+class classproperty:
+    def __init__(self, f):
+        self.f = f
+
+    def __get__(self, _, cls):
+        return self.f(cls)