]> jfr.im git - yt-dlp.git/blobdiff - test/test_utils.py
Add new options `--impersonate` and `--list-impersonate-targets`
[yt-dlp.git] / test / test_utils.py
index 768edfd0cfea50cee64c95f97a7b8e54fa17e76e..a3073f0e0ac2590505ecf6b9d18ff4f86c8a26b2 100644 (file)
@@ -14,6 +14,7 @@
 import io
 import itertools
 import json
+import subprocess
 import xml.etree.ElementTree
 
 from yt_dlp.compat import (
@@ -28,6 +29,7 @@
     InAdvancePagedList,
     LazyList,
     OnDemandPagedList,
+    Popen,
     age_restricted,
     args_to_str,
     base_url,
@@ -47,8 +49,6 @@
     encode_base_n,
     encode_compat_str,
     encodeFilename,
-    escape_rfc3986,
-    escape_url,
     expand_path,
     extract_attributes,
     extract_basic_auth,
     xpath_text,
     xpath_with_ns,
 )
-from yt_dlp.utils.networking import HTTPHeaderDict
+from yt_dlp.utils.networking import (
+    HTTPHeaderDict,
+    escape_rfc3986,
+    normalize_url,
+    remove_dot_segments,
+)
 
 
 class TestUtil(unittest.TestCase):
@@ -655,6 +660,8 @@ def test_parse_duration(self):
         self.assertEqual(parse_duration('P0Y0M0DT0H4M20.880S'), 260.88)
         self.assertEqual(parse_duration('01:02:03:050'), 3723.05)
         self.assertEqual(parse_duration('103:050'), 103.05)
