]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils.py
Let `--match-filter` reject entries early
[yt-dlp.git] / yt_dlp / utils.py
index 4f584596d1e19f9e975341c6fb45c6628fa1922d..6276ac726be379f628e056b8e76b5d7b8523d0b8 100644 (file)
@@ -1836,7 +1836,7 @@ def write_json_file(obj, fn):
 
     try:
         with tf:
-            json.dump(obj, tf, default=repr)
+            json.dump(obj, tf)
         if sys.platform == 'win32':
             # Need to remove existing file on Windows, else os.rename raises
             # WindowsError or FileExistsError.
@@ -4041,15 +4041,31 @@ def __str__(self):
         return repr(self.exhaust())
 
 
-class PagedList(object):
+class PagedList:
     def __len__(self):
         # This is only useful for tests
         return len(self.getslice())
 
-    def getslice(self, start, end):
+    def __init__(self, pagefunc, pagesize, use_cache=True):
+        self._pagefunc = pagefunc
+        self._pagesize = pagesize
+        self._use_cache = use_cache
+        self._cache = {}
+
+    def getpage(self, pagenum):
+        page_results = self._cache.get(pagenum) or list(self._pagefunc(pagenum))
+        if self._use_cache:
+            self._cache[pagenum] = page_results
+        return page_results
+
+    def getslice(self, start=0, end=None):
+        return list(self._getslice(start, end))
+
+    def _getslice(self, start, end):
         raise NotImplementedError('This method must be implemented by subclasses')
 
     def __getitem__(self, idx):
+        # NOTE: cache must be enabled if this is used
         if not isinstance(idx, int) or idx < 0:
             raise TypeError('indices must be non-negative integers')
         entries = self.getslice(idx, idx + 1)
@@ -4057,42 +4073,26 @@ def __getitem__(self, idx):
 
 
 class OnDemandPagedList(PagedList):
-    def __init__(self, pagefunc, pagesize, use_cache=True):
-        self._pagefunc = pagefunc
-        self._pagesize = pagesize
-        self._use_cache = use_cache
-        if use_cache:
-            self._cache = {}
-
-    def getslice(self, start=0, end=None):
-        res = []
+    def _getslice(self, start, end):
         for pagenum in itertools.count(start // self._pagesize):
             firstid = pagenum * self._pagesize
             nextfirstid = pagenum * self._pagesize + self._pagesize
             if start >= nextfirstid:
                 continue
 
-            page_results = None
-            if self._use_cache:
-                page_results = self._cache.get(pagenum)
-            if page_results is None:
-                page_results = list(self._pagefunc(pagenum))
-            if self._use_cache:
-                self._cache[pagenum] = page_results
-
             startv = (
                 start % self._pagesize
                 if firstid <= start < nextfirstid
                 else 0)
-
             endv = (
                 ((end - 1) % self._pagesize) + 1
                 if (end is not None and firstid <= end <= nextfirstid)
                 else None)
 
+            page_results = self.getpage(pagenum)
             if startv != 0 or endv is not None:
                 page_results = page_results[startv:endv]
-            res.extend(page_results)
+            yield from page_results
 
             # A little optimization - if current page is not "full", ie. does
             # not contain page_size videos then we can assume that this page
@@ -4105,36 +4105,31 @@ def getslice(self, start=0, end=None):
             # break out early as well
             if end == nextfirstid:
                 break
-        return res
 
 
 class InAdvancePagedList(PagedList):
     def __init__(self, pagefunc, pagecount, pagesize):
-        self._pagefunc = pagefunc
         self._pagecount = pagecount
-        self._pagesize = pagesize
+        PagedList.__init__(self, pagefunc, pagesize, True)
 
-    def getslice(self, start=0, end=None):
-        res = []
+    def _getslice(self, start, end):
         start_page = start // self._pagesize
         end_page = (
             self._pagecount if end is None else (end // self._pagesize + 1))
         skip_elems = start - start_page * self._pagesize
         only_more = None if end is None else end - start
         for pagenum in range(start_page, end_page):
-            page = list(self._pagefunc(pagenum))
+            page_results = self.getpage(pagenum)
             if skip_elems:
-                page = page[skip_elems:]
+                page_results = page_results[skip_elems:]
                 skip_elems = None
             if only_more is not None:
-                if len(page) < only_more:
-                    only_more -= len(page)
+                if len(page_results) < only_more:
+                    only_more -= len(page_results)
                 else:
-                    page = page[:only_more]
-                    res.extend(page)
+                    yield from page_results[:only_more]
                     break
-            res.extend(page)
-        return res
+            yield from page_results
 
 
 def uppercase_escape(s):
@@ -4544,7 +4539,7 @@ def parse_codecs(codecs_str):
     if not codecs_str:
         return {}
     split_codecs = list(filter(None, map(
-        lambda str: str.strip(), codecs_str.strip().strip(',').split(','))))
+        str.strip, codecs_str.strip().strip(',').split(','))))
     vcodec, acodec = None, None
     for full_codec in split_codecs:
         codec = full_codec.split('.')[0]
