]> jfr.im git - yt-dlp.git/commitdiff
[utils] `traverse_obj`: Support `xml.etree.ElementTree.Element` (#8911)
authorSimon Sawicki <redacted>
Fri, 5 Jan 2024 20:26:17 +0000 (21:26 +0100)
committerGitHub <redacted>
Fri, 5 Jan 2024 20:26:17 +0000 (21:26 +0100)
Authored by: Grub4K

test/test_utils.py
yt_dlp/utils/traversal.py

index c3e387cd0d6fb1088f07be08e49c8f21488d400e..09c648cf8939c84d0cf7e24079a6200f1d1955d5 100644 (file)
@@ -2340,6 +2340,58 @@ def test_traverse_obj(self):
         self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
                          msg='function on a `re.Match` should give group name as well')
 
+        # Test xml.etree.ElementTree.Element as input obj
+        etree = xml.etree.ElementTree.fromstring('''<?xml version="1.0"?>
+        <data>
+            <country name="Liechtenstein">
+                <rank>1</rank>
+                <year>2008</year>
+                <gdppc>141100</gdppc>
+                <neighbor name="Austria" direction="E"/>
+                <neighbor name="Switzerland" direction="W"/>
+            </country>
+            <country name="Singapore">
+                <rank>4</rank>
+                <year>2011</year>
+                <gdppc>59900</gdppc>
+                <neighbor name="Malaysia" direction="N"/>
+            </country>
+            <country name="Panama">
+                <rank>68</rank>
+                <year>2011</year>
+                <gdppc>13600</gdppc>
+                <neighbor name="Costa Rica" direction="W"/>
+                <neighbor name="Colombia" direction="E"/>
+            </country>
+        </data>''')
+        self.assertEqual(traverse_obj(etree, ''), etree,
+                         msg='empty str key should return the element itself')
+        self.assertEqual(traverse_obj(etree, 'country'), list(etree),
+                         msg='str key should lead all children with that tag name')
+        self.assertEqual(traverse_obj(etree, ...), list(etree),
+                         msg='`...` as key should return all children')
+        self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]],
+                         msg='function as key should get element as value')
+        self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]],
+                         msg='function as key should get index as key')
+        self.assertEqual(traverse_obj(etree, 0), etree[0],
+                         msg='int key should return the nth child')
+        self.assertEqual(traverse_obj(etree, './/neighbor/@name'),
+                         ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'],
+                         msg='`@<attribute>` at end of path should give that attribute')
+        self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None],
+                         msg='`@<nonexistant>` at end of path should give `None`')
+        self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'},
+                         msg='`@` should give the full attribute dict')
+        self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'],
+                         msg='`text()` at end of path should give the inner text')
+        self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'],
+                         msg='full python xpath features should be supported')
+        self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein',
+                         msg='special transformations should act on current element')
+        self.assertEqual(traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})), [1, 2008, 141100],
+                         msg='special transformations should act on current element')
+
     def test_http_header_dict(self):
         headers = HTTPHeaderDict()
         headers['ytdl-test'] = b'0'
index 5a2f69fccde3940cea03aa2392b9c44b38b7eba9..8938f4c78298f29bda65adfe1f1e44dba56b0506 100644 (file)
@@ -3,6 +3,7 @@
 import inspect
 import itertools
 import re
+import xml.etree.ElementTree
 
 from ._utils import (
     IDENTITY,
@@ -118,7 +119,7 @@ def apply_key(key, obj, is_last):
             branching = True
             if isinstance(obj, collections.abc.Mapping):
                 result = obj.values()
-            elif is_iterable_like(obj):
+            elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
                 result = obj
             elif isinstance(obj, re.Match):
                 result = obj.groups()
@@ -132,7 +133,7 @@ def apply_key(key, obj, is_last):
             branching = True
             if isinstance(obj, collections.abc.Mapping):
                 iter_obj = obj.items()
-            elif is_iterable_like(obj):
+            elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
                 iter_obj = enumerate(obj)
             elif isinstance(obj, re.Match):
                 iter_obj = itertools.chain(
@@ -168,7 +169,7 @@ def apply_key(key, obj, is_last):
                 result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
 
         elif isinstance(key, (int, slice)):
-            if is_iterable_like(obj, collections.abc.Sequence):
+            if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)):
                 branching = isinstance(key, slice)
                 with contextlib.suppress(IndexError):
                     result = obj[key]
@@ -176,6 +177,34 @@ def apply_key(key, obj, is_last):
                 with contextlib.suppress(IndexError):
                     result = str(obj)[key]
 
+        elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
+            xpath, _, special = key.rpartition('/')
+            if not special.startswith('@') and special != 'text()':
+                xpath = key
+                special = None
+
+            # Allow abbreviations of relative paths, absolute paths error
+            if xpath.startswith('/'):
+                xpath = f'.{xpath}'
+            elif xpath and not xpath.startswith('./'):
+                xpath = f'./{xpath}'
+
+            def apply_specials(element):
+                if special is None:
+                    return element
+                if special == '@':
+                    return element.attrib
+                if special.startswith('@'):
+                    return try_call(element.attrib.get, args=(special[1:],))
+                if special == 'text()':
+                    return element.text
+                assert False, f'apply_specials is missing case for {special!r}'
+
+            if xpath:
+                result = list(map(apply_specials, obj.iterfind(xpath)))
+            else:
+                result = apply_specials(obj)
+
         return branching, result if branching else (result,)
 
     def lazy_last(iterable):