+        self.assertEqual(parse_duration('1HR 3MIN'), 3780)
+        self.assertEqual(parse_duration('2hrs 3mins'), 7380)
 
     def test_fix_xml_ampersands(self):
         self.assertEqual(
@@ -931,24 +938,45 @@ def test_escape_rfc3986(self):
         self.assertEqual(escape_rfc3986('foo bar'), 'foo%20bar')
         self.assertEqual(escape_rfc3986('foo%20bar'), 'foo%20bar')
 
-    def test_escape_url(self):
+    def test_normalize_url(self):
         self.assertEqual(
-            escape_url('http://wowza.imust.org/srv/vod/telemb/new/UPLOAD/UPLOAD/20224_IncendieHavré_FD.mp4'),
+            normalize_url('http://wowza.imust.org/srv/vod/telemb/new/UPLOAD/UPLOAD/20224_IncendieHavré_FD.mp4'),
             'http://wowza.imust.org/srv/vod/telemb/new/UPLOAD/UPLOAD/20224_IncendieHavre%CC%81_FD.mp4'
         )
         self.assertEqual(
-            escape_url('http://www.ardmediathek.de/tv/Sturm-der-Liebe/Folge-2036-Zu-Mann-und-Frau-erklärt/Das-Erste/Video?documentId=22673108&bcastId=5290'),
+            normalize_url('http://www.ardmediathek.de/tv/Sturm-der-Liebe/Folge-2036-Zu-Mann-und-Frau-erklärt/Das-Erste/Video?documentId=22673108&bcastId=5290'),
             'http://www.ardmediathek.de/tv/Sturm-der-Liebe/Folge-2036-Zu-Mann-und-Frau-erkl%C3%A4rt/Das-Erste/Video?documentId=22673108&bcastId=5290'
         )
         self.assertEqual(
-            escape_url('http://тест.рф/фрагмент'),
+            normalize_url('http://тест.рф/фрагмент'),
             'http://xn--e1aybc.xn--p1ai/%D1%84%D1%80%D0%B0%D0%B3%D0%BC%D0%B5%D0%BD%D1%82'
         )
         self.assertEqual(
-            escape_url('http://тест.рф/абв?абв=абв#абв'),
+            normalize_url('http://тест.рф/абв?абв=абв#абв'),
             'http://xn--e1aybc.xn--p1ai/%D0%B0%D0%B1%D0%B2?%D0%B0%D0%B1%D0%B2=%D0%B0%D0%B1%D0%B2#%D0%B0%D0%B1%D0%B2'
         )
-        self.assertEqual(escape_url('http://vimeo.com/56015672#at=0'), 'http://vimeo.com/56015672#at=0')
+        self.assertEqual(normalize_url('http://vimeo.com/56015672#at=0'), 'http://vimeo.com/56015672#at=0')
+
+        self.assertEqual(normalize_url('http://www.example.com/../a/b/../c/./d.html'), 'http://www.example.com/a/c/d.html')
+
+    def test_remove_dot_segments(self):
+        self.assertEqual(remove_dot_segments('/a/b/c/./../../g'), '/a/g')
+        self.assertEqual(remove_dot_segments('mid/content=5/../6'), 'mid/6')
+        self.assertEqual(remove_dot_segments('/ad/../cd'), '/cd')
+        self.assertEqual(remove_dot_segments('/ad/../cd/'), '/cd/')
+        self.assertEqual(remove_dot_segments('/..'), '/')
+        self.assertEqual(remove_dot_segments('/./'), '/')
+        self.assertEqual(remove_dot_segments('/./a'), '/a')
+        self.assertEqual(remove_dot_segments('/abc/./.././d/././e/.././f/./../../ghi'), '/ghi')
+        self.assertEqual(remove_dot_segments('/'), '/')
+        self.assertEqual(remove_dot_segments('/t'), '/t')
+        self.assertEqual(remove_dot_segments('t'), 't')
+        self.assertEqual(remove_dot_segments(''), '')
+        self.assertEqual(remove_dot_segments('/../a/b/c'), '/a/b/c')
+        self.assertEqual(remove_dot_segments('../a'), 'a')
+        self.assertEqual(remove_dot_segments('./a'), 'a')
+        self.assertEqual(remove_dot_segments('.'), '')
+        self.assertEqual(remove_dot_segments('////'), '////')
 
     def test_js_to_json_vars_strings(self):
         self.assertDictEqual(
@@ -1181,6 +1209,9 @@ def test_js_to_json_edgecases(self):
         on = js_to_json('\'"\\""\'')
         self.assertEqual(json.loads(on), '"""', msg='Unnecessary quote escape should be escaped')
 
+        on = js_to_json('[new Date("spam"), \'("eggs")\']')
+        self.assertEqual(json.loads(on), ['spam', '("eggs")'], msg='Date regex should match a single string')
+
     def test_js_to_json_malformed(self):
         self.assertEqual(js_to_json('42a1'), '42"a1"')
         self.assertEqual(js_to_json('42a-1'), '42"a"-1')
@@ -1192,6 +1223,14 @@ def test_js_to_json_template_literal(self):
         self.assertEqual(js_to_json('`${name}"${name}"`', {'name': '5'}), '"5\\"5\\""')
         self.assertEqual(js_to_json('`${name}`', {}), '"name"')
 
+    def test_js_to_json_common_constructors(self):
+        self.assertEqual(json.loads(js_to_json('new Map([["a", 5]])')), {'a': 5})
+        self.assertEqual(json.loads(js_to_json('Array(5, 10)')), [5, 10])
+        self.assertEqual(json.loads(js_to_json('new Array(15,5)')), [15, 5])
+        self.assertEqual(json.loads(js_to_json('new Map([Array(5, 10),new Array(15,5)])')), {'5': 10, '15': 5})
+        self.assertEqual(json.loads(js_to_json('new Date("123")')), "123")
+        self.assertEqual(json.loads(js_to_json('new Date(\'2023-10-19\')')), "2023-10-19")
+
     def test_extract_attributes(self):
         self.assertEqual(extract_attributes('<e x="y">'), {'x': 'y'})
         self.assertEqual(extract_attributes("<e x='y'>"), {'x': 'y'})
@@ -2071,6 +2110,8 @@ def test_traverse_obj(self):
         self.assertEqual(traverse_obj(_TEST_DATA, (..., {str_or_none})),
                          [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
                          msg='Function in set should be a transformation')
+        self.assertEqual(traverse_obj(_TEST_DATA, ('fail', {lambda _: 'const'})), 'const',
+                         msg='Function in set should always be called')
         if __debug__:
             with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
                 traverse_obj(_TEST_DATA, set())
@@ -2278,23 +2319,6 @@ def test_traverse_obj(self):
         self.assertEqual(traverse_obj({}, (0, slice(1)), traverse_string=True), [],
                          msg='branching should result in list if `traverse_string`')
 
-        # Test is_user_input behavior
-        _IS_USER_INPUT_DATA = {'range8': list(range(8))}
-        self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'),
-                                      is_user_input=True), 3,
-                         msg='allow for string indexing if `is_user_input`')
-        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'),
-                                           is_user_input=True), tuple(range(8))[3:],
-                              msg='allow for string slice if `is_user_input`')
-        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'),
-                                           is_user_input=True), tuple(range(8))[:4:2],
-                              msg='allow step in string slice if `is_user_input`')
-        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'),
-                                           is_user_input=True), range(8),
-                              msg='`:` should be treated as `...` if `is_user_input`')
-        with self.assertRaises(TypeError, msg='too many params should result in error'):
-            traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True)
-
         # Test re.Match as input obj
         mobj = re.fullmatch(r'0(12)(?P<group>3)(4)?', '0123')
         self.assertEqual(traverse_obj(mobj, ...), [x for x in mobj.groups() if x is not None],
@@ -2316,8 +2340,62 @@ def test_traverse_obj(self):
         self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
                          msg='function on a `re.Match` should give group name as well')
 
