X-Git-Url: https://jfr.im/git/yt-dlp.git/blobdiff_plain/b69fd25c25f23a859aefae69a1cc4116896536b8..1c1b2f96ae9696ef16b1b27d1a007bf89c683a0c:/yt_dlp/utils.py diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index fdcb350f2..72f11691f 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -3,6 +3,8 @@ from __future__ import unicode_literals +import asyncio +import atexit import base64 import binascii import calendar @@ -45,6 +47,7 @@ compat_HTMLParser, compat_HTTPError, compat_basestring, + compat_brotli, compat_chr, compat_cookiejar, compat_ctypes_WINFUNCTYPE, @@ -58,6 +61,7 @@ compat_kwargs, compat_os_name, compat_parse_qs, + compat_shlex_split, compat_shlex_quote, compat_str, compat_struct_pack, @@ -72,6 +76,7 @@ compat_urllib_parse_unquote_plus, compat_urllib_request, compat_urlparse, + compat_websockets, compat_xpath, ) @@ -80,6 +85,12 @@ sockssocket, ) +try: + import certifi + has_certifi = True +except ImportError: + has_certifi = False + def register_socks_protocols(): # "Register" SOCKS protocols @@ -139,11 +150,17 @@ def random_user_agent(): return _USER_AGENT_TPL % random.choice(_CHROME_VERSIONS) +SUPPORTED_ENCODINGS = [ + 'gzip', 'deflate' +] +if compat_brotli: + SUPPORTED_ENCODINGS.append('br') + std_headers = { 'User-Agent': random_user_agent(), 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', - 'Accept-Encoding': 'gzip, deflate', 'Accept-Language': 'en-us,en;q=0.5', + 'Sec-Fetch-Mode': 'navigate', } @@ -210,6 +227,7 @@ def random_user_agent(): '%Y/%m/%d %H:%M:%S', '%Y%m%d%H%M', '%Y%m%d%H%M%S', + '%Y%m%d', '%Y-%m-%d %H:%M', '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M:%S.%f', @@ -304,7 +322,7 @@ def write_json_file(obj, fn): try: with tf: - json.dump(obj, tf) + json.dump(obj, tf, ensure_ascii=False) if sys.platform == 'win32': # Need to remove existing file on Windows, else os.rename raises # WindowsError or FileExistsError. @@ -414,17 +432,33 @@ def get_element_by_id(id, html): return get_element_by_attribute('id', id, html) +def get_element_html_by_id(id, html): + """Return the html of the tag with the specified ID in the passed HTML document""" + return get_element_html_by_attribute('id', id, html) + + def get_element_by_class(class_name, html): """Return the content of the first tag with the specified class in the passed HTML document""" retval = get_elements_by_class(class_name, html) return retval[0] if retval else None +def get_element_html_by_class(class_name, html): + """Return the html of the first tag with the specified class in the passed HTML document""" + retval = get_elements_html_by_class(class_name, html) + return retval[0] if retval else None + + def get_element_by_attribute(attribute, value, html, escape_value=True): retval = get_elements_by_attribute(attribute, value, html, escape_value) return retval[0] if retval else None +def get_element_html_by_attribute(attribute, value, html, escape_value=True): + retval = get_elements_html_by_attribute(attribute, value, html, escape_value) + return retval[0] if retval else None + + def get_elements_by_class(class_name, html): """Return the content of all tags with the specified class in the passed HTML document as a list""" return get_elements_by_attribute( @@ -432,29 +466,123 @@ def get_elements_by_class(class_name, html): html, escape_value=False) -def get_elements_by_attribute(attribute, value, html, escape_value=True): +def get_elements_html_by_class(class_name, html): + """Return the html of all tags with the specified class in the passed HTML document as a list""" + return get_elements_html_by_attribute( + 'class', r'[^\'"]*\b%s\b[^\'"]*' % re.escape(class_name), + html, escape_value=False) + + +def get_elements_by_attribute(*args, **kwargs): """Return the content of the tag with the specified attribute in the passed HTML document""" + return [content for content, _ in get_elements_text_and_html_by_attribute(*args, **kwargs)] + + +def get_elements_html_by_attribute(*args, **kwargs): + """Return the html of the tag with the specified attribute in the passed HTML document""" + return [whole for _, whole in get_elements_text_and_html_by_attribute(*args, **kwargs)] + + +def get_elements_text_and_html_by_attribute(attribute, value, html, escape_value=True): + """ + Return the text (content) and the html (whole) of the tag with the specified + attribute in the passed HTML document + """ + + value_quote_optional = '' if re.match(r'''[\s"'`=<>]''', value) else '?' value = re.escape(value) if escape_value else value - retlist = [] - for m in re.finditer(r'''(?xs) - <([a-zA-Z0-9:._-]+) - (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'|))*? - \s+%s=['"]?%s['"]? - (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'|))*? - \s*> - (?P.*?) - - ''' % (re.escape(attribute), value), html): - res = m.group('content') + partial_element_re = r'''(?x) + <(?P[a-zA-Z0-9:._-]+) + (?:\s(?:[^>"']|"[^"]*"|'[^']*')*)? + \s%(attribute)s\s*=\s*(?P<_q>['"]%(vqo)s)(?-x:%(value)s)(?P=_q) + ''' % {'attribute': re.escape(attribute), 'value': value, 'vqo': value_quote_optional} - if res.startswith('"') or res.startswith("'"): - res = res[1:-1] + for m in re.finditer(partial_element_re, html): + content, whole = get_element_text_and_html_by_tag(m.group('tag'), html[m.start():]) - retlist.append(unescapeHTML(res)) + yield ( + unescapeHTML(re.sub(r'^(?P["\'])(?P.*)(?P=q)$', r'\g', content, flags=re.DOTALL)), + whole + ) + + +class HTMLBreakOnClosingTagParser(compat_HTMLParser): + """ + HTML parser which raises HTMLBreakOnClosingTagException upon reaching the + closing tag for the first opening tag it has encountered, and can be used + as a context manager + """ + + class HTMLBreakOnClosingTagException(Exception): + pass + + def __init__(self): + self.tagstack = collections.deque() + compat_HTMLParser.__init__(self) - return retlist + def __enter__(self): + return self + + def __exit__(self, *_): + self.close() + + def close(self): + # handle_endtag does not return upon raising HTMLBreakOnClosingTagException, + # so data remains buffered; we no longer have any interest in it, thus + # override this method to discard it + pass + + def handle_starttag(self, tag, _): + self.tagstack.append(tag) + + def handle_endtag(self, tag): + if not self.tagstack: + raise compat_HTMLParseError('no tags in the stack') + while self.tagstack: + inner_tag = self.tagstack.pop() + if inner_tag == tag: + break + else: + raise compat_HTMLParseError(f'matching opening tag for closing {tag} tag not found') + if not self.tagstack: + raise self.HTMLBreakOnClosingTagException() + + +def get_element_text_and_html_by_tag(tag, html): + """ + For the first element with the specified tag in the passed HTML document + return its' content (text) and the whole element (html) + """ + def find_or_raise(haystack, needle, exc): + try: + return haystack.index(needle) + except ValueError: + raise exc + closing_tag = f'' + whole_start = find_or_raise( + html, f'<{tag}', compat_HTMLParseError(f'opening {tag} tag not found')) + content_start = find_or_raise( + html[whole_start:], '>', compat_HTMLParseError(f'malformed opening {tag} tag')) + content_start += whole_start + 1 + with HTMLBreakOnClosingTagParser() as parser: + parser.feed(html[whole_start:content_start]) + if not parser.tagstack or parser.tagstack[0] != tag: + raise compat_HTMLParseError(f'parser did not match opening {tag} tag') + offset = content_start + while offset < len(html): + next_closing_tag_start = find_or_raise( + html[offset:], closing_tag, + compat_HTMLParseError(f'closing {tag} tag not found')) + next_closing_tag_end = next_closing_tag_start + len(closing_tag) + try: + parser.feed(html[offset:offset + next_closing_tag_end]) + offset += next_closing_tag_end + except HTMLBreakOnClosingTagParser.HTMLBreakOnClosingTagException: + return html[content_start:offset + next_closing_tag_start], \ + html[whole_start:offset + next_closing_tag_end] + raise compat_HTMLParseError('unexpected end of html') class HTMLAttributeParser(compat_HTMLParser): @@ -526,10 +654,9 @@ def clean_html(html): if html is None: # Convenience for sanitizing descriptions etc. return html - # Newline vs
- html = html.replace('\n', ' ') - html = re.sub(r'(?u)\s*<\s*br\s*/?\s*>\s*', '\n', html) - html = re.sub(r'(?u)<\s*/\s*p\s*>\s*<\s*p[^>]*>', '\n', html) + html = re.sub(r'\s+', ' ', html) + html = re.sub(r'(?u)\s?<\s?br\s?/?\s?>\s?', '\n', html) + html = re.sub(r'(?u)<\s?/\s?p\s?>\s?<\s?p[^>]*>', '\n', html) # Strip html tags html = re.sub('<.*?>', '', html) # Replace html entities @@ -553,7 +680,7 @@ def sanitize_open(filename, open_mode): import msvcrt msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) return (sys.stdout.buffer if hasattr(sys.stdout, 'buffer') else sys.stdout, filename) - stream = open(encodeFilename(filename), open_mode) + stream = locked_file(filename, open_mode, block=False).open() return (stream, filename) except (IOError, OSError) as err: if err.errno in (errno.EACCES,): @@ -565,7 +692,7 @@ def sanitize_open(filename, open_mode): raise else: # An exception here should be caught in the caller - stream = open(encodeFilename(alt_filename), open_mode) + stream = locked_file(filename, open_mode, block=False).open() return (stream, alt_filename) @@ -578,36 +705,40 @@ def timeconvert(timestr): return timestamp -def sanitize_filename(s, restricted=False, is_id=False): +def sanitize_filename(s, restricted=False, is_id=NO_DEFAULT): """Sanitizes a string so it could be used as part of a filename. - If restricted is set, use a stricter subset of allowed characters. - Set is_id if this is not an arbitrary string, but an ID that should be kept - if possible. + @param restricted Use a stricter subset of allowed characters + @param is_id Whether this is an ID that should be kept unchanged if possible. + If unset, yt-dlp's new sanitization rules are in effect """ + if s == '': + return '' + def replace_insane(char): if restricted and char in ACCENT_CHARS: return ACCENT_CHARS[char] elif not restricted and char == '\n': - return ' ' + return '\0 ' elif char == '?' or ord(char) < 32 or ord(char) == 127: return '' elif char == '"': return '' if restricted else '\'' elif char == ':': - return '_-' if restricted else ' -' + return '\0_\0-' if restricted else '\0 \0-' elif char in '\\/|*<>': - return '_' - if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace()): - return '_' - if restricted and ord(char) > 127: - return '_' + return '\0_' + if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace() or ord(char) > 127): + return '\0_' return char - if s == '': - return '' - # Handle timestamps - s = re.sub(r'[0-9]+(?::[0-9]+)+', lambda m: m.group(0).replace(':', '_'), s) + s = re.sub(r'[0-9]+(?::[0-9]+)+', lambda m: m.group(0).replace(':', '_'), s) # Handle timestamps result = ''.join(map(replace_insane, s)) + if is_id is NO_DEFAULT: + result = re.sub('(\0.)(?:(?=\\1)..)+', r'\1', result) # Remove repeated substitute chars + STRIP_RE = '(?:\0.|[ _-])*' + result = re.sub(f'^\0.{STRIP_RE}|{STRIP_RE}\0.$', '', result) # Remove substitute chars from start/end + result = result.replace('\0', '') or '_' + if not is_id: while '__' in result: result = result.replace('__', '_') @@ -884,33 +1015,34 @@ def make_HTTPS_handler(params, **kwargs): opts_check_certificate = not params.get('nocheckcertificate') context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.check_hostname = opts_check_certificate + if params.get('legacyserverconnect'): + context.options |= 4 # SSL_OP_LEGACY_SERVER_CONNECT context.verify_mode = ssl.CERT_REQUIRED if opts_check_certificate else ssl.CERT_NONE if opts_check_certificate: - try: - context.load_default_certs() - # Work around the issue in load_default_certs when there are bad certificates. See: - # https://github.com/yt-dlp/yt-dlp/issues/1060, - # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312 - except ssl.SSLError: - # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151 - if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'): - # Create a new context to discard any certificates that were already loaded - context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - context.check_hostname, context.verify_mode = True, ssl.CERT_REQUIRED - for storename in ('CA', 'ROOT'): - _ssl_load_windows_store_certs(context, storename) - context.set_default_verify_paths() + if has_certifi and 'no-certifi' not in params.get('compat_opts', []): + context.load_verify_locations(cafile=certifi.where()) + else: + try: + context.load_default_certs() + # Work around the issue in load_default_certs when there are bad certificates. See: + # https://github.com/yt-dlp/yt-dlp/issues/1060, + # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312 + except ssl.SSLError: + # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151 + if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'): + # Create a new context to discard any certificates that were already loaded + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname, context.verify_mode = True, ssl.CERT_REQUIRED + for storename in ('CA', 'ROOT'): + _ssl_load_windows_store_certs(context, storename) + context.set_default_verify_paths() return YoutubeDLHTTPSHandler(params, context=context, **kwargs) def bug_reports_message(before=';'): - if ytdl_is_updateable(): - update_cmd = 'type yt-dlp -U to update' - else: - update_cmd = 'see https://github.com/yt-dlp/yt-dlp on how to update' - msg = 'please report this issue on https://github.com/yt-dlp/yt-dlp .' - msg += ' Make sure you are using the latest version; %s.' % update_cmd - msg += ' Be sure to call yt-dlp with the --verbose flag and include its complete output.' + msg = ('please report this issue on https://github.com/yt-dlp/yt-dlp , ' + 'filling out the appropriate issue template. ' + 'Confirm you are on the latest version using yt-dlp -U') before = before.rstrip() if not before or before.endswith(('.', '!', '?')): @@ -947,7 +1079,7 @@ def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=N if sys.exc_info()[0] in network_exceptions: expected = True - self.msg = str(msg) + self.orig_msg = str(msg) self.traceback = tb self.expected = expected self.cause = cause @@ -958,14 +1090,15 @@ def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=N super(ExtractorError, self).__init__(''.join(( format_field(ie, template='[%s] '), format_field(video_id, template='%s: '), - self.msg, + msg, format_field(cause, template=' (caused by %r)'), '' if expected else bug_reports_message()))) def format_traceback(self): - if self.traceback is None: - return None - return ''.join(traceback.format_tb(self.traceback)) + return join_nonempty( + self.traceback and ''.join(traceback.format_tb(self.traceback)), + self.cause and ''.join(traceback.format_exception(None, self.cause, self.cause.__traceback__)[1:]), + delim='\n') or None class UnsupportedError(ExtractorError): @@ -1243,6 +1376,12 @@ def deflate(data): except zlib.error: return zlib.decompress(data) + @staticmethod + def brotli(data): + if not data: + return data + return compat_brotli.decompress(data) + def http_request(self, req): # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not # always respected by websites, some tend to give out URLs with non percent-encoded @@ -1259,12 +1398,15 @@ def http_request(self, req): if url != url_escaped: req = update_Request(req, url=url_escaped) - for h, v in std_headers.items(): + for h, v in self._params.get('http_headers', std_headers).items(): # Capitalize is needed because of Python bug 2275: http://bugs.python.org/issue2275 # The dict keys are capitalized because of this bug by urllib if h.capitalize() not in req.headers: req.add_header(h, v) + if 'Accept-encoding' not in req.headers: + req.add_header('Accept-encoding', ', '.join(SUPPORTED_ENCODINGS)) + req.headers = handle_youtubedl_headers(req.headers) if sys.version_info < (2, 7) and '#' in req.get_full_url(): @@ -1303,6 +1445,12 @@ def http_response(self, req, resp): resp = compat_urllib_request.addinfourl(gz, old_resp.headers, old_resp.url, old_resp.code) resp.msg = old_resp.msg del resp.headers['Content-encoding'] + # brotli + if resp.headers.get('Content-encoding', '') == 'br': + resp = compat_urllib_request.addinfourl( + io.BytesIO(self.brotli(resp.read())), old_resp.headers, old_resp.url, old_resp.code) + resp.msg = old_resp.msg + del resp.headers['Content-encoding'] # Percent-encode redirect URL of Location HTTP header to satisfy RFC 3986 (see # https://github.com/ytdl-org/youtube-dl/issues/6457). if 300 <= resp.code < 400: @@ -1722,7 +1870,7 @@ def subtitles_filename(filename, sub_lang, sub_format, expected_real_ext=None): def datetime_from_str(date_str, precision='auto', format='%Y%m%d'): """ Return a datetime object from a string in the format YYYYMMDD or - (now|today|date)[+-][0-9](microsecond|second|minute|hour|day|week|month|year)(s)? + (now|today|yesterday|date)[+-][0-9](microsecond|second|minute|hour|day|week|month|year)(s)? format: string date format used to return datetime object from precision: round the time portion of a datetime object. @@ -1733,7 +1881,7 @@ def datetime_from_str(date_str, precision='auto', format='%Y%m%d'): if precision == 'auto': auto_precision = True precision = 'microsecond' - today = datetime_round(datetime.datetime.now(), precision) + today = datetime_round(datetime.datetime.utcnow(), precision) if date_str in ('now', 'today'): return today if date_str == 'yesterday': @@ -1761,13 +1909,17 @@ def datetime_from_str(date_str, precision='auto', format='%Y%m%d'): return datetime_round(datetime.datetime.strptime(date_str, format), precision) -def date_from_str(date_str, format='%Y%m%d'): +def date_from_str(date_str, format='%Y%m%d', strict=False): """ Return a datetime object from a string in the format YYYYMMDD or - (now|today|date)[+-][0-9](microsecond|second|minute|hour|day|week|month|year)(s)? + (now|today|yesterday|date)[+-][0-9](microsecond|second|minute|hour|day|week|month|year)(s)? + + If "strict", only (now|today)[+-][0-9](day|week|month|year)(s)? is allowed format: string date format used to return datetime object from """ + if strict and not re.fullmatch(r'\d{8}|(now|today)[+-]\d+(day|week|month|year)(s)?', date_str): + raise ValueError(f'Invalid date format {date_str}') return datetime_from_str(date_str, precision='microsecond', format=format).date() @@ -1814,11 +1966,11 @@ class DateRange(object): def __init__(self, start=None, end=None): """start and end must be strings in the format accepted by date""" if start is not None: - self.start = date_from_str(start) + self.start = date_from_str(start, strict=True) else: self.start = datetime.datetime.min.date() if end is not None: - self.end = date_from_str(end) + self.end = date_from_str(end, strict=True) else: self.end = datetime.datetime.max.date() if self.start > self.end: @@ -2005,38 +2157,52 @@ class OVERLAPPED(ctypes.Structure): whole_low = 0xffffffff whole_high = 0x7fffffff - def _lock_file(f, exclusive): + def _lock_file(f, exclusive, block): overlapped = OVERLAPPED() overlapped.Offset = 0 overlapped.OffsetHigh = 0 overlapped.hEvent = 0 f._lock_file_overlapped_p = ctypes.pointer(overlapped) - handle = msvcrt.get_osfhandle(f.fileno()) - if not LockFileEx(handle, 0x2 if exclusive else 0x0, 0, - whole_low, whole_high, f._lock_file_overlapped_p): - raise OSError('Locking file failed: %r' % ctypes.FormatError()) + + if not LockFileEx(msvcrt.get_osfhandle(f.fileno()), + (0x2 if exclusive else 0x0) | (0x0 if block else 0x1), + 0, whole_low, whole_high, f._lock_file_overlapped_p): + raise BlockingIOError('Locking file failed: %r' % ctypes.FormatError()) def _unlock_file(f): assert f._lock_file_overlapped_p handle = msvcrt.get_osfhandle(f.fileno()) - if not UnlockFileEx(handle, 0, - whole_low, whole_high, f._lock_file_overlapped_p): + if not UnlockFileEx(handle, 0, whole_low, whole_high, f._lock_file_overlapped_p): raise OSError('Unlocking file failed: %r' % ctypes.FormatError()) else: - # Some platforms, such as Jython, is missing fcntl try: import fcntl - def _lock_file(f, exclusive): - fcntl.flock(f, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH) + def _lock_file(f, exclusive, block): + try: + fcntl.flock(f, + fcntl.LOCK_SH if not exclusive + else fcntl.LOCK_EX if block + else fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + raise + except OSError: # AOSP does not have flock() + fcntl.lockf(f, + fcntl.LOCK_SH if not exclusive + else fcntl.LOCK_EX if block + else fcntl.LOCK_EX | fcntl.LOCK_NB) def _unlock_file(f): - fcntl.flock(f, fcntl.LOCK_UN) + try: + fcntl.flock(f, fcntl.LOCK_UN) + except OSError: + fcntl.lockf(f, fcntl.LOCK_UN) + except ImportError: UNSUPPORTED_MSG = 'file locking is not supported on this platform' - def _lock_file(f, exclusive): + def _lock_file(f, exclusive, block): raise IOError(UNSUPPORTED_MSG) def _unlock_file(f): @@ -2044,15 +2210,18 @@ def _unlock_file(f): class locked_file(object): - def __init__(self, filename, mode, encoding=None): - assert mode in ['r', 'a', 'w'] + _closed = False + + def __init__(self, filename, mode, block=True, encoding=None): + assert mode in ['r', 'rb', 'a', 'ab', 'w', 'wb'] self.f = io.open(filename, mode, encoding=encoding) self.mode = mode + self.block = block def __enter__(self): - exclusive = self.mode != 'r' + exclusive = 'r' not in self.mode try: - _lock_file(self.f, exclusive) + _lock_file(self.f, exclusive, self.block) except IOError: self.f.close() raise @@ -2060,9 +2229,11 @@ def __enter__(self): def __exit__(self, etype, value, traceback): try: - _unlock_file(self.f) + if not self._closed: + _unlock_file(self.f) finally: self.f.close() + self._closed = True def __iter__(self): return iter(self.f) @@ -2073,6 +2244,15 @@ def write(self, *args): def read(self, *args): return self.f.read(*args) + def flush(self): + self.f.flush() + + def open(self): + return self.__enter__() + + def close(self, *args): + self.__exit__(self, *args, value=False, traceback=False) + def get_filesystem_encoding(): encoding = sys.getfilesystemencoding() @@ -2112,16 +2292,19 @@ def unsmuggle_url(smug_url, default=None): def format_decimal_suffix(num, fmt='%d%s', *, factor=1000): """ Formats numbers with decimal sufixes like K, M, etc """ num, factor = float_or_none(num), float(factor) - if num is None: + if num is None or num < 0: return None - exponent = 0 if num == 0 else int(math.log(num, factor)) - suffix = ['', *'KMGTPEZY'][exponent] + POSSIBLE_SUFFIXES = 'kMGTPEZY' + exponent = 0 if num == 0 else min(int(math.log(num, factor)), len(POSSIBLE_SUFFIXES)) + suffix = ['', *POSSIBLE_SUFFIXES][exponent] + if factor == 1024: + suffix = {'k': 'Ki', '': ''}.get(suffix, f'{suffix}i') converted = num / (factor ** exponent) return fmt % (converted, suffix) def format_bytes(bytes): - return format_decimal_suffix(bytes, '%.2f%siB', factor=1024) or 'N/A' + return format_decimal_suffix(bytes, '%.2f%sB', factor=1024) or 'N/A' def lookup_unit_table(unit_table, s): @@ -2210,7 +2393,7 @@ def parse_count(s): if s is None: return None - s = s.strip() + s = re.sub(r'^[^\d]+\s', '', s).strip() if re.match(r'^[\d,.]+$', s): return str_to_int(s) @@ -2222,9 +2405,17 @@ def parse_count(s): 'M': 1000 ** 2, 'kk': 1000 ** 2, 'KK': 1000 ** 2, + 'b': 1000 ** 3, + 'B': 1000 ** 3, } - return lookup_unit_table(_UNIT_TABLE, s) + ret = lookup_unit_table(_UNIT_TABLE, s) + if ret is not None: + return ret + + mobj = re.match(r'([\d,.]+)(?:$|\s)', s) + if mobj: + return str_to_int(mobj.group(1)) def parse_resolution(s): @@ -2369,13 +2560,8 @@ def get_method(self): def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1): - if get_attr: - if v is not None: - v = getattr(v, get_attr, None) - if v == '': - v = None - if v is None: - return default + if get_attr and v is not None: + v = getattr(v, get_attr, None) try: return int(v) * invscale // scale except (ValueError, TypeError, OverflowError): @@ -2419,6 +2605,13 @@ def url_or_none(url): return url if re.match(r'^(?:(?:https?|rt(?:m(?:pt?[es]?|fp)|sp[su]?)|mms|ftps?):)?//', url) else None +def request_to_url(req): + if isinstance(req, compat_urllib_request.Request): + return req.get_full_url() + else: + return req + + def strftime_or_none(timestamp, date_format, default=None): datetime_object = None try: @@ -2439,30 +2632,35 @@ def parse_duration(s): return None days, hours, mins, secs, ms = [None] * 5 - m = re.match(r'(?:(?:(?:(?P[0-9]+):)?(?P[0-9]+):)?(?P[0-9]+):)?(?P[0-9]+)(?P\.[0-9]+)?Z?$', s) + m = re.match(r'''(?x) + (?P + (?:(?:(?P[0-9]+):)?(?P[0-9]+):)?(?P[0-9]+):)? + (?P(?(before_secs)[0-9]{1,2}|[0-9]+)) + (?P[.:][0-9]+)?Z?$ + ''', s) if m: - days, hours, mins, secs, ms = m.groups() + days, hours, mins, secs, ms = m.group('days', 'hours', 'mins', 'secs', 'ms') else: m = re.match( r'''(?ix)(?:P? (?: - [0-9]+\s*y(?:ears?)?\s* + [0-9]+\s*y(?:ears?)?,?\s* )? (?: - [0-9]+\s*m(?:onths?)?\s* + [0-9]+\s*m(?:onths?)?,?\s* )? (?: - [0-9]+\s*w(?:eeks?)?\s* + [0-9]+\s*w(?:eeks?)?,?\s* )? (?: - (?P[0-9]+)\s*d(?:ays?)?\s* + (?P[0-9]+)\s*d(?:ays?)?,?\s* )? T)? (?: - (?P[0-9]+)\s*h(?:ours?)?\s* + (?P[0-9]+)\s*h(?:ours?)?,?\s* )? (?: - (?P[0-9]+)\s*m(?:in(?:ute)?s?)?\s* + (?P[0-9]+)\s*m(?:in(?:ute)?s?)?,?\s* )? (?: (?P[0-9]+)(?P\.[0-9]+)?\s*s(?:ec(?:ond)?s?)?\s* @@ -2486,7 +2684,7 @@ def parse_duration(s): if days: duration += float(days) * 24 * 60 * 60 if ms: - duration += float(ms) + duration += float(ms.replace(':', '.')) return duration @@ -2651,13 +2849,14 @@ def __len__(self): def __init__(self, pagefunc, pagesize, use_cache=True): self._pagefunc = pagefunc self._pagesize = pagesize + self._pagecount = float('inf') self._use_cache = use_cache self._cache = {} def getpage(self, pagenum): page_results = self._cache.get(pagenum) if page_results is None: - page_results = list(self._pagefunc(pagenum)) + page_results = [] if pagenum > self._pagecount else list(self._pagefunc(pagenum)) if self._use_cache: self._cache[pagenum] = page_results return page_results @@ -2669,7 +2868,7 @@ def _getslice(self, start, end): raise NotImplementedError('This method must be implemented by subclasses') def __getitem__(self, idx): - # NOTE: cache must be enabled if this is used + assert self._use_cache, 'Indexing PagedList requires cache' if not isinstance(idx, int) or idx < 0: raise TypeError('indices must be non-negative integers') entries = self.getslice(idx, idx + 1) @@ -2695,7 +2894,11 @@ def _getslice(self, start, end): if (end is not None and firstid <= end <= nextfirstid) else None) - page_results = self.getpage(pagenum) + try: + page_results = self.getpage(pagenum) + except Exception: + self._pagecount = pagenum - 1 + raise if startv != 0 or endv is not None: page_results = page_results[startv:endv] yield from page_results @@ -2715,13 +2918,12 @@ def _getslice(self, start, end): class InAdvancePagedList(PagedList): def __init__(self, pagefunc, pagecount, pagesize): - self._pagecount = pagecount PagedList.__init__(self, pagefunc, pagesize, True) + self._pagecount = pagecount def _getslice(self, start, end): start_page = start // self._pagesize - end_page = ( - self._pagecount if end is None else (end // self._pagesize + 1)) + end_page = self._pagecount if end is None else min(self._pagecount, end // self._pagesize + 1) skip_elems = start - start_page * self._pagesize only_more = None if end is None else end - start for pagenum in range(start_page, end_page): @@ -3004,6 +3206,8 @@ def fix_kv(m): return '"%s"' % v + code = re.sub(r'new Date\((".+")\)', r'\g<1>', code) + return re.sub(r'''(?sx) "(?:[^"\\]*(?:\\\\|\\['"nurtbfx/\n]))*[^"\\]*"| '(?:[^'\\]*(?:\\\\|\\['"nurtbfx/\n]))*[^'\\]*'| @@ -3025,6 +3229,9 @@ def q(qid): return q +POSTPROCESS_WHEN = {'pre_process', 'after_filter', 'before_dl', 'after_move', 'post_process', 'after_video', 'playlist'} + + DEFAULT_OUTTMPL = { 'default': '%(title)s [%(id)s].%(ext)s', 'chapter': '%(title)s - %(section_number)03d %(section_title)s [%(id)s].%(ext)s', @@ -3037,6 +3244,7 @@ def q(qid): 'annotation': 'annotations.xml', 'infojson': 'info.json', 'link': None, + 'pl_video': None, 'pl_thumbnail': None, 'pl_description': 'description', 'pl_infojson': 'info.json', @@ -3185,7 +3393,7 @@ def parse_codecs(codecs_str): return {} split_codecs = list(filter(None, map( str.strip, codecs_str.strip().strip(',').split(',')))) - vcodec, acodec, hdr = None, None, None + vcodec, acodec, tcodec, hdr = None, None, None, None for full_codec in split_codecs: parts = full_codec.split('.') codec = parts[0].replace('0', '') @@ -3202,13 +3410,17 @@ def parse_codecs(codecs_str): elif codec in ('flac', 'mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-3', 'ec-3', 'eac3', 'dtsc', 'dtse', 'dtsh', 'dtsl'): if not acodec: acodec = full_codec + elif codec in ('stpp', 'wvtt',): + if not tcodec: + tcodec = full_codec else: write_string('WARNING: Unknown codec %s\n' % full_codec, sys.stderr) - if vcodec or acodec: + if vcodec or acodec or tcodec: return { 'vcodec': vcodec or 'none', 'acodec': acodec or 'none', 'dynamic_range': hdr, + **({'tcodec': tcodec} if tcodec is not None else {}), } elif len(split_codecs) == 2: return { @@ -3298,19 +3510,18 @@ def get_max_lens(table): return [max(width(str(v)) for v in col) for col in zip(*table)] def filter_using_list(row, filterArray): - return [col for (take, col) in zip(filterArray, row) if take] + return [col for take, col in itertools.zip_longest(filterArray, row, fillvalue=True) if take] - if hide_empty: - max_lens = get_max_lens(data) - header_row = filter_using_list(header_row, max_lens) - data = [filter_using_list(row, max_lens) for row in data] + max_lens = get_max_lens(data) if hide_empty else [] + header_row = filter_using_list(header_row, max_lens) + data = [filter_using_list(row, max_lens) for row in data] table = [header_row] + data max_lens = get_max_lens(table) extra_gap += 1 if delim: table = [header_row, [delim * (ml + extra_gap) for ml in max_lens]] + data - table[1][-1] = table[1][-1][:-extra_gap] # Remove extra_gap from end of delimiter + table[1][-1] = table[1][-1][:-extra_gap * len(delim)] # Remove extra_gap from end of delimiter for row in table: for pos, text in enumerate(map(str, row)): if '\t' in text: @@ -3338,6 +3549,11 @@ def _match_one(filter_part, dct, incomplete): '=': operator.eq, } + if isinstance(incomplete, bool): + is_incomplete = lambda _: incomplete + else: + is_incomplete = lambda k: k in incomplete + operator_rex = re.compile(r'''(?x)\s* (?P[a-z_]+) \s*(?P!\s*)?(?P%s)(?P\s*\?)?\s* @@ -3376,7 +3592,7 @@ def _match_one(filter_part, dct, incomplete): if numeric_comparison is not None and m['op'] in STRING_OPERATORS: raise ValueError('Operator %s only supports string values!' % m['op']) if actual_value is None: - return incomplete or m['none_inclusive'] + return is_incomplete(m['key']) or m['none_inclusive'] return op(actual_value, comparison_value if numeric_comparison is None else numeric_comparison) UNARY_OPERATORS = { @@ -3391,7 +3607,7 @@ def _match_one(filter_part, dct, incomplete): if m: op = UNARY_OPERATORS[m.group('op')] actual_value = dct.get(m.group('key')) - if incomplete and actual_value is None: + if is_incomplete(m.group('key')) and actual_value is None: return True return op(actual_value) @@ -3399,21 +3615,29 @@ def _match_one(filter_part, dct, incomplete): def match_str(filter_str, dct, incomplete=False): - """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false - When incomplete, all conditions passes on missing fields + """ Filter a dictionary with a simple string syntax. + @returns Whether the filter passes + @param incomplete Set of keys that is expected to be missing from dct. + Can be True/False to indicate all/none of the keys may be missing. + All conditions on incomplete keys pass if the key is missing """ return all( _match_one(filter_part.replace(r'\&', '&'), dct, incomplete) for filter_part in re.split(r'(?' + ('|'.join(re.escape(po) for po in PRIVATE_OPTS)) + ')=.+$') + + def _scrub_eq(o): + m = eqre.match(o) + if m: + return m.group('key') + '=PRIVATE' + else: + return o + + opts = list(map(_scrub_eq, opts)) + for idx, opt in enumerate(opts): + if opt in PRIVATE_OPTS and idx + 1 < len(opts): + opts[idx + 1] = 'PRIVATE' + return opts + + def append_config(self, *args, label=None): + config = type(self)(self._parser, label) + config._loaded_paths = self._loaded_paths + if config.init(*args): + self.configs.append(config) + + @property + def all_args(self): + for config in reversed(self.configs): + yield from config.all_args + yield from self.own_args or [] + + def parse_args(self): + return self._parser.parse_args(list(self.all_args)) + + +class WebSocketsWrapper(): + """Wraps websockets module to use in non-async scopes""" + + def __init__(self, url, headers=None, connect=True): + self.loop = asyncio.events.new_event_loop() + self.conn = compat_websockets.connect( + url, extra_headers=headers, ping_interval=None, + close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf')) + if connect: + self.__enter__() + atexit.register(self.__exit__, None, None, None) + + def __enter__(self): + if not self.pool: + self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop) + return self + + def send(self, *args): + self.run_with_loop(self.pool.send(*args), self.loop) + + def recv(self, *args): + return self.run_with_loop(self.pool.recv(*args), self.loop) + + def __exit__(self, type, value, traceback): + try: + return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop) + finally: + self.loop.close() + self._cancel_all_tasks(self.loop) + + # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications + # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class + @staticmethod + def run_with_loop(main, loop): + if not asyncio.coroutines.iscoroutine(main): + raise ValueError(f'a coroutine was expected, got {main!r}') + + try: + return loop.run_until_complete(main) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + if hasattr(loop, 'shutdown_default_executor'): + loop.run_until_complete(loop.shutdown_default_executor()) + + @staticmethod + def _cancel_all_tasks(loop): + to_cancel = asyncio.tasks.all_tasks(loop) + + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete( + asyncio.tasks.gather(*to_cancel, loop=loop, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler({ + 'message': 'unhandled exception during asyncio.run() shutdown', + 'exception': task.exception(), + 'task': task, + }) + + +has_websockets = bool(compat_websockets) + + +def merge_headers(*dicts): + """Merge dicts of http headers case insensitively, prioritizing the latter ones""" + return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))} + + +class classproperty: + def __init__(self, f): + self.f = f + + def __get__(self, _, cls): + return self.f(cls)