]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils.py
Let `--match-filter` reject entries early
[yt-dlp.git] / yt_dlp / utils.py
index b04fbd22cf526e4ea672b04767ae59c8cf657589..6276ac726be379f628e056b8e76b5d7b8523d0b8 100644 (file)
@@ -1836,7 +1836,7 @@ def write_json_file(obj, fn):
 
     try:
         with tf:
-            json.dump(obj, tf, default=repr)
+            json.dump(obj, tf)
         if sys.platform == 'win32':
             # Need to remove existing file on Windows, else os.rename raises
             # WindowsError or FileExistsError.
@@ -4041,15 +4041,31 @@ def __str__(self):
         return repr(self.exhaust())
 
 
-class PagedList(object):
+class PagedList:
     def __len__(self):
         # This is only useful for tests
         return len(self.getslice())
 
-    def getslice(self, start, end):
+    def __init__(self, pagefunc, pagesize, use_cache=True):
+        self._pagefunc = pagefunc
+        self._pagesize = pagesize
+        self._use_cache = use_cache
+        self._cache = {}
+
+    def getpage(self, pagenum):
+        page_results = self._cache.get(pagenum) or list(self._pagefunc(pagenum))
+        if self._use_cache:
+            self._cache[pagenum] = page_results
+        return page_results
+
+    def getslice(self, start=0, end=None):
+        return list(self._getslice(start, end))
+
+    def _getslice(self, start, end):
         raise NotImplementedError('This method must be implemented by subclasses')
 
     def __getitem__(self, idx):
+        # NOTE: cache must be enabled if this is used
         if not isinstance(idx, int) or idx < 0:
             raise TypeError('indices must be non-negative integers')
         entries = self.getslice(idx, idx + 1)
@@ -4057,42 +4073,26 @@ def __getitem__(self, idx):
 
 
 class OnDemandPagedList(PagedList):