+        # Test xml.etree.ElementTree.Element as input obj
+        etree = xml.etree.ElementTree.fromstring('''<?xml version="1.0"?>
+        <data>
+            <country name="Liechtenstein">
+                <rank>1</rank>
+                <year>2008</year>
+                <gdppc>141100</gdppc>
+                <neighbor name="Austria" direction="E"/>
+                <neighbor name="Switzerland" direction="W"/>
+            </country>
+            <country name="Singapore">
+                <rank>4</rank>
+                <year>2011</year>
+                <gdppc>59900</gdppc>
+                <neighbor name="Malaysia" direction="N"/>
+            </country>
+            <country name="Panama">
+                <rank>68</rank>
+                <year>2011</year>
+                <gdppc>13600</gdppc>
+                <neighbor name="Costa Rica" direction="W"/>
+                <neighbor name="Colombia" direction="E"/>
+            </country>
+        </data>''')
+        self.assertEqual(traverse_obj(etree, ''), etree,
+                         msg='empty str key should return the element itself')
+        self.assertEqual(traverse_obj(etree, 'country'), list(etree),
+                         msg='str key should lead all children with that tag name')
+        self.assertEqual(traverse_obj(etree, ...), list(etree),
+                         msg='`...` as key should return all children')
+        self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]],
+                         msg='function as key should get element as value')
+        self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]],
+                         msg='function as key should get index as key')
+        self.assertEqual(traverse_obj(etree, 0), etree[0],
+                         msg='int key should return the nth child')
+        self.assertEqual(traverse_obj(etree, './/neighbor/@name'),
+                         ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'],
+                         msg='`@<attribute>` at end of path should give that attribute')
+        self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None],
+                         msg='`@<nonexistant>` at end of path should give `None`')
+        self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'},
+                         msg='`@` should give the full attribute dict')
+        self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'],
+                         msg='`text()` at end of path should give the inner text')
+        self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'],
+                         msg='full Python xpath features should be supported')
+        self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein',
+                         msg='special transformations should act on current element')
+        self.assertEqual(traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})), [1, 2008, 141100],
+                         msg='special transformations should act on current element')
+
     def test_http_header_dict(self):
         headers = HTTPHeaderDict()
+        headers['ytdl-test'] = b'0'
+        self.assertEqual(list(headers.items()), [('Ytdl-Test', '0')])
         headers['ytdl-test'] = 1
         self.assertEqual(list(headers.items()), [('Ytdl-Test', '1')])
         headers['Ytdl-test'] = '2'
@@ -2346,6 +2424,11 @@ def test_http_header_dict(self):
         headers4 = HTTPHeaderDict({'ytdl-test': 'data;'})
         self.assertEqual(set(headers4.items()), {('Ytdl-Test', 'data;')})
 
+        # common mistake: strip whitespace from values
+        # https://github.com/yt-dlp/yt-dlp/issues/8729
+        headers5 = HTTPHeaderDict({'ytdl-test': ' data; '})
+        self.assertEqual(set(headers5.items()), {('Ytdl-Test', 'data;')})
+
     def test_extract_basic_auth(self):
         assert extract_basic_auth('http://:foo.bar') == ('http://:foo.bar', None)
         assert extract_basic_auth('http://foo.bar') == ('http://foo.bar', None)
@@ -2354,6 +2437,21 @@ def test_extract_basic_auth(self):
         assert extract_basic_auth('http://user:@foo.bar') == ('http://foo.bar', 'Basic dXNlcjo=')
         assert extract_basic_auth('http://user:pass@foo.bar') == ('http://foo.bar', 'Basic dXNlcjpwYXNz')
 
+    @unittest.skipUnless(compat_os_name == 'nt', 'Only relevant on Windows')
+    def test_Popen_windows_escaping(self):
+        def run_shell(args):
+            stdout, stderr, error = Popen.run(
+                args, text=True, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+            assert not stderr
+            assert not error
+            return stdout
+
+        # Test escaping
+        assert run_shell(['echo', 'test"&']) == '"test""&"\n'
+        # Test if delayed expansion is disabled
+        assert run_shell(['echo', '^!']) == '"^!"\n'
+        assert run_shell('echo "^!"') == '"^!"\n'
+
 
 if __name__ == '__main__':
     unittest.main()