]> jfr.im git - yt-dlp.git/blob - yt_dlp/utils/traversal.py
5a2f69fccde3940cea03aa2392b9c44b38b7eba9
[yt-dlp.git] / yt_dlp / utils / traversal.py
1 import collections.abc
2 import contextlib
3 import inspect
4 import itertools
5 import re
6
7 from ._utils import (
8 IDENTITY,
9 NO_DEFAULT,
10 LazyList,
11 deprecation_warning,
12 is_iterable_like,
13 try_call,
14 variadic,
15 )
16
17
18 def traverse_obj(
19 obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
20 casesense=True, is_user_input=NO_DEFAULT, traverse_string=False):
21 """
22 Safely traverse nested `dict`s and `Iterable`s
23
24 >>> obj = [{}, {"key": "value"}]
25 >>> traverse_obj(obj, (1, "key"))
26 'value'
27
28 Each of the provided `paths` is tested and the first producing a valid result will be returned.
29 The next path will also be tested if the path branched but no results could be found.
30 Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
31 Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
32
33 The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
34
35 The keys in the path can be one of:
36 - `None`: Return the current object.
37 - `set`: Requires the only item in the set to be a type or function,
38 like `{type}`/`{func}`. If a `type`, returns only values
39 of this type. If a function, returns `func(obj)`.
40 - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
41 - `slice`: Branch out and return all values in `obj[key]`.
42 - `Ellipsis`: Branch out and return a list of all values.
43 - `tuple`/`list`: Branch out and return a list of all matching values.
44 Read as: `[traverse_obj(obj, branch) for branch in branches]`.
45 - `function`: Branch out and return values filtered by the function.
46 Read as: `[value for key, value in obj if function(key, value)]`.
47 For `Iterable`s, `key` is the index of the value.
48 For `re.Match`es, `key` is the group number (0 = full match)
49 as well as additionally any group names, if given.
50 - `dict` Transform the current object and return a matching dict.
51 Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
52
53 `tuple`, `list`, and `dict` all support nested paths and branches.
54
55 @params paths Paths which to traverse by.
56 @param default Value to return if the paths do not match.
57 If the last key in the path is a `dict`, it will apply to each value inside
58 the dict instead, depth first. Try to avoid if using nested `dict` keys.
59 @param expected_type If a `type`, only accept final values of this type.
60 If any other callable, try to call the function on each result.
61 If the last key in the path is a `dict`, it will apply to each value inside
62 the dict instead, recursively. This does respect branching paths.
63 @param get_all If `False`, return the first matching result, otherwise all matching ones.
64 @param casesense If `False`, consider string dictionary keys as case insensitive.
65
66 `traverse_string` is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
67
68 @param traverse_string Whether to traverse into objects as strings.
69 If `True`, any non-compatible object will first be
70 converted into a string and then traversed into.
71 The return value of that path will be a string instead,
72 not respecting any further branching.
73
74
75 @returns The result of the object traversal.
76 If successful, `get_all=True`, and the path branches at least once,
77 then a list of results is returned instead.
78 If no `default` is given and the last path branches, a `list` of results
79 is always returned. If a path ends on a `dict` that result will always be a `dict`.
80 """
81 if is_user_input is not NO_DEFAULT:
82 deprecation_warning('The is_user_input parameter is deprecated and no longer works')
83
84 casefold = lambda k: k.casefold() if isinstance(k, str) else k
85
86 if isinstance(expected_type, type):
87 type_test = lambda val: val if isinstance(val, expected_type) else None
88 else:
89 type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
90
91 def apply_key(key, obj, is_last):
92 branching = False
93 result = None
94
95 if obj is None and traverse_string:
96 if key is ... or callable(key) or isinstance(key, slice):
97 branching = True
98 result = ()
99
100 elif key is None:
101 result = obj
102
103 elif isinstance(key, set):
104 assert len(key) == 1, 'Set should only be used to wrap a single item'
105 item = next(iter(key))
106 if isinstance(item, type):
107 if isinstance(obj, item):
108 result = obj
109 else:
110 result = try_call(item, args=(obj,))
111
112 elif isinstance(key, (list, tuple)):
113 branching = True
114 result = itertools.chain.from_iterable(
115 apply_path(obj, branch, is_last)[0] for branch in key)
116
117 elif key is ...:
118 branching = True
119 if isinstance(obj, collections.abc.Mapping):
120 result = obj.values()
121 elif is_iterable_like(obj):
122 result = obj
123 elif isinstance(obj, re.Match):
124 result = obj.groups()
125 elif traverse_string:
126 branching = False
127 result = str(obj)
128 else:
129 result = ()
130
131 elif callable(key):
132 branching = True
133 if isinstance(obj, collections.abc.Mapping):
134 iter_obj = obj.items()
135 elif is_iterable_like(obj):
136 iter_obj = enumerate(obj)
137 elif isinstance(obj, re.Match):
138 iter_obj = itertools.chain(
139 enumerate((obj.group(), *obj.groups())),
140 obj.groupdict().items())
141 elif traverse_string:
142 branching = False
143 iter_obj = enumerate(str(obj))
144 else:
145 iter_obj = ()
146
147 result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
148 if not branching: # string traversal
149 result = ''.join(result)
150
151 elif isinstance(key, dict):
152 iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
153 result = {
154 k: v if v is not None else default for k, v in iter_obj
155 if v is not None or default is not NO_DEFAULT
156 } or None
157
158 elif isinstance(obj, collections.abc.Mapping):
159 result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
160 next((v for k, v in obj.items() if casefold(k) == key), None))
161
162 elif isinstance(obj, re.Match):
163 if isinstance(key, int) or casesense:
164 with contextlib.suppress(IndexError):
165 result = obj.group(key)
166
167 elif isinstance(key, str):
168 result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
169
170 elif isinstance(key, (int, slice)):
171 if is_iterable_like(obj, collections.abc.Sequence):
172 branching = isinstance(key, slice)
173 with contextlib.suppress(IndexError):
174 result = obj[key]
175 elif traverse_string:
176 with contextlib.suppress(IndexError):
177 result = str(obj)[key]
178
179 return branching, result if branching else (result,)
180
181 def lazy_last(iterable):
182 iterator = iter(iterable)
183 prev = next(iterator, NO_DEFAULT)
184 if prev is NO_DEFAULT:
185 return
186
187 for item in iterator:
188 yield False, prev
189 prev = item
190
191 yield True, prev
192
193 def apply_path(start_obj, path, test_type):
194 objs = (start_obj,)
195 has_branched = False
196
197 key = None
198 for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
199 if not casesense and isinstance(key, str):
200 key = key.casefold()
201
202 if __debug__ and callable(key):
203 # Verify function signature
204 inspect.signature(key).bind(None, None)
205
206 new_objs = []
207 for obj in objs:
208 branching, results = apply_key(key, obj, last)
209 has_branched |= branching
210 new_objs.append(results)
211
212 objs = itertools.chain.from_iterable(new_objs)
213
214 if test_type and not isinstance(key, (dict, list, tuple)):
215 objs = map(type_test, objs)
216
217 return objs, has_branched, isinstance(key, dict)
218
219 def _traverse_obj(obj, path, allow_empty, test_type):
220 results, has_branched, is_dict = apply_path(obj, path, test_type)
221 results = LazyList(item for item in results if item not in (None, {}))
222 if get_all and has_branched:
223 if results:
224 return results.exhaust()
225 if allow_empty:
226 return [] if default is NO_DEFAULT else default
227 return None
228
229 return results[0] if results else {} if allow_empty and is_dict else None
230
231 for index, path in enumerate(paths, 1):
232 result = _traverse_obj(obj, path, index == len(paths), True)
233 if result is not None:
234 return result
235
236 return None if default is NO_DEFAULT else default
237
238
239 def get_first(obj, *paths, **kwargs):
240 return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)
241
242
243 def dict_get(d, key_or_keys, default=None, skip_false_values=True):
244 for val in map(d.get, variadic(key_or_keys)):
245 if val is not None and (val or not skip_false_values):
246 return val
247 return default