@@ -4662,28 +4657,40 @@ def filter_using_list(row, filterArray):
     return '\n'.join(format_str % tuple(row) for row in table)
 
 
-def _match_one(filter_part, dct):
+def _match_one(filter_part, dct, incomplete):
+    # TODO: Generalize code with YoutubeDL._build_format_filter
+    STRING_OPERATORS = {
+        '*=': operator.contains,
+        '^=': lambda attr, value: attr.startswith(value),
+        '$=': lambda attr, value: attr.endswith(value),
+        '~=': lambda attr, value: re.search(value, attr),
+    }
     COMPARISON_OPERATORS = {
+        **STRING_OPERATORS,
+        '<=': operator.le,  # "<=" must be defined above "<"
         '<': operator.lt,
-        '<=': operator.le,
-        '>': operator.gt,
         '>=': operator.ge,
+        '>': operator.gt,
         '=': operator.eq,
-        '!=': operator.ne,
     }
+
     operator_rex = re.compile(r'''(?x)\s*
         (?P<key>[a-z_]+)
-        \s*(?P<op>%s)(?P<none_inclusive>\s*\?)?\s*
+        \s*(?P<negation>!\s*)?(?P<op>%s)(?P<none_inclusive>\s*\?)?\s*
         (?:
             (?P<intval>[0-9.]+(?:[kKmMgGtTpPeEzZyY]i?[Bb]?)?)|
-            (?P<quote>["\'])(?P<quotedstrval>(?:\\.|(?!(?P=quote)|\\).)+?)(?P=quote)|
-            (?P<strval>(?![0-9.])[a-z0-9A-Z]*)
+            (?P<quote>["\'])(?P<quotedstrval>.+?)(?P=quote)|
+            (?P<strval>.+?)
         )
         \s*$
         ''' % '|'.join(map(re.escape, COMPARISON_OPERATORS.keys())))
     m = operator_rex.search(filter_part)
     if m:
-        op = COMPARISON_OPERATORS[m.group('op')]
+        unnegated_op = COMPARISON_OPERATORS[m.group('op')]
+        if m.group('negation'):
+            op = lambda attr, value: not unnegated_op(attr, value)
+        else:
+            op = unnegated_op
         actual_value = dct.get(m.group('key'))
         if (m.group('quotedstrval') is not None
             or m.group('strval') is not None
@@ -4693,14 +4700,13 @@ def _match_one(filter_part, dct):
             # https://github.com/ytdl-org/youtube-dl/issues/11082).
             or actual_value is not None and m.group('intval') is not None
                 and isinstance(actual_value, compat_str)):
-            if m.group('op') not in ('=', '!='):
-                raise ValueError(
-                    'Operator %s does not support string values!' % m.group('op'))
             comparison_value = m.group('quotedstrval') or m.group('strval') or m.group('intval')
             quote = m.group('quote')
             if quote is not None:
                 comparison_value = comparison_value.replace(r'\%s' % quote, quote)
         else:
+            if m.group('op') in STRING_OPERATORS:
+                raise ValueError('Operator %s only supports string values!' % m.group('op'))
             try:
                 comparison_value = int(m.group('intval'))
             except ValueError:
@@ -4712,7 +4718,7 @@ def _match_one(filter_part, dct):
                         'Invalid integer value %r in filter part %r' % (
                             m.group('intval'), filter_part))
         if actual_value is None:
-            return m.group('none_inclusive')
+            return incomplete or m.group('none_inclusive')
         return op(actual_value, comparison_value)
 
     UNARY_OPERATORS = {
@@ -4727,21 +4733,25 @@ def _match_one(filter_part, dct):
     if m:
         op = UNARY_OPERATORS[m.group('op')]
         actual_value = dct.get(m.group('key'))
+        if incomplete and actual_value is None:
+            return True
         return op(actual_value)
 
     raise ValueError('Invalid filter part %r' % filter_part)
 
 
-def match_str(filter_str, dct):
-    """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false """
-
+def match_str(filter_str, dct, incomplete=False):
+    """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false
+        When incomplete, all conditions passes on missing fields
+    """
     return all(
-        _match_one(filter_part, dct) for filter_part in filter_str.split('&'))
+        _match_one(filter_part.replace(r'\&', '&'), dct, incomplete)
+        for filter_part in re.split(r'(?<!\\)&', filter_str))
 
 
 def match_filter_func(filter_str):
-    def _match_func(info_dict):
-        if match_str(filter_str, info_dict):
+    def _match_func(info_dict, *args, **kwargs):
+        if match_str(filter_str, info_dict, *args, **kwargs):
             return None
         else:
             video_title = info_dict.get('title', info_dict.get('id', 'video'))
@@ -6149,8 +6159,11 @@ def to_high_limit_path(path):
     return path
 
 
-def format_field(obj, field, template='%s', ignore=(None, ''), default='', func=None):
-    val = obj.get(field, default)
+def format_field(obj, field=None, template='%s', ignore=(None, ''), default='', func=None):
+    if field is None:
+        val = obj if obj is not None else default
+    else:
+        val = obj.get(field, default)
     if func and val not in ignore:
         val = func(val)
     return template % val if val not in ignore else default
@@ -6246,11 +6259,13 @@ def traverse_obj(
     # TODO: Write tests
     '''
     if not casesense:
-        _lower = lambda k: k.lower() if isinstance(k, str) else k
+        _lower = lambda k: (k.lower() if isinstance(k, str) else k)
         path_list = (map(_lower, variadic(path)) for path in path_list)
 
     def _traverse_obj(obj, path, _current_depth=0):
         nonlocal depth
+        if obj is None:
+            return None
         path = tuple(variadic(path))
         for i, key in enumerate(path):
             if isinstance(key, (list, tuple)):
@@ -6263,7 +6278,7 @@ def _traverse_obj(obj, path, _current_depth=0):
                 _current_depth += 1
                 depth = max(depth, _current_depth)
                 return [_traverse_obj(inner_obj, path[i + 1:], _current_depth) for inner_obj in obj]
-            elif isinstance(obj, dict):
+            elif isinstance(obj, dict) and not (is_user_input and key == ':'):
                 obj = (obj.get(key) if casesense or (key in obj)
                        else next((v for k, v in obj.items() if _lower(k) == key), None))
             else:
@@ -6271,7 +6286,7 @@ def _traverse_obj(obj, path, _current_depth=0):
                     key = (int_or_none(key) if ':' not in key
                            else slice(*map(int_or_none, key.split(':'))))
                     if key == slice(None):
-                        return _traverse_obj(obj, (..., *path[i + 1:]))
+                        return _traverse_obj(obj, (..., *path[i + 1:]), _current_depth)
                 if not isinstance(key, (int, slice)):
                     return None
                 if not isinstance(obj, (list, tuple, LazyList)):