X-Git-Url: https://jfr.im/git/yt-dlp.git/blobdiff_plain/8aa0e7cd96a1e2f315d49744793ae07f6543ce4c..7b2c3f47c6b586a208655fcfc716bba3f8619d1e:/yt_dlp/utils.py diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 6abdca788..b9c579cb6 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -146,6 +146,7 @@ def random_user_agent(): NO_DEFAULT = object() +IDENTITY = lambda x: x ENGLISH_MONTH_NAMES = [ 'January', 'February', 'March', 'April', 'May', 'June', @@ -4744,22 +4745,42 @@ def pkcs1pad(data, length): return [0, 2] + pseudo_random + [0] + data -def encode_base_n(num, n, table=None): - FULL_TABLE = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' - if not table: - table = FULL_TABLE[:n] +def _base_n_table(n, table): + if not table and not n: + raise ValueError('Either table or n must be specified') + elif not table: + table = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'[:n] + elif not n or n == len(table): + return table + raise ValueError(f'base {n} exceeds table length {len(table)}') - if n > len(table): - raise ValueError('base %d exceeds table length %d' % (n, len(table))) - if num == 0: +def encode_base_n(num, n=None, table=None): + """Convert given int to a base-n string""" + table = _base_n_table(n) + if not num: return table[0] - ret = '' + result, base = '', len(table) while num: - ret = table[num % n] + ret - num = num // n - return ret + result = table[num % base] + result + num = num // result + return result + + +def decode_base_n(string, n=None, table=None): + """Convert given base-n string to int""" + table = {char: index for index, char in enumerate(_base_n_table(n, table))} + result, base = 0, len(table) + for char in string: + result = result * base + table[char] + return result + + +def decode_base(value, digits): + write_string('DeprecationWarning: yt_dlp.utils.decode_base is deprecated ' + 'and may be removed in a future version. Use yt_dlp.decode_base_n instead') + return decode_base_n(value, table=digits) def decode_packed_codes(code): @@ -5062,11 +5083,11 @@ def to_high_limit_path(path): return path -def format_field(obj, field=None, template='%s', ignore=NO_DEFAULT, default='', func=None): +def format_field(obj, field=None, template='%s', ignore=NO_DEFAULT, default='', func=IDENTITY): val = traverse_obj(obj, *variadic(field)) - if (not val and val != 0) if ignore is NO_DEFAULT else val in ignore: + if (not val and val != 0) if ignore is NO_DEFAULT else val in variadic(ignore): return default - return template % (func(val) if func else val) + return template % func(val) def clean_podcast_url(url): @@ -5207,10 +5228,8 @@ def _traverse_obj(obj, path, _current_depth=0): if isinstance(expected_type, type): type_test = lambda val: val if isinstance(val, expected_type) else None - elif expected_type is not None: - type_test = expected_type else: - type_test = lambda val: val + type_test = expected_type or IDENTITY for path in path_list: depth = 0 @@ -5243,17 +5262,6 @@ def variadic(x, allowed_types=(str, bytes, dict)): return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,) -def decode_base(value, digits): - # This will convert given base-x string to scalar (long or int) - table = {char: index for index, char in enumerate(digits)} - result = 0 - base = len(digits) - for chr in value: - result *= base - result += table[chr] - return result - - def time_seconds(**kwargs): t = datetime.datetime.now(datetime.timezone(datetime.timedelta(**kwargs))) return t.timestamp() @@ -5327,7 +5335,7 @@ def number_of_digits(number): def join_nonempty(*values, delim='-', from_dict=None): if from_dict is not None: - values = map(from_dict.get, values) + values = (traverse_obj(from_dict, variadic(v)) for v in values) return delim.join(map(str, filter(None, values)))