]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils.py
[utils] Add `join_nonempty`
[yt-dlp.git] / yt_dlp / utils.py
index 7a40258cf91ac65db9185e524f68df86625cdc66..75b4ed61b6d47c19e503d106ffc9b1876f782bb9 100644 (file)
@@ -18,7 +18,7 @@
 import gzip
 import hashlib
 import hmac
-import imp
+import importlib.util
 import io
 import itertools
 import json
@@ -2006,6 +2006,23 @@ def handle_starttag(self, tag, attrs):
         self.attrs = dict(attrs)
 
 
+class HTMLListAttrsParser(compat_HTMLParser):
+    """HTML parser to gather the attributes for the elements of a list"""
+
+    def __init__(self):
+        compat_HTMLParser.__init__(self)
+        self.items = []
+        self._level = 0
+
+    def handle_starttag(self, tag, attrs):
+        if tag == 'li' and self._level == 0:
+            self.items.append(dict(attrs))
+        self._level += 1
+
+    def handle_endtag(self, tag):
+        self._level -= 1
+
+
 def extract_attributes(html_element):
     """Given a string for an HTML element such as
     <el
@@ -2032,6 +2049,15 @@ def extract_attributes(html_element):
     return parser.attrs
 
 
+def parse_list(webpage):
+    """Given a string for an series of HTML <li> elements,
+    return a dictionary of their attributes"""
+    parser = HTMLListAttrsParser()
+    parser.feed(webpage)
+    parser.close()
+    return parser.items
+
+
 def clean_html(html):
     """Clean an HTML snippet into a readable string"""
 
@@ -2272,6 +2298,20 @@ def process_communicate_or_kill(p, *args, **kwargs):
         raise
 
 
+class Popen(subprocess.Popen):
+    if sys.platform == 'win32':
+        _startupinfo = subprocess.STARTUPINFO()
+        _startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
+    else:
+        _startupinfo = None
+
+    def __init__(self, *args, **kwargs):
+        super(Popen, self).__init__(*args, **kwargs, startupinfo=self._startupinfo)
+
+    def communicate_or_kill(self, *args, **kwargs):
+        return process_communicate_or_kill(self, *args, **kwargs)
+
+
 def get_subprocess_encoding():
     if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
         # For subprocess calls, encode with locale encoding
@@ -2342,14 +2382,25 @@ def decodeOption(optval):
     return optval
 
 
+_timetuple = collections.namedtuple('Time', ('hours', 'minutes', 'seconds', 'milliseconds'))
+
+
+def timetuple_from_msec(msec):
+    secs, msec = divmod(msec, 1000)
+    mins, secs = divmod(secs, 60)
+    hrs, mins = divmod(mins, 60)
+    return _timetuple(hrs, mins, secs, msec)
+
+
 def formatSeconds(secs, delim=':', msec=False):
-    if secs > 3600:
-        ret = '%d%s%02d%s%02d' % (secs // 3600, delim, (secs % 3600) // 60, delim, secs % 60)
-    elif secs > 60:
-        ret = '%d%s%02d' % (secs // 60, delim, secs % 60)
+    time = timetuple_from_msec(secs * 1000)
+    if time.hours:
+        ret = '%d%s%02d%s%02d' % (time.hours, delim, time.minutes, delim, time.seconds)
+    elif time.minutes:
+        ret = '%d%s%02d' % (time.minutes, delim, time.seconds)
     else:
-        ret = '%d' % secs
-    return '%s.%03d' % (ret, secs % 1) if msec else ret
+        ret = '%d' % time.seconds
+    return '%s.%03d' % (ret, time.milliseconds) if msec else ret
 
 
 def _ssl_load_windows_store_certs(ssl_context, storename):
@@ -2467,9 +2518,9 @@ class GeoRestrictedError(ExtractorError):
     geographic location due to geographic restrictions imposed by a website.
     """
 
-    def __init__(self, msg, countries=None):
-        super(GeoRestrictedError, self).__init__(msg, expected=True)
-        self.msg = msg
+    def __init__(self, msg, countries=None, **kwargs):
+        kwargs['expected'] = True
+        super(GeoRestrictedError, self).__init__(msg, **kwargs)
         self.countries = countries
 
 