-    def __init__(self, pagefunc, pagesize, use_cache=True):
-        self._pagefunc = pagefunc
-        self._pagesize = pagesize
-        self._use_cache = use_cache
-        if use_cache:
-            self._cache = {}
-
-    def getslice(self, start=0, end=None):
-        res = []
+    def _getslice(self, start, end):
         for pagenum in itertools.count(start // self._pagesize):
             firstid = pagenum * self._pagesize
             nextfirstid = pagenum * self._pagesize + self._pagesize
             if start >= nextfirstid:
                 continue
 
-            page_results = None
-            if self._use_cache:
-                page_results = self._cache.get(pagenum)
-            if page_results is None:
-                page_results = list(self._pagefunc(pagenum))
-            if self._use_cache:
-                self._cache[pagenum] = page_results
-
             startv = (
                 start % self._pagesize
                 if firstid <= start < nextfirstid
                 else 0)
-
             endv = (
                 ((end - 1) % self._pagesize) + 1
                 if (end is not None and firstid <= end <= nextfirstid)
                 else None)
 
+            page_results = self.getpage(pagenum)
             if startv != 0 or endv is not None:
                 page_results = page_results[startv:endv]
-            res.extend(page_results)
+            yield from page_results
 
             # A little optimization - if current page is not "full", ie. does
             # not contain page_size videos then we can assume that this page
@@ -4105,36 +4105,31 @@ def getslice(self, start=0, end=None):
             # break out early as well
             if end == nextfirstid:
                 break
-        return res
 
 
 class InAdvancePagedList(PagedList):
     def __init__(self, pagefunc, pagecount, pagesize):
-        self._pagefunc = pagefunc
         self._pagecount = pagecount
-        self._pagesize = pagesize
+        PagedList.__init__(self, pagefunc, pagesize, True)
 
-    def getslice(self, start=0, end=None):
-        res = []
+    def _getslice(self, start, end):
         start_page = start // self._pagesize
         end_page = (
             self._pagecount if end is None else (end // self._pagesize + 1))
         skip_elems = start - start_page * self._pagesize
         only_more = None if end is None else end - start
         for pagenum in range(start_page, end_page):
-            page = list(self._pagefunc(pagenum))
+            page_results = self.getpage(pagenum)
             if skip_elems:
-                page = page[skip_elems:]
+                page_results = page_results[skip_elems:]
                 skip_elems = None
             if only_more is not None:
-                if len(page) < only_more:
-                    only_more -= len(page)
+                if len(page_results) < only_more:
+                    only_more -= len(page_results)
                 else:
-                    page = page[:only_more]
-                    res.extend(page)
+                    yield from page_results[:only_more]
                     break
-            res.extend(page)
-        return res
+            yield from page_results
 
 
 def uppercase_escape(s):
@@ -4662,7 +4657,7 @@ def filter_using_list(row, filterArray):
     return '\n'.join(format_str % tuple(row) for row in table)
 
 
-def _match_one(filter_part, dct):
+def _match_one(filter_part, dct, incomplete):
     # TODO: Generalize code with YoutubeDL._build_format_filter
     STRING_OPERATORS = {
         '*=': operator.contains,
@@ -4723,7 +4718,7 @@ def _match_one(filter_part, dct):
                         'Invalid integer value %r in filter part %r' % (
                             m.group('intval'), filter_part))
         if actual_value is None:
-            return m.group('none_inclusive')
+            return incomplete or m.group('none_inclusive')
         return op(actual_value, comparison_value)
 
     UNARY_OPERATORS = {
@@ -4738,22 +4733,25 @@ def _match_one(filter_part, dct):
     if m:
         op = UNARY_OPERATORS[m.group('op')]
         actual_value = dct.get(m.group('key'))
+        if incomplete and actual_value is None:
+            return True
         return op(actual_value)
 
     raise ValueError('Invalid filter part %r' % filter_part)
 
 
-def match_str(filter_str, dct):
-    """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false """
-
+def match_str(filter_str, dct, incomplete=False):
+    """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false
+        When incomplete, all conditions passes on missing fields
+    """
     return all(
-        _match_one(filter_part.replace(r'\&', '&'), dct)
+        _match_one(filter_part.replace(r'\&', '&'), dct, incomplete)
         for filter_part in re.split(r'(?<!\\)&', filter_str))
 
 
 def match_filter_func(filter_str):
-    def _match_func(info_dict):
-        if match_str(filter_str, info_dict):
+    def _match_func(info_dict, *args, **kwargs):
+        if match_str(filter_str, info_dict, *args, **kwargs):
             return None
         else:
             video_title = info_dict.get('title', info_dict.get('id', 'video'))
@@ -6161,8 +6159,11 @@ def to_high_limit_path(path):
     return path
 
 
-def format_field(obj, field, template='%s', ignore=(None, ''), default='', func=None):
-    val = obj.get(field, default)
+def format_field(obj, field=None, template='%s', ignore=(None, ''), default='', func=None):
+    if field is None:
+        val = obj if obj is not None else default
+    else:
+        val = obj.get(field, default)
     if func and val not in ignore:
         val = func(val)
     return template % val if val not in ignore else default
@@ -6263,6 +6264,8 @@ def traverse_obj(
 
     def _traverse_obj(obj, path, _current_depth=0):
         nonlocal depth
+        if obj is None:
+            return None
         path = tuple(variadic(path))
         for i, key in enumerate(path):
             if isinstance(key, (list, tuple)):
@@ -6275,7 +6278,7 @@ def _traverse_obj(obj, path, _current_depth=0):
                 _current_depth += 1
                 depth = max(depth, _current_depth)
                 return [_traverse_obj(inner_obj, path[i + 1:], _current_depth) for inner_obj in obj]
-            elif isinstance(obj, dict):
+            elif isinstance(obj, dict) and not (is_user_input and key == ':'):
                 obj = (obj.get(key) if casesense or (key in obj)
                        else next((v for k, v in obj.items() if _lower(k) == key), None))
             else:
@@ -6283,7 +6286,7 @@ def _traverse_obj(obj, path, _current_depth=0):
                     key = (int_or_none(key) if ':' not in key
                            else slice(*map(int_or_none, key.split(':'))))
                     if key == slice(None):
-                        return _traverse_obj(obj, (..., *path[i + 1:]))
+                        return _traverse_obj(obj, (..., *path[i + 1:]), _current_depth)
                 if not isinstance(key, (int, slice)):
                     return None
                 if not isinstance(obj, (list, tuple, LazyList)):