]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils.py
Improved progress reporting (See desc) (#1125)
[yt-dlp.git] / yt_dlp / utils.py
index cdf4c0755bbd629f962a250200ed7b2a6a16f6f9..027387897446def7fb3d94b5cf58f7283bbae32c 100644 (file)
@@ -16,6 +16,8 @@
 import errno
 import functools
 import gzip
+import hashlib
+import hmac
 import imp
 import io
 import itertools
@@ -1740,6 +1742,7 @@ def random_user_agent():
     '%b %dth %Y %I:%M',
     '%Y %m %d',
     '%Y-%m-%d',
+    '%Y.%m.%d.',
     '%Y/%m/%d',
     '%Y/%m/%d %H:%M',
     '%Y/%m/%d %H:%M:%S',
@@ -1761,6 +1764,7 @@ def random_user_agent():
     '%b %d %Y at %H:%M:%S',
     '%B %d %Y at %H:%M',
     '%B %d %Y at %H:%M:%S',
+    '%H:%M %d-%b-%Y',
 )
 
 DATE_FORMATS_DAY_FIRST = list(DATE_FORMATS)
@@ -2095,7 +2099,9 @@ def sanitize_filename(s, restricted=False, is_id=False):
     def replace_insane(char):
         if restricted and char in ACCENT_CHARS:
             return ACCENT_CHARS[char]
-        if char == '?' or ord(char) < 32 or ord(char) == 127:
+        elif not restricted and char == '\n':
+            return ' '
+        elif char == '?' or ord(char) < 32 or ord(char) == 127:
             return ''
         elif char == '"':
             return '' if restricted else '\''
@@ -2346,29 +2352,42 @@ def formatSeconds(secs, delim=':', msec=False):
     return '%s.%03d' % (ret, secs % 1) if msec else ret
 
 
-def make_HTTPS_handler(params, **kwargs):
-    opts_no_check_certificate = params.get('nocheckcertificate', False)
-    if hasattr(ssl, 'create_default_context'):  # Python >= 3.4 or 2.7.9
-        context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
-        if opts_no_check_certificate:
-            context.check_hostname = False
-            context.verify_mode = ssl.CERT_NONE
+def _ssl_load_windows_store_certs(ssl_context, storename):
+    # Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py
+    try:
+        certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename)
+                 if encoding == 'x509_asn' and (
+                     trust is True or ssl.Purpose.SERVER_AUTH.oid in trust)]
+    except PermissionError:
+        return
+    for cert in certs:
         try:
-            return YoutubeDLHTTPSHandler(params, context=context, **kwargs)
-        except TypeError:
-            # Python 2.7.8
-            # (create_default_context present but HTTPSHandler has no context=)
+            ssl_context.load_verify_locations(cadata=cert)
+        except ssl.SSLError:
             pass
 
-    if sys.version_info < (3, 2):
-        return YoutubeDLHTTPSHandler(params, **kwargs)
-    else:  # Python < 3.4
-        context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
-        context.verify_mode = (ssl.CERT_NONE
-                               if opts_no_check_certificate
-                               else ssl.CERT_REQUIRED)
-        context.set_default_verify_paths()
-        return YoutubeDLHTTPSHandler(params, context=context, **kwargs)
+
+def make_HTTPS_handler(params, **kwargs):
+    opts_check_certificate = not params.get('nocheckcertificate')
+    context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+    context.check_hostname = opts_check_certificate
+    context.verify_mode = ssl.CERT_REQUIRED if opts_check_certificate else ssl.CERT_NONE
+    if opts_check_certificate:
+        try:
+            context.load_default_certs()
+            # Work around the issue in load_default_certs when there are bad certificates. See:
+            # https://github.com/yt-dlp/yt-dlp/issues/1060,
+            # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312
+        except ssl.SSLError:
+            # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
+            if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
+                # Create a new context to discard any certificates that were already loaded
+                context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+                context.check_hostname, context.verify_mode = True, ssl.CERT_REQUIRED
+                for storename in ('CA', 'ROOT'):
+                    _ssl_load_windows_store_certs(context, storename)
+            context.set_default_verify_paths()
+    return YoutubeDLHTTPSHandler(params, context=context, **kwargs)
 
 
 def bug_reports_message(before=';'):
