]> jfr.im git - yt-dlp.git/commitdiff
[utils] Improve `traverse_obj`
authorpukkandan <redacted>
Thu, 15 Jul 2021 14:52:49 +0000 (20:22 +0530)
committerpukkandan <redacted>
Mon, 19 Jul 2021 21:12:11 +0000 (02:42 +0530)
* Allow skipping a level: `traverse_obj([{k:v1}, {k:v2}], (None, k))` => `[v1, v2]`
* Make keys variadic: `traverse_obj(obj, k1: str, k2: str)` => `traverse_obj(obj, (k1,), (k2,))`
* Fetch from multiple keys: `traverse_obj([{k1:[1], k2:[2], k3:[3]}], (0, (k1, k2), 0))` => `[1, 2]`

TODO: Add tests

yt_dlp/utils.py

index 795c5632ffe8f88e9d41b52932c34b5ac2538826..d1be485f8bceb9356c97afb2082d349a0cedc0bb 100644 (file)
@@ -6225,9 +6225,14 @@ def load_plugins(name, suffix, namespace):
 
 
 def traverse_obj(
-        obj, *key_list, default=None, expected_type=None,
+        obj, *path_list, default=None, expected_type=None,
         casesense=True, is_user_input=False, traverse_string=False):
     ''' Traverse nested list/dict/tuple
+    @param path_list        A list of paths which are checked one by one.
+                            Each path is a list of keys where each key is a string,
+                            a tuple of strings or "...". When a tuple is given,
+                            all the keys given in the tuple are traversed, and
+                            "..." traverses all the keys in the object
     @param default          Default value to return
     @param expected_type    Only accept final value of this type
     @param casesense        Whether to consider dictionary keys as case sensitive
@@ -6235,23 +6240,38 @@ def traverse_obj(
                             strings are converted to int/slice if necessary
     @param traverse_string  Whether to traverse inside strings. If True, any
                             non-compatible object will also be converted into a string
+    # TODO: Write tests
     '''
     if not casesense:
         _lower = lambda k: k.lower() if isinstance(k, str) else k
-        key_list = ((_lower(k) for k in keys) for keys in key_list)
-
-    def _traverse_obj(obj, keys):
-        for key in list(keys):
-            if isinstance(obj, dict):
+        path_list = (map(_lower, variadic(path)) for path in path_list)
+
+    def _traverse_obj(obj, path, _current_depth=0):
+        nonlocal depth
+        path = tuple(variadic(path))
+        for i, key in enumerate(path):
+            if isinstance(key, (list, tuple)):
+                obj = [_traverse_obj(obj, sub_key, _current_depth) for sub_key in key]
+                key = ...
+            if key is ...:
+                obj = (obj.values() if isinstance(obj, dict)
+                       else obj if isinstance(obj, (list, tuple, LazyList))
+                       else str(obj) if traverse_string else [])
+                _current_depth += 1
+                depth = max(depth, _current_depth)
+                return [_traverse_obj(inner_obj, path[i + 1:], _current_depth) for inner_obj in obj]
+            elif isinstance(obj, dict):
                 obj = (obj.get(key) if casesense or (key in obj)
                        else next((v for k, v in obj.items() if _lower(k) == key), None))
             else:
                 if is_user_input:
                     key = (int_or_none(key) if ':' not in key
                            else slice(*map(int_or_none, key.split(':'))))
+                    if key == slice(None):
+                        return _traverse_obj(obj, (..., *path[i + 1:]))
                 if not isinstance(key, (int, slice)):
                     return None
-                if not isinstance(obj, (list, tuple)):
+                if not isinstance(obj, (list, tuple, LazyList)):
                     if not traverse_string:
                         return None
                     obj = str(obj)
@@ -6261,10 +6281,18 @@ def _traverse_obj(obj, keys):
                     return None
         return obj
 
-    for keys in key_list:
-        val = _traverse_obj(obj, keys)
+    for path in path_list:
+        depth = 0
+        val = _traverse_obj(obj, path)
         if val is not None:
-            if expected_type is None or isinstance(val, expected_type):
+            if depth:
+                for _ in range(depth - 1):
+                    val = itertools.chain.from_iterable(filter(None, val))
+                val = (list(filter(None, val)) if expected_type is None
+                       else [v for v in val if isinstance(v, expected_type)])
+                if val:
+                    return val
+            elif expected_type is None or isinstance(val, expected_type):
                 return val
     return default