]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils.py
[compat/asyncio] Use `asyncio.all_tasks`
[yt-dlp.git] / yt_dlp / utils.py
index 35e8d1d5b685d8035259c2dc6fd749c2a2a90697..0171394fc1dbf9ab879cabc1c59e40d9d330a429 100644 (file)
@@ -1,5 +1,4 @@
 #!/usr/bin/env python3
-import asyncio
 import atexit
 import base64
 import binascii
@@ -41,7 +40,7 @@
 import zlib
 
 from .compat import (
-    compat_brotli,
+    asyncio,
     compat_chr,
     compat_cookiejar,
     compat_etree_fromstring,
     compat_urllib_parse_urlparse,
     compat_urllib_request,
     compat_urlparse,
-    compat_websockets,
 )
+from .dependencies import brotli, certifi, websockets
 from .socks import ProxyType, sockssocket
 
-try:
-    import certifi
-    has_certifi = True
-except ImportError:
-    has_certifi = False
-
 
 def register_socks_protocols():
     # "Register" SOCKS protocols
@@ -136,7 +129,7 @@ def random_user_agent():
 SUPPORTED_ENCODINGS = [
     'gzip', 'deflate'
 ]
-if compat_brotli:
+if brotli:
     SUPPORTED_ENCODINGS.append('br')
 
 std_headers = {
@@ -281,22 +274,16 @@ def write_json_file(obj, fn):
         if sys.platform == 'win32':
             # Need to remove existing file on Windows, else os.rename raises
             # WindowsError or FileExistsError.
-            try:
+            with contextlib.suppress(OSError):
                 os.unlink(fn)
-            except OSError:
-                pass
-        try:
+        with contextlib.suppress(OSError):
             mask = os.umask(0)
             os.umask(mask)
             os.chmod(tf.name, 0o666 & ~mask)
-        except OSError:
-            pass
         os.rename(tf.name, fn)
     except Exception:
-        try:
+        with contextlib.suppress(OSError):
             os.remove(tf.name)
-        except OSError:
-            pass
         raise
 
 
@@ -574,12 +561,9 @@ def extract_attributes(html_element):
     }.
     """
     parser = HTMLAttributeParser()
-    try:
+    with contextlib.suppress(compat_HTMLParseError):
         parser.feed(html_element)
         parser.close()
-    # Older Python may throw HTMLParseError in case of malformed HTML
-    except compat_HTMLParseError:
-        pass
     return parser.attrs
 
 
@@ -799,10 +783,8 @@ def _htmlentity_transform(entity_with_semicolon):
         else:
             base = 10
         # See https://github.com/ytdl-org/youtube-dl/issues/7518
-        try:
+        with contextlib.suppress(ValueError):
             return compat_chr(int(numstr, base))
-        except ValueError:
-            pass
 
     # Unknown entity in name, return its literal representation
     return '&%s;' % entity
@@ -811,7 +793,7 @@ def _htmlentity_transform(entity_with_semicolon):
 def unescapeHTML(s):
     if s is None:
         return None
-    assert type(s) == compat_str
+    assert isinstance(s, str)
 
     return re.sub(
         r'&([^&;]+;)', lambda m: _htmlentity_transform(m.group(1)), s)
@@ -864,7 +846,7 @@ def get_subprocess_encoding():
 
 
 def encodeFilename(s, for_subprocess=False):
-    assert type(s) == str
+    assert isinstance(s, str)
     return s
 
 
@@ -923,10 +905,8 @@ def _ssl_load_windows_store_certs(ssl_context, storename):
     except PermissionError:
         return
     for cert in certs:
-        try:
+        with contextlib.suppress(ssl.SSLError):
             ssl_context.load_verify_locations(cadata=cert)
-        except ssl.SSLError:
-            pass
 
 
 def make_HTTPS_handler(params, **kwargs):
@@ -1278,7 +1258,7 @@ def deflate(data):
     def brotli(data):
         if not data:
             return data
-        return compat_brotli.decompress(data)
+        return brotli.decompress(data)
 
     def http_request(self, req):
         # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not
@@ -1390,7 +1370,7 @@ class SocksConnection(base_class):
         def connect(self):
             self.sock = sockssocket()
             self.sock.setproxy(*proxy_args)
-            if type(self.timeout) in (int, float):
+            if isinstance(self.timeout, (int, float)):
                 self.sock.settimeout(self.timeout)
             self.sock.connect((self.host, self.port))
 
@@ -1525,9 +1505,7 @@ def prepare_line(line):
                 try:
                     cf.write(prepare_line(line))
                 except compat_cookiejar.LoadError as e:
-                    write_string(
-                        'WARNING: skipping cookie file entry due to %s: %r\n'
-                        % (e, line), sys.stderr)
+                    write_string(f'WARNING: skipping cookie file entry due to {e}: {line!r}\n')
                     continue
         cf.seek(0)
         self._really_load(cf, filename, ignore_discard, ignore_expires)
@@ -1645,12 +1623,10 @@ def parse_iso8601(date_str, delimiter='T', timezone=None):
     if timezone is None:
         timezone, date_str = extract_timezone(date_str)
 
-    try:
+    with contextlib.suppress(ValueError):
         date_format = f'%Y-%m-%d{delimiter}%H:%M:%S'
         dt = datetime.datetime.strptime(date_str, date_format) - timezone
         return calendar.timegm(dt.timetuple())
-    except ValueError:
-        pass
 
 
 def date_formats(day_first=True):
@@ -1670,17 +1646,13 @@ def unified_strdate(date_str, day_first=True):
     _, date_str = extract_timezone(date_str)
 
     for expression in date_formats(day_first):
-        try:
+        with contextlib.suppress(ValueError):
             upload_date = datetime.datetime.strptime(date_str, expression).strftime('%Y%m%d')
-        except ValueError:
-            pass
     if upload_date is None:
         timetuple = email.utils.parsedate_tz(date_str)
         if timetuple:
-            try:
+            with contextlib.suppress(ValueError):
                 upload_date = datetime.datetime(*timetuple[:6]).strftime('%Y%m%d')
-            except ValueError:
-                pass
     if upload_date is not None:
         return compat_str(upload_date)
 
@@ -1708,11 +1680,9 @@ def unified_timestamp(date_str, day_first=True):
         date_str = m.group(1)
 
     for expression in date_formats(day_first):
-        try:
+        with contextlib.suppress(ValueError):
             dt = datetime.datetime.strptime(date_str, expression) - timezone + datetime.timedelta(hours=pm_delta)
             return calendar.timegm(dt.timetuple())
-        except ValueError:
-            pass
     timetuple = email.utils.parsedate_tz(date_str)
     if timetuple:
         return calendar.timegm(timetuple) + pm_delta * 3600
@@ -1878,9 +1848,8 @@ def get_windows_version():
 
 
 def write_string(s, out=None, encoding=None):
-    if out is None:
-        out = sys.stderr
-    assert type(s) == compat_str
+    assert isinstance(s, str)
+    out = out or sys.stderr
 
     if 'b' in getattr(out, 'mode', ''):
         byt = s.encode(encoding or preferredencoding(), 'ignore')
@@ -2482,18 +2451,10 @@ def parse_duration(s):
             else:
                 return None
 
-    duration = 0
-    if secs:
-        duration += float(secs)
-    if mins:
-        duration += float(mins) * 60
-    if hours:
-        duration += float(hours) * 60 * 60
-    if days:
-        duration += float(days) * 24 * 60 * 60
     if ms:
-        duration += float(ms.replace(':', '.'))
-    return duration
+        ms = ms.replace(':', '.')
+    return sum(float(part or 0) * mult for part, mult in (
+        (days, 86400), (hours, 3600), (mins, 60), (secs, 1), (ms, 1)))
 
 
 def prepend_extension(filename, ext, expected_real_ext=None):
@@ -2956,9 +2917,10 @@ def encode_compat_str(string, encoding=preferredencoding(), errors='strict'):
 
 
 def parse_age_limit(s):
-    if type(s) == int:
+    # isinstance(False, int) is True. So type() must be used instead
+    if type(s) is int:
         return s if 0 <= s <= 21 else None
-    if not isinstance(s, str):
+    elif not isinstance(s, str):
         return None
     m = re.match(r'^(?P<age>\d{1,2})\+?$', s)
     if m:
@@ -3042,7 +3004,7 @@ def q(qid):
     return q
 
 
-POSTPROCESS_WHEN = {'pre_process', 'after_filter', 'before_dl', 'after_move', 'post_process', 'after_video', 'playlist'}
+POSTPROCESS_WHEN = ('pre_process', 'after_filter', 'before_dl', 'after_move', 'post_process', 'after_video', 'playlist')
 
 
 DEFAULT_OUTTMPL = {
@@ -3226,7 +3188,7 @@ def parse_codecs(codecs_str):
             if not tcodec:
                 tcodec = full_codec
         else:
-            write_string('WARNING: Unknown codec %s\n' % full_codec, sys.stderr)
+            write_string(f'WARNING: Unknown codec {full_codec}\n')
     if vcodec or acodec or tcodec:
         return {
             'vcodec': vcodec or 'none',
@@ -4933,7 +4895,7 @@ def get_executable_path():
 
 def load_plugins(name, suffix, namespace):
     classes = {}
-    try:
+    with contextlib.suppress(FileNotFoundError):
         plugins_spec = importlib.util.spec_from_file_location(
             name, os.path.join(get_executable_path(), 'ytdlp_plugins', name, '__init__.py'))
         plugins = importlib.util.module_from_spec(plugins_spec)
@@ -4946,8 +4908,6 @@ def load_plugins(name, suffix, namespace):
                 continue
             klass = getattr(plugins, name)
             classes[name] = namespace[name] = klass
-    except FileNotFoundError:
-        pass
     return classes
 
 
@@ -4956,13 +4916,14 @@ def traverse_obj(
         casesense=True, is_user_input=False, traverse_string=False):
     ''' Traverse nested list/dict/tuple
     @param path_list        A list of paths which are checked one by one.
-                            Each path is a list of keys where each key is a string,
-                            a function, a tuple of strings/None or "...".
-                            When a fuction is given, it takes the key and value as arguments
-                            and returns whether the key matches or not. When a tuple is given,
-                            all the keys given in the tuple are traversed, and
-                            "..." traverses all the keys in the object
-                            "None" returns the object without traversal
+                            Each path is a list of keys where each key is a:
+                              - None:     Do nothing
+                              - string:   A dictionary key
+                              - int:      An index into a list
+                              - tuple:    A list of keys all of which will be traversed
+                              - Ellipsis: Fetch all values in the object
+                              - Function: Takes the key and value as arguments
+                                          and returns whether the key matches or not
     @param default          Default value to return
     @param expected_type    Only accept final value of this type (Can also be any callable)
     @param get_all          Return all the values obtained from a path or only the first one
@@ -5252,15 +5213,17 @@ def all_args(self):
         yield from self.own_args or []
 
     def parse_args(self):
-        return self._parser.parse_args(list(self.all_args))
+        return self._parser.parse_args(self.all_args)
 
 
 class WebSocketsWrapper():
     """Wraps websockets module to use in non-async scopes"""
+    pool = None
 
     def __init__(self, url, headers=None, connect=True):
-        self.loop = asyncio.events.new_event_loop()
-        self.conn = compat_websockets.connect(
+        self.loop = asyncio.new_event_loop()
+        # XXX: "loop" is deprecated
+        self.conn = websockets.connect(
             url, extra_headers=headers, ping_interval=None,
             close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
         if connect:
@@ -5289,7 +5252,7 @@ def __exit__(self, type, value, traceback):
     # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
     @staticmethod
     def run_with_loop(main, loop):
-        if not asyncio.coroutines.iscoroutine(main):
+        if not asyncio.iscoroutine(main):
             raise ValueError(f'a coroutine was expected, got {main!r}')
 
         try:
@@ -5301,7 +5264,7 @@ def run_with_loop(main, loop):
 
     @staticmethod
     def _cancel_all_tasks(loop):
-        to_cancel = asyncio.tasks.all_tasks(loop)
+        to_cancel = asyncio.all_tasks(loop)
 
         if not to_cancel:
             return
@@ -5309,8 +5272,9 @@ def _cancel_all_tasks(loop):
         for task in to_cancel:
             task.cancel()
 
+        # XXX: "loop" is removed in python 3.10+
         loop.run_until_complete(
-            asyncio.tasks.gather(*to_cancel, loop=loop, return_exceptions=True))
+            asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
 
         for task in to_cancel:
             if task.cancelled():
@@ -5323,9 +5287,6 @@ def _cancel_all_tasks(loop):
                 })
 
 
-has_websockets = bool(compat_websockets)
-
-
 def merge_headers(*dicts):
     """Merge dicts of http headers case insensitively, prioritizing the latter ones"""
     return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}
@@ -5337,3 +5298,12 @@ def __init__(self, f):
 
     def __get__(self, _, cls):
         return self.f(cls)
+
+
+def Namespace(**kwargs):
+    return collections.namedtuple('Namespace', kwargs)(**kwargs)
+
+
+# Deprecated
+has_certifi = bool(certifi)
+has_websockets = bool(websockets)