@@ -2408,7 +2427,7 @@ def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=N
         if sys.exc_info()[0] in network_exceptions:
             expected = True
 
-        self.msg = msg
+        self.msg = str(msg)
         self.traceback = tb
         self.expected = expected
         self.cause = cause
@@ -2419,7 +2438,7 @@ def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=N
         super(ExtractorError, self).__init__(''.join((
             format_field(ie, template='[%s] '),
             format_field(video_id, template='%s: '),
-            msg,
+            self.msg,
             format_field(cause, template=' (caused by %r)'),
             '' if expected else bug_reports_message())))
 
@@ -3033,8 +3052,16 @@ def redirect_request(self, req, fp, code, msg, headers, newurl):
 
 def extract_timezone(date_str):
     m = re.search(
-        r'^.{8,}?(?P<tz>Z$| ?(?P<sign>\+|-)(?P<hours>[0-9]{2}):?(?P<minutes>[0-9]{2})$)',
-        date_str)
+        r'''(?x)
+            ^.{8,}?                                              # >=8 char non-TZ prefix, if present
+            (?P<tz>Z|                                            # just the UTC Z, or
+                (?:(?<=.\b\d{4}|\b\d{2}:\d\d)|                   # preceded by 4 digits or hh:mm or
+                   (?<!.\b[a-zA-Z]{3}|[a-zA-Z]{4}|..\b\d\d))     # not preceded by 3 alpha word or >= 4 alpha or 2 digits
+                   [ ]?                                          # optional space
+                (?P<sign>\+|-)                                   # +/-
+                (?P<hours>[0-9]{2}):?(?P<minutes>[0-9]{2})       # hh[:]mm
+            $)
+        ''', date_str)
     if not m:
         timezone = datetime.timedelta()
     else:
@@ -3280,6 +3307,14 @@ def platform_name():
     return res
 
 
+def get_windows_version():
+    ''' Get Windows version. None if it's not running on Windows '''
+    if compat_os_name == 'nt':
+        return version_tuple(platform.win32_ver()[1])
+    else:
+        return None
+
+
 def _windows_write_string(s, out):
     """ Returns True if the string was written using special methods,
     False if it has yet to be written out."""
@@ -4454,12 +4489,12 @@ def q(qid):
 STR_FORMAT_RE_TMPL = r'''(?x)
     (?<!%)(?P<prefix>(?:%%)*)
     %
-    (?P<has_key>\((?P<key>{0})\))?  # mapping key
+    (?P<has_key>\((?P<key>{0})\))?
     (?P<format>
-        (?:[#0\-+ ]+)?  # conversion flags (optional)
-        (?:\d+)?  # minimum field width (optional)
-        (?:\.\d+)?  # precision (optional)
-        [hlL]?  # length modifier (optional)
+        (?P<conversion>[#0\-+ ]+)?
+        (?P<min_width>\d+)?
+        (?P<precision>\.\d+)?
+        (?P<len_mod>[hlL])?  # unused in python
         {1}  # conversion type
     )
 '''
@@ -4493,11 +4528,10 @@ def is_outdated_version(version, limit, assume_new=True):
 
 def ytdl_is_updateable():
     """ Returns if yt-dlp can be updated with -U """
-    return False
 
-    from zipimport import zipimporter
+    from .update import is_non_updateable
 
-    return isinstance(globals().get('__loader__'), zipimporter) or hasattr(sys, 'frozen')
+    return not is_non_updateable()
 
 
 def args_to_str(args):
@@ -4518,20 +4552,24 @@ def mimetype2ext(mt):
     if mt is None:
         return None
 