@@ -2517,23 +2568,33 @@ def __init__(self, msg):
         self.msg = msg
 
 
-class ExistingVideoReached(YoutubeDLError):
-    """ --max-downloads limit has been reached. """
-    pass
+class DownloadCancelled(YoutubeDLError):
+    """ Exception raised when the download queue should be interrupted """
+    msg = 'The download was cancelled'
 
+    def __init__(self, msg=None):
+        if msg is not None:
+            self.msg = msg
+        YoutubeDLError.__init__(self, self.msg)
 
-class RejectedVideoReached(YoutubeDLError):
-    """ --max-downloads limit has been reached. """
-    pass
 
+class ExistingVideoReached(DownloadCancelled):
+    """ --break-on-existing triggered """
+    msg = 'Encountered a video that is already in the archive, stopping due to --break-on-existing'
 
-class ThrottledDownload(YoutubeDLError):
-    """ Download speed below --throttled-rate. """
-    pass
 
+class RejectedVideoReached(DownloadCancelled):
+    """ --break-on-reject triggered """
+    msg = 'Encountered a video that did not match filter, stopping due to --break-on-reject'
 
-class MaxDownloadsReached(YoutubeDLError):
+
+class MaxDownloadsReached(DownloadCancelled):
     """ --max-downloads limit has been reached. """
+    msg = 'Maximum number of downloads reached, stopping due to --max-downloads'
+
+
+class ThrottledDownload(YoutubeDLError):
+    """ Download speed below --throttled-rate. """
     pass
 
 
@@ -3689,14 +3750,14 @@ def parse_resolution(s):
     if s is None:
         return {}
 
-    mobj = re.search(r'\b(?P<w>\d+)\s*[xX×]\s*(?P<h>\d+)\b', s)
+    mobj = re.search(r'(?<![a-zA-Z0-9])(?P<w>\d+)\s*[xX×,]\s*(?P<h>\d+)(?![a-zA-Z0-9])', s)
     if mobj:
         return {
             'width': int(mobj.group('w')),
             'height': int(mobj.group('h')),
         }
 
-    mobj = re.search(r'\b(\d+)[pPiI]\b', s)
+    mobj = re.search(r'(?<![a-zA-Z0-9])(\d+)[pPiI](?![a-zA-Z0-9])', s)
     if mobj:
         return {'height': int(mobj.group(1))}
 
@@ -3836,7 +3897,7 @@ def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1):
         return default
     try:
         return int(v) * invscale // scale
-    except (ValueError, TypeError):
+    except (ValueError, TypeError, OverflowError):
         return default
 
 
@@ -3966,30 +4027,25 @@ def check_executable(exe, args=[]):
     """ Checks if the given binary is installed somewhere in PATH, and returns its name.
     args can be a list of arguments for a short output (like -version) """
     try:
-        process_communicate_or_kill(subprocess.Popen(
-            [exe] + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE))
+        Popen([exe] + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate_or_kill()
     except OSError:
         return False
     return exe
 
 
-def get_exe_version(exe, args=['--version'],
-                    version_re=None, unrecognized='present'):
-    """ Returns the version of the specified executable,
-    or False if the executable is not present """
+def _get_exe_version_output(exe, args):
     try:
         # STDIN should be redirected too. On UNIX-like systems, ffmpeg triggers
         # SIGTTOU if yt-dlp is run in the background.
         # See https://github.com/ytdl-org/youtube-dl/issues/955#issuecomment-209789656
-        out, _ = process_communicate_or_kill(subprocess.Popen(
-            [encodeArgument(exe)] + args,
-            stdin=subprocess.PIPE,
-            stdout=subprocess.PIPE, stderr=subprocess.STDOUT))
+        out, _ = Popen(
+            [encodeArgument(exe)] + args, stdin=subprocess.PIPE,
+            stdout=subprocess.PIPE, stderr=subprocess.STDOUT).communicate_or_kill()
     except OSError:
         return False
     if isinstance(out, bytes):  # Python 2.x
         out = out.decode('ascii', 'ignore')
-    return detect_exe_version(out, version_re, unrecognized)
+    return out
 
 
 def detect_exe_version(output, version_re=None, unrecognized='present'):
@@ -4003,6 +4059,14 @@ def detect_exe_version(output, version_re=None, unrecognized='present'):
         return unrecognized
 
 
+def get_exe_version(exe, args=['--version'],
+                    version_re=None, unrecognized='present'):
+    """ Returns the version of the specified executable,
+    or False if the executable is not present """
+    out = _get_exe_version_output(exe, args)
+    return detect_exe_version(out, version_re, unrecognized) if out else False
+
+
 class LazyList(collections.abc.Sequence):
     ''' Lazy immutable list from an iterable
     Note that slices of a LazyList are lists and not LazyList'''
@@ -4027,6 +4091,8 @@ def __iter__(self):
 
     def __exhaust(self):
         self.__cache.extend(self.__iterable)
+        # Discard the emptied iterable to make it pickle-able
+        self.__iterable = []
         return self.__cache
 
     def exhaust(self):
@@ -4478,6 +4544,7 @@ def q(qid):
     'description': 'description',
     'annotation': 'annotations.xml',
     'infojson': 'info.json',
+    'link': None,
     'pl_thumbnail': None,
     'pl_description': 'description',
     'pl_infojson': 'info.json',
@@ -4618,12 +4685,20 @@ def parse_codecs(codecs_str):
         return {}
     split_codecs = list(filter(None, map(
         str.strip, codecs_str.strip().strip(',').split(','))))
-    vcodec, acodec = None, None
+    vcodec, acodec, hdr = None, None, None
     for full_codec in split_codecs:
-        codec = full_codec.split('.')[0]
-        if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2', 'h263', 'h264', 'mp4v', 'hvc1', 'av01', 'theora', 'dvh1', 'dvhe'):
+        parts = full_codec.split('.')
+        codec = parts[0].replace('0', '')
+        if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2',
+                     'h263', 'h264', 'mp4v', 'hvc1', 'av1', 'theora', 'dvh1', 'dvhe'):
             if not vcodec:
-                vcodec = full_codec
+                vcodec = '.'.join(parts[:4]) if codec in ('vp9', 'av1') else full_codec
+                if codec in ('dvh1', 'dvhe'):
+                    hdr = 'DV'
+                elif codec == 'av1' and len(parts) > 3 and parts[3] == '10':
+                    hdr = 'HDR10'
+                elif full_codec.replace('0', '').startswith('vp9.2'):
+                    hdr = 'HDR10'
         elif codec in ('mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-3', 'ec-3', 'eac3', 'dtsc', 'dtse', 'dtsh', 'dtsl'):
             if not acodec:
                 acodec = full_codec
@@ -4639,6 +4714,7 @@ def parse_codecs(codecs_str):
         return {
             'vcodec': vcodec or 'none',
             'acodec': acodec or 'none',
+            'dynamic_range': hdr,
         }
     return {}
 
@@ -4696,7 +4772,7 @@ def determine_protocol(info_dict):
     if protocol is not None:
         return protocol
 
-    url = info_dict['url']
+    url = sanitize_url(info_dict['url'])
     if url.startswith('rtmp'):
         return 'rtmp'
     elif url.startswith('mms'):
@@ -4715,9 +4791,11 @@ def determine_protocol(info_dict):
 
 def render_table(header_row, data, delim=False, extraGap=0, hideEmpty=False):
     """ Render a list of rows, each as a list of values """
+    def width(string):
+        return len(remove_terminal_sequences(string))
 
     def get_max_lens(table):
-        return [max(len(compat_str(v)) for v in col) for col in zip(*table)]
+        return [max(width(str(v)) for v in col) for col in zip(*table)]
 
     def filter_using_list(row, filterArray):
         return [col for (take, col) in zip(filterArray, row) if take]
@@ -4729,10 +4807,15 @@ def filter_using_list(row, filterArray):
 
     table = [header_row] + data
     max_lens = get_max_lens(table)
+    extraGap += 1
     if delim:
-        table = [header_row] + [['-' * ml for ml in max_lens]] + data
-    format_str = ' '.join('%-' + compat_str(ml + extraGap) + 's' for ml in max_lens[:-1]) + ' %s'
-    return '\n'.join(format_str % tuple(row) for row in table)
+        table = [header_row] + [[delim * (ml + extraGap) for ml in max_lens]] + data
+    max_lens[-1] = 0
+    for row in table:
+        for pos, text in enumerate(map(str, row)):
+            row[pos] = text + (' ' * (max_lens[pos] - width(text) + extraGap))
+    ret = '\n'.join(''.join(row) for row in table)
+    return ret
 
 
 def _match_one(filter_part, dct, incomplete):
@@ -4756,7 +4839,6 @@ def _match_one(filter_part, dct, incomplete):
         (?P<key>[a-z_]+)
         \s*(?P<negation>!\s*)?(?P<op>%s)(?P<none_inclusive>\s*\?)?\s*
         (?:
-            (?P<intval>[0-9.]+(?:[kKmMgGtTpPeEzZyY]i?[Bb]?)?)|
             (?P<quote>["\'])(?P<quotedstrval>.+?)(?P=quote)|
             (?P<strval>.+?)
         )
@@ -4764,40 +4846,35 @@ def _match_one(filter_part, dct, incomplete):
         ''' % '|'.join(map(re.escape, COMPARISON_OPERATORS.keys())))
     m = operator_rex.search(filter_part)
     if m:
-        unnegated_op = COMPARISON_OPERATORS[m.group('op')]
-        if m.group('negation'):
+        m = m.groupdict()
+        unnegated_op = COMPARISON_OPERATORS[m['op']]
+        if m['negation']:
             op = lambda attr, value: not unnegated_op(attr, value)
         else:
             op = unnegated_op
-        actual_value = dct.get(m.group('key'))
-        if (m.group('quotedstrval') is not None
-            or m.group('strval') is not None
+        comparison_value = m['quotedstrval'] or m['strval'] or m['intval']
+        if m['quote']:
+            comparison_value = comparison_value.replace(r'\%s' % m['quote'], m['quote'])
+        actual_value = dct.get(m['key'])
+        numeric_comparison = None
+        if isinstance(actual_value, compat_numeric_types):
             # If the original field is a string and matching comparisonvalue is
             # a number we should respect the origin of the original field
             # and process comparison value as a string (see
-            # https://github.com/ytdl-org/youtube-dl/issues/11082).
-            or actual_value is not None and m.group('intval') is not None
-                and isinstance(actual_value, compat_str)):
-            comparison_value = m.group('quotedstrval') or m.group('strval') or m.group('intval')
-            quote = m.group('quote')
-            if quote is not None:
-                comparison_value = comparison_value.replace(r'\%s' % quote, quote)
-        else:
-            if m.group('op') in STRING_OPERATORS:
-                raise ValueError('Operator %s only supports string values!' % m.group('op'))
+            # https://github.com/ytdl-org/youtube-dl/issues/11082)
             try:
-                comparison_value = int(m.group('intval'))
+                numeric_comparison = int(comparison_value)
             except ValueError:
-                comparison_value = parse_filesize(m.group('intval'))
-                if comparison_value is None:
-                    comparison_value = parse_filesize(m.group('intval') + 'B')
-                if comparison_value is None:
-                    raise ValueError(
-                        'Invalid integer value %r in filter part %r' % (
-                            m.group('intval'), filter_part))
+                numeric_comparison = parse_filesize(comparison_value)
+                if numeric_comparison is None:
+                    numeric_comparison = parse_filesize(f'{comparison_value}B')
+                if numeric_comparison is None:
+                    numeric_comparison = parse_duration(comparison_value)
+        if numeric_comparison is not None and m['op'] in STRING_OPERATORS:
+            raise ValueError('Operator %s only supports string values!' % m['op'])
         if actual_value is None:
-            return incomplete or m.group('none_inclusive')
-        return op(actual_value, comparison_value)
+            return incomplete or m['none_inclusive']
+        return op(actual_value, comparison_value if numeric_comparison is None else numeric_comparison)
 
     UNARY_OPERATORS = {
         '': lambda v: (v is True) if isinstance(v, bool) else (v is not None),
@@ -4851,7 +4928,12 @@ def parse_dfxp_time_expr(time_expr):
 
 
 def srt_subtitles_timecode(seconds):
-    return '%02d:%02d:%02d,%03d' % (seconds / 3600, (seconds % 3600) / 60, seconds % 60, (seconds % 1) * 1000)
+    return '%02d:%02d:%02d,%03d' % timetuple_from_msec(seconds * 1000)
+
+
+def ass_subtitles_timecode(seconds):
+    time = timetuple_from_msec(seconds * 1000)
+    return '%01d:%02d:%02d.%02d' % (*time[:-1], time.milliseconds / 10)
 
 
 def dfxp2srt(dfxp_data):
@@ -6135,11 +6217,11 @@ def write_xattr(path, key, value):
                        + [encodeFilename(path, True)])
 
                 try:
-                    p = subprocess.Popen(
+                    p = Popen(
                         cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
                 except EnvironmentError as e:
                     raise XAttrMetadataError(e.errno, e.strerror)
-                stdout, stderr = process_communicate_or_kill(p)
+                stdout, stderr = p.communicate_or_kill()
                 stderr = stderr.decode('utf-8', 'replace')
                 if p.returncode != 0:
                     raise XAttrMetadataError(p.returncode, stderr)
@@ -6197,6 +6279,12 @@ def random_birthday(year_field, month_field, day_field):
 Icon=text-html
 '''.lstrip()
 
+LINK_TEMPLATES = {
+    'url': DOT_URL_LINK_TEMPLATE,
+    'desktop': DOT_DESKTOP_LINK_TEMPLATE,
+    'webloc': DOT_WEBLOC_LINK_TEMPLATE,
+}
+
 
 def iri_to_uri(iri):
     """
@@ -6308,12 +6396,13 @@ def get_executable_path():
 
 
 def load_plugins(name, suffix, namespace):
-    plugin_info = [None]
     classes = {}
     try:
-        plugin_info = imp.find_module(
-            name, [os.path.join(get_executable_path(), 'ytdlp_plugins')])
-        plugins = imp.load_module(name, *plugin_info)
+        plugins_spec = importlib.util.spec_from_file_location(
+            name, os.path.join(get_executable_path(), 'ytdlp_plugins', name, '__init__.py'))
+        plugins = importlib.util.module_from_spec(plugins_spec)
+        sys.modules[plugins_spec.name] = plugins
+        plugins_spec.loader.exec_module(plugins)
         for name in dir(plugins):
             if name in namespace:
                 continue
@@ -6321,11 +6410,8 @@ def load_plugins(name, suffix, namespace):
                 continue
             klass = getattr(plugins, name)
             classes[name] = namespace[name] = klass
-    except ImportError:
+    except FileNotFoundError:
         pass
-    finally:
-        if plugin_info[0] is not None:
-            plugin_info[0].close()
     return classes
 
 
@@ -6456,6 +6542,13 @@ def jwt_encode_hs256(payload_data, key, headers={}):
     return token
 
 
+# can be extended in future to verify the signature and parse header and return the algorithm used if it's not HS256
+def jwt_decode_hs256(jwt):
+    header_b64, payload_b64, signature_b64 = jwt.split('.')
+    payload_data = json.loads(base64.urlsafe_b64decode(payload_b64))
+    return payload_data
+
+
 def supports_terminal_sequences(stream):
     if compat_os_name == 'nt':
         if get_windows_version() < (10, 0, 10586):
@@ -6468,12 +6561,18 @@ def supports_terminal_sequences(stream):
         return False
 
 
-TERMINAL_SEQUENCES = {
-    'DOWN': '\n',
-    'UP': '\x1b[A',
-    'ERASE_LINE': '\x1b[K',
-    'RED': '\033[0;31m',
-    'YELLOW': '\033[0;33m',
-    'BLUE': '\033[0;34m',
-    'RESET_STYLE': '\033[0m',
-}
+_terminal_sequences_re = re.compile('\033\\[[^m]+m')
+
+
+def remove_terminal_sequences(string):
+    return _terminal_sequences_re.sub('', string)
+
+
+def number_of_digits(number):
+    return len('%d' % number)
+
+
+def join_nonempty(*values, delim='-', from_dict=None):
+    if from_dict is not None:
+        values = operator.itemgetter(values)(from_dict)
+    return delim.join(map(str, filter(None, values)))