X-Git-Url: https://jfr.im/git/yt-dlp.git/blobdiff_plain/a3125791c7a5cdf2c8c025b99788bf686edd1a8a..ff91cf748343c41a74b09120896feccd390f91ce:/yt_dlp/utils.py diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index e39a5b29e..9b130e109 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -3,6 +3,8 @@ from __future__ import unicode_literals +import asyncio +import atexit import base64 import binascii import calendar @@ -45,6 +47,7 @@ compat_HTMLParser, compat_HTTPError, compat_basestring, + compat_brotli, compat_chr, compat_cookiejar, compat_ctypes_WINFUNCTYPE, @@ -73,6 +76,7 @@ compat_urllib_parse_unquote_plus, compat_urllib_request, compat_urlparse, + compat_websockets, compat_xpath, ) @@ -140,10 +144,16 @@ def random_user_agent(): return _USER_AGENT_TPL % random.choice(_CHROME_VERSIONS) +SUPPORTED_ENCODINGS = [ + 'gzip', 'deflate' +] +if compat_brotli: + SUPPORTED_ENCODINGS.append('br') + std_headers = { 'User-Agent': random_user_agent(), 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', - 'Accept-Encoding': 'gzip, deflate', + 'Accept-Encoding': ', '.join(SUPPORTED_ENCODINGS), 'Accept-Language': 'en-us,en;q=0.5', 'Sec-Fetch-Mode': 'navigate', } @@ -1020,7 +1030,7 @@ def make_HTTPS_handler(params, **kwargs): def bug_reports_message(before=';'): msg = ('please report this issue on https://github.com/yt-dlp/yt-dlp , ' 'filling out the "Broken site" issue template properly. ' - 'Confirm you are on the latest version using -U') + 'Confirm you are on the latest version using yt-dlp -U') before = before.rstrip() if not before or before.endswith(('.', '!', '?')): @@ -1057,7 +1067,7 @@ def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=N if sys.exc_info()[0] in network_exceptions: expected = True - self.msg = str(msg) + self.orig_msg = str(msg) self.traceback = tb self.expected = expected self.cause = cause @@ -1068,14 +1078,15 @@ def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=N super(ExtractorError, self).__init__(''.join(( format_field(ie, template='[%s] '), format_field(video_id, template='%s: '), - self.msg, + msg, format_field(cause, template=' (caused by %r)'), '' if expected else bug_reports_message()))) def format_traceback(self): - if self.traceback is None: - return None - return ''.join(traceback.format_tb(self.traceback)) + return join_nonempty( + self.traceback and ''.join(traceback.format_tb(self.traceback)), + self.cause and ''.join(traceback.format_exception(self.cause)[1:]), + delim='\n') or None class UnsupportedError(ExtractorError): @@ -1353,6 +1364,12 @@ def deflate(data): except zlib.error: return zlib.decompress(data) + @staticmethod + def brotli(data): + if not data: + return data + return compat_brotli.decompress(data) + def http_request(self, req): # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not # always respected by websites, some tend to give out URLs with non percent-encoded @@ -1369,7 +1386,7 @@ def http_request(self, req): if url != url_escaped: req = update_Request(req, url=url_escaped) - for h, v in std_headers.items(): + for h, v in self._params.get('http_headers', std_headers).items(): # Capitalize is needed because of Python bug 2275: http://bugs.python.org/issue2275 # The dict keys are capitalized because of this bug by urllib if h.capitalize() not in req.headers: @@ -1413,6 +1430,12 @@ def http_response(self, req, resp): resp = compat_urllib_request.addinfourl(gz, old_resp.headers, old_resp.url, old_resp.code) resp.msg = old_resp.msg del resp.headers['Content-encoding'] + # brotli + if resp.headers.get('Content-encoding', '') == 'br': + resp = compat_urllib_request.addinfourl( + io.BytesIO(self.brotli(resp.read())), old_resp.headers, old_resp.url, old_resp.code) + resp.msg = old_resp.msg + del resp.headers['Content-encoding'] # Percent-encode redirect URL of Location HTTP header to satisfy RFC 3986 (see # https://github.com/ytdl-org/youtube-dl/issues/6457). if 300 <= resp.code < 400: @@ -1832,7 +1855,7 @@ def subtitles_filename(filename, sub_lang, sub_format, expected_real_ext=None): def datetime_from_str(date_str, precision='auto', format='%Y%m%d'): """ Return a datetime object from a string in the format YYYYMMDD or - (now|today|date)[+-][0-9](microsecond|second|minute|hour|day|week|month|year)(s)? + (now|today|yesterday|date)[+-][0-9](microsecond|second|minute|hour|day|week|month|year)(s)? format: string date format used to return datetime object from precision: round the time portion of a datetime object. @@ -1871,13 +1894,17 @@ def datetime_from_str(date_str, precision='auto', format='%Y%m%d'): return datetime_round(datetime.datetime.strptime(date_str, format), precision) -def date_from_str(date_str, format='%Y%m%d'): +def date_from_str(date_str, format='%Y%m%d', strict=False): """ Return a datetime object from a string in the format YYYYMMDD or - (now|today|date)[+-][0-9](microsecond|second|minute|hour|day|week|month|year)(s)? + (now|today|yesterday|date)[+-][0-9](microsecond|second|minute|hour|day|week|month|year)(s)? + + If "strict", only (now|today)[+-][0-9](day|week|month|year)(s)? is allowed format: string date format used to return datetime object from """ + if strict and not re.fullmatch(r'\d{8}|(now|today)[+-]\d+(day|week|month|year)(s)?', date_str): + raise ValueError(f'Invalid date format {date_str}') return datetime_from_str(date_str, precision='microsecond', format=format).date() @@ -1924,11 +1951,11 @@ class DateRange(object): def __init__(self, start=None, end=None): """start and end must be strings in the format accepted by date""" if start is not None: - self.start = date_from_str(start) + self.start = date_from_str(start, strict=True) else: self.start = datetime.datetime.min.date() if end is not None: - self.end = date_from_str(end) + self.end = date_from_str(end, strict=True) else: self.end = datetime.datetime.max.date() if self.start > self.end: @@ -2115,37 +2142,47 @@ class OVERLAPPED(ctypes.Structure): whole_low = 0xffffffff whole_high = 0x7fffffff - def _lock_file(f, exclusive, block): # todo: block unused on win32 + def _lock_file(f, exclusive, block): overlapped = OVERLAPPED() overlapped.Offset = 0 overlapped.OffsetHigh = 0 overlapped.hEvent = 0 f._lock_file_overlapped_p = ctypes.pointer(overlapped) - handle = msvcrt.get_osfhandle(f.fileno()) - if not LockFileEx(handle, 0x2 if exclusive else 0x0, 0, - whole_low, whole_high, f._lock_file_overlapped_p): - raise OSError('Locking file failed: %r' % ctypes.FormatError()) + + if not LockFileEx(msvcrt.get_osfhandle(f.fileno()), + (0x2 if exclusive else 0x0) | (0x0 if block else 0x1), + 0, whole_low, whole_high, f._lock_file_overlapped_p): + raise BlockingIOError('Locking file failed: %r' % ctypes.FormatError()) def _unlock_file(f): assert f._lock_file_overlapped_p handle = msvcrt.get_osfhandle(f.fileno()) - if not UnlockFileEx(handle, 0, - whole_low, whole_high, f._lock_file_overlapped_p): + if not UnlockFileEx(handle, 0, whole_low, whole_high, f._lock_file_overlapped_p): raise OSError('Unlocking file failed: %r' % ctypes.FormatError()) else: - # Some platforms, such as Jython, is missing fcntl try: import fcntl def _lock_file(f, exclusive, block): - fcntl.flock(f, - fcntl.LOCK_SH if not exclusive - else fcntl.LOCK_EX if block - else fcntl.LOCK_EX | fcntl.LOCK_NB) + try: + fcntl.flock(f, + fcntl.LOCK_SH if not exclusive + else fcntl.LOCK_EX if block + else fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + raise + except OSError: # AOSP does not have flock() + fcntl.lockf(f, + fcntl.LOCK_SH if not exclusive + else fcntl.LOCK_EX if block + else fcntl.LOCK_EX | fcntl.LOCK_NB) def _unlock_file(f): - fcntl.flock(f, fcntl.LOCK_UN) + try: + fcntl.flock(f, fcntl.LOCK_UN) + except OSError: + fcntl.lockf(f, fcntl.LOCK_UN) except ImportError: UNSUPPORTED_MSG = 'file locking is not supported on this platform' @@ -2158,6 +2195,8 @@ def _unlock_file(f): class locked_file(object): + _closed = False + def __init__(self, filename, mode, block=True, encoding=None): assert mode in ['r', 'rb', 'a', 'ab', 'w', 'wb'] self.f = io.open(filename, mode, encoding=encoding) @@ -2175,9 +2214,11 @@ def __enter__(self): def __exit__(self, etype, value, traceback): try: - _unlock_file(self.f) + if not self._closed: + _unlock_file(self.f) finally: self.f.close() + self._closed = True def __iter__(self): return iter(self.f) @@ -2236,7 +2277,7 @@ def unsmuggle_url(smug_url, default=None): def format_decimal_suffix(num, fmt='%d%s', *, factor=1000): """ Formats numbers with decimal sufixes like K, M, etc """ num, factor = float_or_none(num), float(factor) - if num is None: + if num is None or num < 0: return None exponent = 0 if num == 0 else int(math.log(num, factor)) suffix = ['', *'kMGTPEZY'][exponent] @@ -2548,6 +2589,13 @@ def url_or_none(url): return url if re.match(r'^(?:(?:https?|rt(?:m(?:pt?[es]?|fp)|sp[su]?)|mms|ftps?):)?//', url) else None +def request_to_url(req): + if isinstance(req, compat_urllib_request.Request): + return req.get_full_url() + else: + return req + + def strftime_or_none(timestamp, date_format, default=None): datetime_object = None try: @@ -2785,13 +2833,14 @@ def __len__(self): def __init__(self, pagefunc, pagesize, use_cache=True): self._pagefunc = pagefunc self._pagesize = pagesize + self._pagecount = float('inf') self._use_cache = use_cache self._cache = {} def getpage(self, pagenum): page_results = self._cache.get(pagenum) if page_results is None: - page_results = list(self._pagefunc(pagenum)) + page_results = [] if pagenum > self._pagecount else list(self._pagefunc(pagenum)) if self._use_cache: self._cache[pagenum] = page_results return page_results @@ -2803,7 +2852,7 @@ def _getslice(self, start, end): raise NotImplementedError('This method must be implemented by subclasses') def __getitem__(self, idx): - # NOTE: cache must be enabled if this is used + assert self._use_cache, 'Indexing PagedList requires cache' if not isinstance(idx, int) or idx < 0: raise TypeError('indices must be non-negative integers') entries = self.getslice(idx, idx + 1) @@ -2829,7 +2878,11 @@ def _getslice(self, start, end): if (end is not None and firstid <= end <= nextfirstid) else None) - page_results = self.getpage(pagenum) + try: + page_results = self.getpage(pagenum) + except Exception: + self._pagecount = pagenum - 1 + raise if startv != 0 or endv is not None: page_results = page_results[startv:endv] yield from page_results @@ -2849,8 +2902,8 @@ def _getslice(self, start, end): class InAdvancePagedList(PagedList): def __init__(self, pagefunc, pagecount, pagesize): - self._pagecount = pagecount PagedList.__init__(self, pagefunc, pagesize, True) + self._pagecount = pagecount def _getslice(self, start, end): start_page = start // self._pagesize @@ -3137,6 +3190,8 @@ def fix_kv(m): return '"%s"' % v + code = re.sub(r'new Date\((".+")\)', r'\g<1>', code) + return re.sub(r'''(?sx) "(?:[^"\\]*(?:\\\\|\\['"nurtbfx/\n]))*[^"\\]*"| '(?:[^'\\]*(?:\\\\|\\['"nurtbfx/\n]))*[^'\\]*'| @@ -3158,7 +3213,7 @@ def q(qid): return q -POSTPROCESS_WHEN = {'pre_process', 'before_dl', 'after_move', 'post_process', 'after_video', 'playlist'} +POSTPROCESS_WHEN = {'pre_process', 'after_filter', 'before_dl', 'after_move', 'post_process', 'after_video', 'playlist'} DEFAULT_OUTTMPL = { @@ -3450,7 +3505,7 @@ def filter_using_list(row, filterArray): extra_gap += 1 if delim: table = [header_row, [delim * (ml + extra_gap) for ml in max_lens]] + data - table[1][-1] = table[1][-1][:-extra_gap] # Remove extra_gap from end of delimiter + table[1][-1] = table[1][-1][:-extra_gap * len(delim)] # Remove extra_gap from end of delimiter for row in table: for pos, text in enumerate(map(str, row)): if '\t' in text: @@ -3548,6 +3603,9 @@ def match_str(filter_str, dct, incomplete=False): def match_filter_func(filter_str): + if filter_str is None: + return None + def _match_func(info_dict, *args, **kwargs): if match_str(filter_str, info_dict, *args, **kwargs): return None @@ -5160,10 +5218,30 @@ def traverse_dict(dictn, keys, casesense=True): return traverse_obj(dictn, keys, casesense=casesense, is_user_input=True, traverse_string=True) +def get_first(obj, keys, **kwargs): + return traverse_obj(obj, (..., *variadic(keys)), **kwargs, get_all=False) + + def variadic(x, allowed_types=(str, bytes, dict)): return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,) +def decode_base(value, digits): + # This will convert given base-x string to scalar (long or int) + table = {char: index for index, char in enumerate(digits)} + result = 0 + base = len(digits) + for chr in value: + result *= base + result += table[chr] + return result + + +def time_seconds(**kwargs): + t = datetime.datetime.now(datetime.timezone(datetime.timedelta(**kwargs))) + return t.timestamp() + + # create a JSON Web Signature (jws) with HS256 algorithm # the resulting format is in JWS Compact Serialization # implemented following JWT https://www.rfc-editor.org/rfc/rfc7519.html @@ -5220,6 +5298,38 @@ def join_nonempty(*values, delim='-', from_dict=None): return delim.join(map(str, filter(None, values))) +def scale_thumbnails_to_max_format_width(formats, thumbnails, url_width_re): + """ + Find the largest format dimensions in terms of video width and, for each thumbnail: + * Modify the URL: Match the width with the provided regex and replace with the former width + * Update dimensions + + This function is useful with video services that scale the provided thumbnails on demand + """ + _keys = ('width', 'height') + max_dimensions = max( + [tuple(format.get(k) or 0 for k in _keys) for format in formats], + default=(0, 0)) + if not max_dimensions[0]: + return thumbnails + return [ + merge_dicts( + {'url': re.sub(url_width_re, str(max_dimensions[0]), thumbnail['url'])}, + dict(zip(_keys, max_dimensions)), thumbnail) + for thumbnail in thumbnails + ] + + +def parse_http_range(range): + """ Parse value of "Range" or "Content-Range" HTTP header into tuple. """ + if not range: + return None, None, None + crg = re.search(r'bytes[ =](\d+)-(\d+)?(?:/(\d+))?', range) + if not crg: + return None, None, None + return int(crg.group(1)), int_or_none(crg.group(2)), int_or_none(crg.group(3)) + + class Config: own_args = None filename = None @@ -5307,3 +5417,76 @@ def all_args(self): def parse_args(self): return self._parser.parse_args(list(self.all_args)) + + +class WebSocketsWrapper(): + """Wraps websockets module to use in non-async scopes""" + + def __init__(self, url, headers=None): + self.loop = asyncio.events.new_event_loop() + self.conn = compat_websockets.connect( + url, extra_headers=headers, ping_interval=None, + close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf')) + atexit.register(self.__exit__, None, None, None) + + def __enter__(self): + self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop) + return self + + def send(self, *args): + self.run_with_loop(self.pool.send(*args), self.loop) + + def recv(self, *args): + return self.run_with_loop(self.pool.recv(*args), self.loop) + + def __exit__(self, type, value, traceback): + try: + return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop) + finally: + self.loop.close() + self._cancel_all_tasks(self.loop) + + # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications + # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class + @staticmethod + def run_with_loop(main, loop): + if not asyncio.coroutines.iscoroutine(main): + raise ValueError(f'a coroutine was expected, got {main!r}') + + try: + return loop.run_until_complete(main) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + if hasattr(loop, 'shutdown_default_executor'): + loop.run_until_complete(loop.shutdown_default_executor()) + + @staticmethod + def _cancel_all_tasks(loop): + to_cancel = asyncio.tasks.all_tasks(loop) + + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete( + asyncio.tasks.gather(*to_cancel, loop=loop, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler({ + 'message': 'unhandled exception during asyncio.run() shutdown', + 'exception': task.exception(), + 'task': task, + }) + + +has_websockets = bool(compat_websockets) + + +def merge_headers(*dicts): + """Merge dicts of http headers case insensitively, prioritizing the latter ones""" + return {k.capitalize(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}