-    ext = {
+    mt, _, params = mt.partition(';')
+    mt = mt.strip()
+
+    FULL_MAP = {
         'audio/mp4': 'm4a',
         # Per RFC 3003, audio/mpeg can be .mp1, .mp2 or .mp3. Here use .mp3 as
         # it's the most popular one
         'audio/mpeg': 'mp3',
         'audio/x-wav': 'wav',
-    }.get(mt)
+        'audio/wav': 'wav',
+        'audio/wave': 'wav',
+    }
+
+    ext = FULL_MAP.get(mt)
     if ext is not None:
         return ext
 
-    _, _, res = mt.rpartition('/')
-    res = res.split(';')[0].strip().lower()
-
-    return {
+    SUBTYPE_MAP = {
         '3gpp': '3gp',
         'smptett+xml': 'tt',
         'ttaf+xml': 'dfxp',
@@ -4550,7 +4588,28 @@ def mimetype2ext(mt):
         'quicktime': 'mov',
         'mp2t': 'ts',
         'x-wav': 'wav',
-    }.get(res, res)
+        'filmstrip+json': 'fs',
+        'svg+xml': 'svg',
+    }
+
+    _, _, subtype = mt.rpartition('/')
+    ext = SUBTYPE_MAP.get(subtype.lower())
+    if ext is not None:
+        return ext
+
+    SUFFIX_MAP = {
+        'json': 'json',
+        'xml': 'xml',
+        'zip': 'zip',
+        'gzip': 'gz',
+    }
+
+    _, _, suffix = subtype.partition('+')
+    ext = SUFFIX_MAP.get(suffix)
+    if ext is not None:
+        return ext
+
+    return subtype.replace('+', '.')
 
 
 def parse_codecs(codecs_str):
@@ -6250,7 +6309,7 @@ def get_executable_path():
 
 def load_plugins(name, suffix, namespace):
     plugin_info = [None]
-    classes = []
+    classes = {}
     try:
         plugin_info = imp.find_module(
             name, [os.path.join(get_executable_path(), 'ytdlp_plugins')])
@@ -6261,8 +6320,7 @@ def load_plugins(name, suffix, namespace):
             if not name.endswith(suffix):
                 continue
             klass = getattr(plugins, name)
-            classes.append(klass)
-            namespace[name] = klass
+            classes[name] = namespace[name] = klass
     except ImportError:
         pass
     finally:
@@ -6363,3 +6421,45 @@ def traverse_dict(dictn, keys, casesense=True):
 
 def variadic(x, allowed_types=(str, bytes)):
     return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,)
+
+
+# create a JSON Web Signature (jws) with HS256 algorithm
+# the resulting format is in JWS Compact Serialization
+# implemented following JWT https://www.rfc-editor.org/rfc/rfc7519.html
+# implemented following JWS https://www.rfc-editor.org/rfc/rfc7515.html
+def jwt_encode_hs256(payload_data, key, headers={}):
+    header_data = {
+        'alg': 'HS256',
+        'typ': 'JWT',
+    }
+    if headers:
+        header_data.update(headers)
+    header_b64 = base64.b64encode(json.dumps(header_data).encode('utf-8'))
+    payload_b64 = base64.b64encode(json.dumps(payload_data).encode('utf-8'))
+    h = hmac.new(key.encode('utf-8'), header_b64 + b'.' + payload_b64, hashlib.sha256)
+    signature_b64 = base64.b64encode(h.digest())
+    token = header_b64 + b'.' + payload_b64 + b'.' + signature_b64
+    return token
+
+
+def supports_terminal_sequences(stream):
+    if compat_os_name == 'nt':
+        if get_windows_version() < (10, ):
+            return False
+    elif not os.getenv('TERM'):
+        return False
+    try:
+        return stream.isatty()
+    except BaseException:
+        return False
+
+
+TERMINAL_SEQUENCES = {
+    'DOWN': '\n',
+    'UP': '\x1b[A',
+    'ERASE_LINE': '\x1b[K',
+    'RED': '\033[0;31m',
+    'YELLOW': '\033[0;33m',
+    'BLUE': '\033[0;34m',
+    'RESET_STYLE': '\033[0m',
+}