]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils/traversal.py
[ie/brightcove] Upgrade requests to HTTPS (#10202)
[yt-dlp.git] / yt_dlp / utils / traversal.py
index ff5703198aea65e844f7a5831e492015e89f4e28..96eb2eddf5296d6af7c374111a0fc672892a3542 100644 (file)
@@ -1,8 +1,10 @@
 import collections.abc
 import contextlib
+import http.cookies
 import inspect
 import itertools
 import re
+import xml.etree.ElementTree
 
 from ._utils import (
     IDENTITY,
@@ -23,11 +25,12 @@ def traverse_obj(
 
     >>> obj = [{}, {"key": "value"}]
     >>> traverse_obj(obj, (1, "key"))
-    "value"
+    'value'
 
     Each of the provided `paths` is tested and the first producing a valid result will be returned.
     The next path will also be tested if the path branched but no results could be found.
-    Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
+    Supported values for traversal are `Mapping`, `Iterable`, `re.Match`,
+    `xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`.
     Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
 
     The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
@@ -35,8 +38,8 @@ def traverse_obj(
     The keys in the path can be one of:
         - `None`:           Return the current object.
         - `set`:            Requires the only item in the set to be a type or function,
-                            like `{type}`/`{func}`. If a `type`, returns only values
-                            of this type. If a function, returns `func(obj)`.
+                            like `{type}`/`{type, type, ...}/`{func}`. If a `type`, return only
+                            values of this type. If a function, returns `func(obj)`.
         - `str`/`int`:      Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
         - `slice`:          Branch out and return all values in `obj[key]`.
         - `Ellipsis`:       Branch out and return a list of all values.
@@ -47,8 +50,10 @@ def traverse_obj(
                             For `Iterable`s, `key` is the index of the value.
                             For `re.Match`es, `key` is the group number (0 = full match)
                             as well as additionally any group names, if given.
-        - `dict`            Transform the current object and return a matching dict.
+        - `dict`:           Transform the current object and return a matching dict.
                             Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
+        - `any`-builtin:    Take the first matching object and return it, resetting branching.
+        - `all`-builtin:    Take all matching objects and return them as a list, resetting branching.
 
         `tuple`, `list`, and `dict` all support nested paths and branches.
 
@@ -101,10 +106,10 @@ def apply_key(key, obj, is_last):
             result = obj
 
         elif isinstance(key, set):
-            assert len(key) == 1, 'Set should only be used to wrap a single item'
             item = next(iter(key))
-            if isinstance(item, type):
-                if isinstance(obj, item):
+            if len(key) > 1 or isinstance(item, type):
+                assert all(isinstance(item, type) for item in key)
+                if isinstance(obj, tuple(key)):
                     result = obj
             else:
                 result = try_call(item, args=(obj,))
@@ -116,9 +121,11 @@ def apply_key(key, obj, is_last):
 
         elif key is ...:
             branching = True
+            if isinstance(obj, http.cookies.Morsel):
+                obj = dict(obj, key=obj.key, value=obj.value)
             if isinstance(obj, collections.abc.Mapping):
                 result = obj.values()
-            elif is_iterable_like(obj):
+            elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
                 result = obj
             elif isinstance(obj, re.Match):
                 result = obj.groups()
@@ -130,9 +137,11 @@ def apply_key(key, obj, is_last):
 
         elif callable(key):
             branching = True
+            if isinstance(obj, http.cookies.Morsel):
+                obj = dict(obj, key=obj.key, value=obj.value)
             if isinstance(obj, collections.abc.Mapping):
                 iter_obj = obj.items()
-            elif is_iterable_like(obj):
+            elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
                 iter_obj = enumerate(obj)
             elif isinstance(obj, re.Match):
                 iter_obj = itertools.chain(
@@ -156,6 +165,8 @@ def apply_key(key, obj, is_last):
             } or None
 
         elif isinstance(obj, collections.abc.Mapping):
+            if isinstance(obj, http.cookies.Morsel):
+                obj = dict(obj, key=obj.key, value=obj.value)
             result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
                       next((v for k, v in obj.items() if casefold(k) == key), None))
 
@@ -168,7 +179,7 @@ def apply_key(key, obj, is_last):
                 result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
 
         elif isinstance(key, (int, slice)):
-            if is_iterable_like(obj, collections.abc.Sequence):
+            if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)):
                 branching = isinstance(key, slice)
                 with contextlib.suppress(IndexError):
                     result = obj[key]
@@ -176,6 +187,34 @@ def apply_key(key, obj, is_last):
                 with contextlib.suppress(IndexError):
                     result = str(obj)[key]
 
+        elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
+            xpath, _, special = key.rpartition('/')
+            if not special.startswith('@') and not special.endswith('()'):
+                xpath = key
+                special = None
+
+            # Allow abbreviations of relative paths, absolute paths error
+            if xpath.startswith('/'):
+                xpath = f'.{xpath}'
+            elif xpath and not xpath.startswith('./'):
+                xpath = f'./{xpath}'
+
+            def apply_specials(element):
+                if special is None:
+                    return element
+                if special == '@':
+                    return element.attrib
+                if special.startswith('@'):
+                    return try_call(element.attrib.get, args=(special[1:],))
+                if special == 'text()':
+                    return element.text
+                raise SyntaxError(f'apply_specials is missing case for {special!r}')
+
+            if xpath:
+                result = list(map(apply_specials, obj.iterfind(xpath)))
+            else:
+                result = apply_specials(obj)
+
         return branching, result if branching else (result,)
 
     def lazy_last(iterable):
@@ -199,6 +238,15 @@ def apply_path(start_obj, path, test_type):
             if not casesense and isinstance(key, str):
                 key = key.casefold()
 
+            if key in (any, all):
+                has_branched = False
+                filtered_objs = (obj for obj in objs if obj not in (None, {}))
+                if key is any:
+                    objs = (next(filtered_objs, None),)
+                else:
+                    objs = (list(filtered_objs),)
+                continue
+
             if __debug__ and callable(key):
                 # Verify function signature
                 inspect.signature(key).bind(None, None)