X-Git-Url: https://jfr.im/git/yt-dlp.git/blobdiff_plain/6586bca9b9a3d30e3e76ee27bcd98ea5c8c7a57f..34921b43451a23d8cd7350f8511269bdfd35cf61:/yt_dlp/utils.py diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 4d3cbc7b4..75b4ed61b 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -16,7 +16,9 @@ import errno import functools import gzip -import imp +import hashlib +import hmac +import importlib.util import io import itertools import json @@ -1740,9 +1742,12 @@ def random_user_agent(): '%b %dth %Y %I:%M', '%Y %m %d', '%Y-%m-%d', + '%Y.%m.%d.', '%Y/%m/%d', '%Y/%m/%d %H:%M', '%Y/%m/%d %H:%M:%S', + '%Y%m%d%H%M', + '%Y%m%d%H%M%S', '%Y-%m-%d %H:%M', '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M:%S.%f', @@ -1759,6 +1764,7 @@ def random_user_agent(): '%b %d %Y at %H:%M:%S', '%B %d %Y at %H:%M', '%B %d %Y at %H:%M:%S', + '%H:%M %d-%b-%Y', ) DATE_FORMATS_DAY_FIRST = list(DATE_FORMATS) @@ -1836,7 +1842,7 @@ def write_json_file(obj, fn): try: with tf: - json.dump(obj, tf, default=repr) + json.dump(obj, tf) if sys.platform == 'win32': # Need to remove existing file on Windows, else os.rename raises # WindowsError or FileExistsError. @@ -2000,6 +2006,23 @@ def handle_starttag(self, tag, attrs): self.attrs = dict(attrs) +class HTMLListAttrsParser(compat_HTMLParser): + """HTML parser to gather the attributes for the elements of a list""" + + def __init__(self): + compat_HTMLParser.__init__(self) + self.items = [] + self._level = 0 + + def handle_starttag(self, tag, attrs): + if tag == 'li' and self._level == 0: + self.items.append(dict(attrs)) + self._level += 1 + + def handle_endtag(self, tag): + self._level -= 1 + + def extract_attributes(html_element): """Given a string for an HTML element such as elements, + return a dictionary of their attributes""" + parser = HTMLListAttrsParser() + parser.feed(webpage) + parser.close() + return parser.items + + def clean_html(html): """Clean an HTML snippet into a readable string""" @@ -2093,7 +2125,9 @@ def sanitize_filename(s, restricted=False, is_id=False): def replace_insane(char): if restricted and char in ACCENT_CHARS: return ACCENT_CHARS[char] - if char == '?' or ord(char) < 32 or ord(char) == 127: + elif not restricted and char == '\n': + return ' ' + elif char == '?' or ord(char) < 32 or ord(char) == 127: return '' elif char == '"': return '' if restricted else '\'' @@ -2264,6 +2298,20 @@ def process_communicate_or_kill(p, *args, **kwargs): raise +class Popen(subprocess.Popen): + if sys.platform == 'win32': + _startupinfo = subprocess.STARTUPINFO() + _startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + else: + _startupinfo = None + + def __init__(self, *args, **kwargs): + super(Popen, self).__init__(*args, **kwargs, startupinfo=self._startupinfo) + + def communicate_or_kill(self, *args, **kwargs): + return process_communicate_or_kill(self, *args, **kwargs) + + def get_subprocess_encoding(): if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5: # For subprocess calls, encode with locale encoding @@ -2334,39 +2382,63 @@ def decodeOption(optval): return optval +_timetuple = collections.namedtuple('Time', ('hours', 'minutes', 'seconds', 'milliseconds')) + + +def timetuple_from_msec(msec): + secs, msec = divmod(msec, 1000) + mins, secs = divmod(secs, 60) + hrs, mins = divmod(mins, 60) + return _timetuple(hrs, mins, secs, msec) + + def formatSeconds(secs, delim=':', msec=False): - if secs > 3600: - ret = '%d%s%02d%s%02d' % (secs // 3600, delim, (secs % 3600) // 60, delim, secs % 60) - elif secs > 60: - ret = '%d%s%02d' % (secs // 60, delim, secs % 60) + time = timetuple_from_msec(secs * 1000) + if time.hours: + ret = '%d%s%02d%s%02d' % (time.hours, delim, time.minutes, delim, time.seconds) + elif time.minutes: + ret = '%d%s%02d' % (time.minutes, delim, time.seconds) else: - ret = '%d' % secs - return '%s.%03d' % (ret, secs % 1) if msec else ret + ret = '%d' % time.seconds + return '%s.%03d' % (ret, time.milliseconds) if msec else ret -def make_HTTPS_handler(params, **kwargs): - opts_no_check_certificate = params.get('nocheckcertificate', False) - if hasattr(ssl, 'create_default_context'): # Python >= 3.4 or 2.7.9 - context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) - if opts_no_check_certificate: - context.check_hostname = False - context.verify_mode = ssl.CERT_NONE +def _ssl_load_windows_store_certs(ssl_context, storename): + # Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py + try: + certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename) + if encoding == 'x509_asn' and ( + trust is True or ssl.Purpose.SERVER_AUTH.oid in trust)] + except PermissionError: + return + for cert in certs: try: - return YoutubeDLHTTPSHandler(params, context=context, **kwargs) - except TypeError: - # Python 2.7.8 - # (create_default_context present but HTTPSHandler has no context=) + ssl_context.load_verify_locations(cadata=cert) + except ssl.SSLError: pass - if sys.version_info < (3, 2): - return YoutubeDLHTTPSHandler(params, **kwargs) - else: # Python < 3.4 - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = (ssl.CERT_NONE - if opts_no_check_certificate - else ssl.CERT_REQUIRED) - context.set_default_verify_paths() - return YoutubeDLHTTPSHandler(params, context=context, **kwargs) + +def make_HTTPS_handler(params, **kwargs): + opts_check_certificate = not params.get('nocheckcertificate') + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = opts_check_certificate + context.verify_mode = ssl.CERT_REQUIRED if opts_check_certificate else ssl.CERT_NONE + if opts_check_certificate: + try: + context.load_default_certs() + # Work around the issue in load_default_certs when there are bad certificates. See: + # https://github.com/yt-dlp/yt-dlp/issues/1060, + # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312 + except ssl.SSLError: + # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151 + if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'): + # Create a new context to discard any certificates that were already loaded + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname, context.verify_mode = True, ssl.CERT_REQUIRED + for storename in ('CA', 'ROOT'): + _ssl_load_windows_store_certs(context, storename) + context.set_default_verify_paths() + return YoutubeDLHTTPSHandler(params, context=context, **kwargs) def bug_reports_message(before=';'): @@ -2399,25 +2471,27 @@ class YoutubeDLError(Exception): class ExtractorError(YoutubeDLError): """Error during info extraction.""" - def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None): + def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=None): """ tb, if given, is the original traceback (so that it can be printed out). If expected is set, this is a normal error message and most likely not a bug in yt-dlp. """ - if sys.exc_info()[0] in network_exceptions: expected = True - if video_id is not None: - msg = video_id + ': ' + msg - if cause: - msg += ' (caused by %r)' % cause - if not expected: - msg += bug_reports_message() - super(ExtractorError, self).__init__(msg) + self.msg = str(msg) self.traceback = tb - self.exc_info = sys.exc_info() # preserve original exception + self.expected = expected self.cause = cause self.video_id = video_id + self.ie = ie + self.exc_info = sys.exc_info() # preserve original exception + + super(ExtractorError, self).__init__(''.join(( + format_field(ie, template='[%s] '), + format_field(video_id, template='%s: '), + self.msg, + format_field(cause, template=' (caused by %r)'), + '' if expected else bug_reports_message()))) def format_traceback(self): if self.traceback is None: @@ -2444,9 +2518,9 @@ class GeoRestrictedError(ExtractorError): geographic location due to geographic restrictions imposed by a website. """ - def __init__(self, msg, countries=None): - super(GeoRestrictedError, self).__init__(msg, expected=True) - self.msg = msg + def __init__(self, msg, countries=None, **kwargs): + kwargs['expected'] = True + super(GeoRestrictedError, self).__init__(msg, **kwargs) self.countries = countries @@ -2494,23 +2568,33 @@ def __init__(self, msg): self.msg = msg -class ExistingVideoReached(YoutubeDLError): - """ --max-downloads limit has been reached. """ - pass +class DownloadCancelled(YoutubeDLError): + """ Exception raised when the download queue should be interrupted """ + msg = 'The download was cancelled' + def __init__(self, msg=None): + if msg is not None: + self.msg = msg + YoutubeDLError.__init__(self, self.msg) -class RejectedVideoReached(YoutubeDLError): - """ --max-downloads limit has been reached. """ - pass +class ExistingVideoReached(DownloadCancelled): + """ --break-on-existing triggered """ + msg = 'Encountered a video that is already in the archive, stopping due to --break-on-existing' -class ThrottledDownload(YoutubeDLError): - """ Download speed below --throttled-rate. """ - pass + +class RejectedVideoReached(DownloadCancelled): + """ --break-on-reject triggered """ + msg = 'Encountered a video that did not match filter, stopping due to --break-on-reject' -class MaxDownloadsReached(YoutubeDLError): +class MaxDownloadsReached(DownloadCancelled): """ --max-downloads limit has been reached. """ + msg = 'Maximum number of downloads reached, stopping due to --max-downloads' + + +class ThrottledDownload(YoutubeDLError): + """ Download speed below --throttled-rate. """ pass @@ -3029,8 +3113,16 @@ def redirect_request(self, req, fp, code, msg, headers, newurl): def extract_timezone(date_str): m = re.search( - r'^.{8,}?(?PZ$| ?(?P\+|-)(?P[0-9]{2}):?(?P[0-9]{2})$)', - date_str) + r'''(?x) + ^.{8,}? # >=8 char non-TZ prefix, if present + (?PZ| # just the UTC Z, or + (?:(?<=.\b\d{4}|\b\d{2}:\d\d)| # preceded by 4 digits or hh:mm or + (?= 4 alpha or 2 digits + [ ]? # optional space + (?P\+|-) # +/- + (?P[0-9]{2}):?(?P[0-9]{2}) # hh[:]mm + $) + ''', date_str) if not m: timezone = datetime.timedelta() else: @@ -3276,6 +3368,14 @@ def platform_name(): return res +def get_windows_version(): + ''' Get Windows version. None if it's not running on Windows ''' + if compat_os_name == 'nt': + return version_tuple(platform.win32_ver()[1]) + else: + return None + + def _windows_write_string(s, out): """ Returns True if the string was written using special methods, False if it has yet to be written out.""" @@ -3650,14 +3750,14 @@ def parse_resolution(s): if s is None: return {} - mobj = re.search(r'\b(?P\d+)\s*[xX×]\s*(?P\d+)\b', s) + mobj = re.search(r'(?\d+)\s*[xX×,]\s*(?P\d+)(?![a-zA-Z0-9])', s) if mobj: return { 'width': int(mobj.group('w')), 'height': int(mobj.group('h')), } - mobj = re.search(r'\b(\d+)[pPiI]\b', s) + mobj = re.search(r'(? 0 else -1 - stop = idx.stop if idx.stop is not None else -1 if step > 0 else 0 if self.__reversed: - (start, stop), step = map(self.__reverse_index, (start, stop)), -step - idx = slice(start, stop, step) + idx = slice(self.__reverse_index(idx.start), self.__reverse_index(idx.stop), -(idx.step or 1)) + start, stop, step = idx.start, idx.stop, idx.step or 1 elif isinstance(idx, int): if self.__reversed: idx = self.__reverse_index(idx) - start = stop = idx + start, stop, step = idx, idx, 0 else: raise TypeError('indices must be integers or slices') - if start < 0 or stop < 0: + if ((start or 0) < 0 or (stop or 0) < 0 + or (start is None and step < 0) + or (stop is None and step > 0)): # We need to consume the entire iterable to be able to slice from the end # Obviously, never use this with infinite iterables - return self.__exhaust()[idx] - - n = max(start, stop) - len(self.__cache) + 1 + self.__exhaust() + try: + return self.__cache[idx] + except IndexError as e: + raise self.IndexError(e) from e + n = max(start or 0, stop or 0) - len(self.__cache) + 1 if n > 0: self.__cache.extend(itertools.islice(self.__iterable, n)) - return self.__cache[idx] + try: + return self.__cache[idx] + except IndexError as e: + raise self.IndexError(e) from e def __bool__(self): try: self[-1] if self.__reversed else self[0] - except IndexError: + except self.IndexError: return False return True def __len__(self): - self.exhaust() + self.__exhaust() return len(self.__cache) def reverse(self): @@ -4042,15 +4155,31 @@ def __str__(self): return repr(self.exhaust()) -class PagedList(object): +class PagedList: def __len__(self): # This is only useful for tests return len(self.getslice()) - def getslice(self, start, end): + def __init__(self, pagefunc, pagesize, use_cache=True): + self._pagefunc = pagefunc + self._pagesize = pagesize + self._use_cache = use_cache + self._cache = {} + + def getpage(self, pagenum): + page_results = self._cache.get(pagenum) or list(self._pagefunc(pagenum)) + if self._use_cache: + self._cache[pagenum] = page_results + return page_results + + def getslice(self, start=0, end=None): + return list(self._getslice(start, end)) + + def _getslice(self, start, end): raise NotImplementedError('This method must be implemented by subclasses') def __getitem__(self, idx): + # NOTE: cache must be enabled if this is used if not isinstance(idx, int) or idx < 0: raise TypeError('indices must be non-negative integers') entries = self.getslice(idx, idx + 1) @@ -4058,42 +4187,26 @@ def __getitem__(self, idx): class OnDemandPagedList(PagedList): - def __init__(self, pagefunc, pagesize, use_cache=True): - self._pagefunc = pagefunc - self._pagesize = pagesize - self._use_cache = use_cache - if use_cache: - self._cache = {} - - def getslice(self, start=0, end=None): - res = [] + def _getslice(self, start, end): for pagenum in itertools.count(start // self._pagesize): firstid = pagenum * self._pagesize nextfirstid = pagenum * self._pagesize + self._pagesize if start >= nextfirstid: continue - page_results = None - if self._use_cache: - page_results = self._cache.get(pagenum) - if page_results is None: - page_results = list(self._pagefunc(pagenum)) - if self._use_cache: - self._cache[pagenum] = page_results - startv = ( start % self._pagesize if firstid <= start < nextfirstid else 0) - endv = ( ((end - 1) % self._pagesize) + 1 if (end is not None and firstid <= end <= nextfirstid) else None) + page_results = self.getpage(pagenum) if startv != 0 or endv is not None: page_results = page_results[startv:endv] - res.extend(page_results) + yield from page_results # A little optimization - if current page is not "full", ie. does # not contain page_size videos then we can assume that this page @@ -4106,36 +4219,31 @@ def getslice(self, start=0, end=None): # break out early as well if end == nextfirstid: break - return res class InAdvancePagedList(PagedList): def __init__(self, pagefunc, pagecount, pagesize): - self._pagefunc = pagefunc self._pagecount = pagecount - self._pagesize = pagesize + PagedList.__init__(self, pagefunc, pagesize, True) - def getslice(self, start=0, end=None): - res = [] + def _getslice(self, start, end): start_page = start // self._pagesize end_page = ( self._pagecount if end is None else (end // self._pagesize + 1)) skip_elems = start - start_page * self._pagesize only_more = None if end is None else end - start for pagenum in range(start_page, end_page): - page = list(self._pagefunc(pagenum)) + page_results = self.getpage(pagenum) if skip_elems: - page = page[skip_elems:] + page_results = page_results[skip_elems:] skip_elems = None if only_more is not None: - if len(page) < only_more: - only_more -= len(page) + if len(page_results) < only_more: + only_more -= len(page_results) else: - page = page[:only_more] - res.extend(page) + yield from page_results[:only_more] break - res.extend(page) - return res + yield from page_results def uppercase_escape(s): @@ -4173,6 +4281,10 @@ def escape_url(url): ).geturl() +def parse_qs(url): + return compat_parse_qs(compat_urllib_parse_urlparse(url).query) + + def read_batch_urls(batch_fd): def fixup(url): if not isinstance(url, compat_str): @@ -4376,6 +4488,8 @@ def fix_kv(m): v = m.group(0) if v in ('true', 'false', 'null'): return v + elif v in ('undefined', 'void 0'): + return 'null' elif v.startswith('/*') or v.startswith('//') or v.startswith('!') or v == ',': return "" @@ -4402,7 +4516,7 @@ def fix_kv(m): "(?:[^"\\]*(?:\\\\|\\['"nurtbfx/\n]))*[^"\\]*"| '(?:[^'\\]*(?:\\\\|\\['"nurtbfx/\n]))*[^'\\]*'| {comment}|,(?={skip}[\]}}])| - (?:(?(?:%%)*) % - (?P\((?P{0})\))? # mapping key + (?P\((?P{0})\))? (?P - (?:[#0\-+ ]+)? # conversion flags (optional) - (?:\d+)? # minimum field width (optional) - (?:\.\d+)? # precision (optional) - [hlL]? # length modifier (optional) - [diouxXeEfFgGcrs] # conversion type + (?P[#0\-+ ]+)? + (?P\d+)? + (?P\.\d+)? + (?P[hlL])? # unused in python + {1} # conversion type ) ''' +STR_FORMAT_TYPES = 'diouxXeEfFgGcrs' + + def limit_length(s, length): """ Add ellipses to overly long strings """ if s is None: @@ -4477,11 +4595,10 @@ def is_outdated_version(version, limit, assume_new=True): def ytdl_is_updateable(): """ Returns if yt-dlp can be updated with -U """ - return False - from zipimport import zipimporter + from .update import is_non_updateable - return isinstance(globals().get('__loader__'), zipimporter) or hasattr(sys, 'frozen') + return not is_non_updateable() def args_to_str(args): @@ -4502,20 +4619,24 @@ def mimetype2ext(mt): if mt is None: return None - ext = { + mt, _, params = mt.partition(';') + mt = mt.strip() + + FULL_MAP = { 'audio/mp4': 'm4a', # Per RFC 3003, audio/mpeg can be .mp1, .mp2 or .mp3. Here use .mp3 as # it's the most popular one 'audio/mpeg': 'mp3', 'audio/x-wav': 'wav', - }.get(mt) + 'audio/wav': 'wav', + 'audio/wave': 'wav', + } + + ext = FULL_MAP.get(mt) if ext is not None: return ext - _, _, res = mt.rpartition('/') - res = res.split(';')[0].strip().lower() - - return { + SUBTYPE_MAP = { '3gpp': '3gp', 'smptett+xml': 'tt', 'ttaf+xml': 'dfxp', @@ -4534,7 +4655,28 @@ def mimetype2ext(mt): 'quicktime': 'mov', 'mp2t': 'ts', 'x-wav': 'wav', - }.get(res, res) + 'filmstrip+json': 'fs', + 'svg+xml': 'svg', + } + + _, _, subtype = mt.rpartition('/') + ext = SUBTYPE_MAP.get(subtype.lower()) + if ext is not None: + return ext + + SUFFIX_MAP = { + 'json': 'json', + 'xml': 'xml', + 'zip': 'zip', + 'gzip': 'gz', + } + + _, _, suffix = subtype.partition('+') + ext = SUFFIX_MAP.get(suffix) + if ext is not None: + return ext + + return subtype.replace('+', '.') def parse_codecs(codecs_str): @@ -4542,13 +4684,21 @@ def parse_codecs(codecs_str): if not codecs_str: return {} split_codecs = list(filter(None, map( - lambda str: str.strip(), codecs_str.strip().strip(',').split(',')))) - vcodec, acodec = None, None + str.strip, codecs_str.strip().strip(',').split(',')))) + vcodec, acodec, hdr = None, None, None for full_codec in split_codecs: - codec = full_codec.split('.')[0] - if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2', 'h263', 'h264', 'mp4v', 'hvc1', 'av01', 'theora'): + parts = full_codec.split('.') + codec = parts[0].replace('0', '') + if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2', + 'h263', 'h264', 'mp4v', 'hvc1', 'av1', 'theora', 'dvh1', 'dvhe'): if not vcodec: - vcodec = full_codec + vcodec = '.'.join(parts[:4]) if codec in ('vp9', 'av1') else full_codec + if codec in ('dvh1', 'dvhe'): + hdr = 'DV' + elif codec == 'av1' and len(parts) > 3 and parts[3] == '10': + hdr = 'HDR10' + elif full_codec.replace('0', '').startswith('vp9.2'): + hdr = 'HDR10' elif codec in ('mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-3', 'ec-3', 'eac3', 'dtsc', 'dtse', 'dtsh', 'dtsl'): if not acodec: acodec = full_codec @@ -4564,6 +4714,7 @@ def parse_codecs(codecs_str): return { 'vcodec': vcodec or 'none', 'acodec': acodec or 'none', + 'dynamic_range': hdr, } return {} @@ -4621,7 +4772,7 @@ def determine_protocol(info_dict): if protocol is not None: return protocol - url = info_dict['url'] + url = sanitize_url(info_dict['url']) if url.startswith('rtmp'): return 'rtmp' elif url.startswith('mms'): @@ -4640,9 +4791,11 @@ def determine_protocol(info_dict): def render_table(header_row, data, delim=False, extraGap=0, hideEmpty=False): """ Render a list of rows, each as a list of values """ + def width(string): + return len(remove_terminal_sequences(string)) def get_max_lens(table): - return [max(len(compat_str(v)) for v in col) for col in zip(*table)] + return [max(width(str(v)) for v in col) for col in zip(*table)] def filter_using_list(row, filterArray): return [col for (take, col) in zip(filterArray, row) if take] @@ -4654,64 +4807,74 @@ def filter_using_list(row, filterArray): table = [header_row] + data max_lens = get_max_lens(table) + extraGap += 1 if delim: - table = [header_row] + [['-' * ml for ml in max_lens]] + data - format_str = ' '.join('%-' + compat_str(ml + extraGap) + 's' for ml in max_lens[:-1]) + ' %s' - return '\n'.join(format_str % tuple(row) for row in table) + table = [header_row] + [[delim * (ml + extraGap) for ml in max_lens]] + data + max_lens[-1] = 0 + for row in table: + for pos, text in enumerate(map(str, row)): + row[pos] = text + (' ' * (max_lens[pos] - width(text) + extraGap)) + ret = '\n'.join(''.join(row) for row in table) + return ret -def _match_one(filter_part, dct): +def _match_one(filter_part, dct, incomplete): + # TODO: Generalize code with YoutubeDL._build_format_filter + STRING_OPERATORS = { + '*=': operator.contains, + '^=': lambda attr, value: attr.startswith(value), + '$=': lambda attr, value: attr.endswith(value), + '~=': lambda attr, value: re.search(value, attr), + } COMPARISON_OPERATORS = { + **STRING_OPERATORS, + '<=': operator.le, # "<=" must be defined above "<" '<': operator.lt, - '<=': operator.le, - '>': operator.gt, '>=': operator.ge, + '>': operator.gt, '=': operator.eq, - '!=': operator.ne, } + operator_rex = re.compile(r'''(?x)\s* (?P[a-z_]+) - \s*(?P%s)(?P\s*\?)?\s* + \s*(?P!\s*)?(?P%s)(?P\s*\?)?\s* (?: - (?P[0-9.]+(?:[kKmMgGtTpPeEzZyY]i?[Bb]?)?)| - (?P["\'])(?P(?:\\.|(?!(?P=quote)|\\).)+?)(?P=quote)| - (?P(?![0-9.])[a-z0-9A-Z]*) + (?P["\'])(?P.+?)(?P=quote)| + (?P.+?) ) \s*$ ''' % '|'.join(map(re.escape, COMPARISON_OPERATORS.keys()))) m = operator_rex.search(filter_part) if m: - op = COMPARISON_OPERATORS[m.group('op')] - actual_value = dct.get(m.group('key')) - if (m.group('quotedstrval') is not None - or m.group('strval') is not None + m = m.groupdict() + unnegated_op = COMPARISON_OPERATORS[m['op']] + if m['negation']: + op = lambda attr, value: not unnegated_op(attr, value) + else: + op = unnegated_op + comparison_value = m['quotedstrval'] or m['strval'] or m['intval'] + if m['quote']: + comparison_value = comparison_value.replace(r'\%s' % m['quote'], m['quote']) + actual_value = dct.get(m['key']) + numeric_comparison = None + if isinstance(actual_value, compat_numeric_types): # If the original field is a string and matching comparisonvalue is # a number we should respect the origin of the original field # and process comparison value as a string (see - # https://github.com/ytdl-org/youtube-dl/issues/11082). - or actual_value is not None and m.group('intval') is not None - and isinstance(actual_value, compat_str)): - if m.group('op') not in ('=', '!='): - raise ValueError( - 'Operator %s does not support string values!' % m.group('op')) - comparison_value = m.group('quotedstrval') or m.group('strval') or m.group('intval') - quote = m.group('quote') - if quote is not None: - comparison_value = comparison_value.replace(r'\%s' % quote, quote) - else: + # https://github.com/ytdl-org/youtube-dl/issues/11082) try: - comparison_value = int(m.group('intval')) + numeric_comparison = int(comparison_value) except ValueError: - comparison_value = parse_filesize(m.group('intval')) - if comparison_value is None: - comparison_value = parse_filesize(m.group('intval') + 'B') - if comparison_value is None: - raise ValueError( - 'Invalid integer value %r in filter part %r' % ( - m.group('intval'), filter_part)) + numeric_comparison = parse_filesize(comparison_value) + if numeric_comparison is None: + numeric_comparison = parse_filesize(f'{comparison_value}B') + if numeric_comparison is None: + numeric_comparison = parse_duration(comparison_value) + if numeric_comparison is not None and m['op'] in STRING_OPERATORS: + raise ValueError('Operator %s only supports string values!' % m['op']) if actual_value is None: - return m.group('none_inclusive') - return op(actual_value, comparison_value) + return incomplete or m['none_inclusive'] + return op(actual_value, comparison_value if numeric_comparison is None else numeric_comparison) UNARY_OPERATORS = { '': lambda v: (v is True) if isinstance(v, bool) else (v is not None), @@ -4725,21 +4888,25 @@ def _match_one(filter_part, dct): if m: op = UNARY_OPERATORS[m.group('op')] actual_value = dct.get(m.group('key')) + if incomplete and actual_value is None: + return True return op(actual_value) raise ValueError('Invalid filter part %r' % filter_part) -def match_str(filter_str, dct): - """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false """ - +def match_str(filter_str, dct, incomplete=False): + """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false + When incomplete, all conditions passes on missing fields + """ return all( - _match_one(filter_part, dct) for filter_part in filter_str.split('&')) + _match_one(filter_part.replace(r'\&', '&'), dct, incomplete) + for filter_part in re.split(r'(?