]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils.py
Let `--match-filter` reject entries early
[yt-dlp.git] / yt_dlp / utils.py
index fd13febd6da5ef0822231661f4dca579fda43833..6276ac726be379f628e056b8e76b5d7b8523d0b8 100644 (file)
@@ -4041,15 +4041,31 @@ def __str__(self):
         return repr(self.exhaust())
 
 
-class PagedList(object):
+class PagedList:
     def __len__(self):
         # This is only useful for tests
         return len(self.getslice())
 
-    def getslice(self, start, end):
+    def __init__(self, pagefunc, pagesize, use_cache=True):
+        self._pagefunc = pagefunc
+        self._pagesize = pagesize
+        self._use_cache = use_cache
+        self._cache = {}
+
+    def getpage(self, pagenum):
+        page_results = self._cache.get(pagenum) or list(self._pagefunc(pagenum))
+        if self._use_cache:
+            self._cache[pagenum] = page_results
+        return page_results
+
+    def getslice(self, start=0, end=None):
+        return list(self._getslice(start, end))
+
+    def _getslice(self, start, end):
         raise NotImplementedError('This method must be implemented by subclasses')
 
     def __getitem__(self, idx):
+        # NOTE: cache must be enabled if this is used
         if not isinstance(idx, int) or idx < 0:
             raise TypeError('indices must be non-negative integers')
         entries = self.getslice(idx, idx + 1)
@@ -4057,42 +4073,26 @@ def __getitem__(self, idx):
 
 
 class OnDemandPagedList(PagedList):
-    def __init__(self, pagefunc, pagesize, use_cache=True):
-        self._pagefunc = pagefunc
-        self._pagesize = pagesize
-        self._use_cache = use_cache
-        if use_cache:
-            self._cache = {}
-
-    def getslice(self, start=0, end=None):
-        res = []
+    def _getslice(self, start, end):
         for pagenum in itertools.count(start // self._pagesize):
             firstid = pagenum * self._pagesize
             nextfirstid = pagenum * self._pagesize + self._pagesize
             if start >= nextfirstid:
                 continue
 
-            page_results = None
-            if self._use_cache:
-                page_results = self._cache.get(pagenum)
-            if page_results is None:
-                page_results = list(self._pagefunc(pagenum))
-            if self._use_cache:
-                self._cache[pagenum] = page_results
-
             startv = (
                 start % self._pagesize
                 if firstid <= start < nextfirstid
                 else 0)
-
             endv = (
                 ((end - 1) % self._pagesize) + 1
                 if (end is not None and firstid <= end <= nextfirstid)
                 else None)
 
+            page_results = self.getpage(pagenum)
             if startv != 0 or endv is not None:
                 page_results = page_results[startv:endv]
-            res.extend(page_results)
+            yield from page_results
 
             # A little optimization - if current page is not "full", ie. does
             # not contain page_size videos then we can assume that this page
@@ -4105,36 +4105,31 @@ def getslice(self, start=0, end=None):
             # break out early as well
             if end == nextfirstid:
                 break
-        return res
 
 
 class InAdvancePagedList(PagedList):
     def __init__(self, pagefunc, pagecount, pagesize):
-        self._pagefunc = pagefunc
         self._pagecount = pagecount
-        self._pagesize = pagesize
+        PagedList.__init__(self, pagefunc, pagesize, True)
 
-    def getslice(self, start=0, end=None):
-        res = []
+    def _getslice(self, start, end):
         start_page = start // self._pagesize
         end_page = (
             self._pagecount if end is None else (end // self._pagesize + 1))
         skip_elems = start - start_page * self._pagesize
         only_more = None if end is None else end - start
         for pagenum in range(start_page, end_page):
-            page = list(self._pagefunc(pagenum))
+            page_results = self.getpage(pagenum)
             if skip_elems:
-                page = page[skip_elems:]
+                page_results = page_results[skip_elems:]
                 skip_elems = None
             if only_more is not None:
-                if len(page) < only_more:
-                    only_more -= len(page)
+                if len(page_results) < only_more:
+                    only_more -= len(page_results)
                 else:
-                    page = page[:only_more]
-                    res.extend(page)
+                    yield from page_results[:only_more]
                     break
-            res.extend(page)
-        return res
+            yield from page_results
 
 
 def uppercase_escape(s):
@@ -4662,7 +4657,7 @@ def filter_using_list(row, filterArray):
     return '\n'.join(format_str % tuple(row) for row in table)
 
 
-def _match_one(filter_part, dct):
+def _match_one(filter_part, dct, incomplete):
     # TODO: Generalize code with YoutubeDL._build_format_filter
     STRING_OPERATORS = {
         '*=': operator.contains,
@@ -4723,7 +4718,7 @@ def _match_one(filter_part, dct):
                         'Invalid integer value %r in filter part %r' % (
                             m.group('intval'), filter_part))
         if actual_value is None:
-            return m.group('none_inclusive')
+            return incomplete or m.group('none_inclusive')
         return op(actual_value, comparison_value)
 
     UNARY_OPERATORS = {
@@ -4738,22 +4733,25 @@ def _match_one(filter_part, dct):
     if m:
         op = UNARY_OPERATORS[m.group('op')]
         actual_value = dct.get(m.group('key'))
+        if incomplete and actual_value is None:
+            return True
         return op(actual_value)
 
     raise ValueError('Invalid filter part %r' % filter_part)
 
 
-def match_str(filter_str, dct):
-    """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false """
-
+def match_str(filter_str, dct, incomplete=False):
+    """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false
+        When incomplete, all conditions passes on missing fields
+    """
     return all(
-        _match_one(filter_part.replace(r'\&', '&'), dct)
+        _match_one(filter_part.replace(r'\&', '&'), dct, incomplete)
         for filter_part in re.split(r'(?<!\\)&', filter_str))
 
 
 def match_filter_func(filter_str):
-    def _match_func(info_dict):
-        if match_str(filter_str, info_dict):
+    def _match_func(info_dict, *args, **kwargs):
+        if match_str(filter_str, info_dict, *args, **kwargs):
             return None
         else:
             video_title = info_dict.get('title', info_dict.get('id', 'video'))
@@ -6161,8 +6159,11 @@ def to_high_limit_path(path):
     return path
 
 
-def format_field(obj, field, template='%s', ignore=(None, ''), default='', func=None):
-    val = obj.get(field, default)
+def format_field(obj, field=None, template='%s', ignore=(None, ''), default='', func=None):
+    if field is None:
+        val = obj if obj is not None else default
+    else:
+        val = obj.get(field, default)
     if func and val not in ignore:
         val = func(val)
     return template % val if val not in ignore else default