]> jfr.im git - yt-dlp.git/blobdiff - test/test_utils.py
[cleanup] Consistent style for file heads
[yt-dlp.git] / test / test_utils.py
index 2e33308c759764df462c827b9769324a10d9fe2e..8024a8e7c8da7977af90fdf578751adc60b66e79 100644 (file)
 #!/usr/bin/env python3
-# coding: utf-8
-
-from __future__ import unicode_literals
 
 # Allow direct execution
 import os
 import sys
 import unittest
+
 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
 
-# Various small unit tests
+import contextlib
 import io
 import itertools
 import json
 import xml.etree.ElementTree
 
+from yt_dlp.compat import (
+    compat_etree_fromstring,
+    compat_HTMLParseError,
+    compat_os_name,
+)
 from yt_dlp.utils import (
+    Config,
+    DateRange,
+    ExtractorError,
+    InAdvancePagedList,
+    LazyList,
+    OnDemandPagedList,
     age_restricted,
     args_to_str,
-    encode_base_n,
+    base_url,
     caesar,
     clean_html,
     clean_podcast_url,
+    cli_bool_option,
+    cli_option,
+    cli_valueless_option,
     date_from_str,
     datetime_from_str,
-    DateRange,
     detect_exe_version,
     determine_ext,
+    dfxp2srt,
     dict_get,
+    encode_base_n,
     encode_compat_str,
     encodeFilename,
     escape_rfc3986,
     escape_url,
+    expand_path,
     extract_attributes,
-    ExtractorError,
     find_xpath_attr,
     fix_xml_ampersands,
-    format_bytes,
     float_or_none,
-    get_element_by_class,
+    format_bytes,
     get_element_by_attribute,
-    get_elements_by_class,
+    get_element_by_class,
+    get_element_html_by_attribute,
+    get_element_html_by_class,
+    get_element_text_and_html_by_tag,
     get_elements_by_attribute,
-    InAdvancePagedList,
+    get_elements_by_class,
+    get_elements_html_by_attribute,
+    get_elements_html_by_class,
+    get_elements_text_and_html_by_attribute,
     int_or_none,
     intlist_to_bytes,
+    iri_to_uri,
     is_html,
     js_to_json,
     limit_length,
+    locked_file,
+    lowercase_escape,
+    match_str,
     merge_dicts,
     mimetype2ext,
     month_by_name,
     multipart_encode,
     ohdave_rsa_encrypt,
-    OnDemandPagedList,
     orderedSet,
     parse_age_limit,
+    parse_bitrate,
+    parse_codecs,
+    parse_count,
+    parse_dfxp_time_expr,
     parse_duration,
     parse_filesize,
-    parse_count,
     parse_iso8601,
-    parse_resolution,
-    parse_bitrate,
     parse_qs,
+    parse_resolution,
     pkcs1pad,
+    prepend_extension,
     read_batch_urls,
+    remove_end,
+    remove_quotes,
+    remove_start,
+    render_table,
+    replace_extension,
+    rot47,
     sanitize_filename,
     sanitize_path,
     sanitize_url,
     sanitized_Request,
-    expand_path,
-    prepend_extension,
-    replace_extension,
-    remove_start,
-    remove_end,
-    remove_quotes,
-    rot47,
     shell_quote,
     smuggle_url,
     str_to_int,
     unified_strdate,
     unified_timestamp,
     unsmuggle_url,
+    update_url_query,
     uppercase_escape,
-    lowercase_escape,
     url_basename,
     url_or_none,
-    base_url,
-    urljoin,
     urlencode_postdata,
+    urljoin,
     urshift,
-    update_url_query,
     version_tuple,
-    xpath_with_ns,
+    xpath_attr,
     xpath_element,
     xpath_text,
-    xpath_attr,
-    render_table,
-    match_str,
-    parse_dfxp_time_expr,
-    dfxp2srt,
-    cli_option,
-    cli_valueless_option,
-    cli_bool_option,
-    parse_codecs,
-    iri_to_uri,
-    LazyList,
-)
-from yt_dlp.compat import (
-    compat_chr,
-    compat_etree_fromstring,
-    compat_getenv,
-    compat_os_name,
-    compat_setenv,
+    xpath_with_ns,
 )
 
 
@@ -152,10 +156,12 @@ def test_sanitize_filename(self):
             sanitize_filename('New World record at 0:12:34'),
             'New World record at 0_12_34')
 
-        self.assertEqual(sanitize_filename('--gasdgf'), '_-gasdgf')
+        self.assertEqual(sanitize_filename('--gasdgf'), '--gasdgf')
         self.assertEqual(sanitize_filename('--gasdgf', is_id=True), '--gasdgf')
-        self.assertEqual(sanitize_filename('.gasdgf'), 'gasdgf')
+        self.assertEqual(sanitize_filename('--gasdgf', is_id=False), '_-gasdgf')
+        self.assertEqual(sanitize_filename('.gasdgf'), '.gasdgf')
         self.assertEqual(sanitize_filename('.gasdgf', is_id=True), '.gasdgf')
+        self.assertEqual(sanitize_filename('.gasdgf', is_id=False), 'gasdgf')
 
         forbidden = '"\0\\/'
         for fc in forbidden:
@@ -255,15 +261,22 @@ def test_extract_basic_auth(self):
 
     def test_expand_path(self):
         def env(var):
-            return '%{0}%'.format(var) if sys.platform == 'win32' else '${0}'.format(var)
+            return f'%{var}%' if sys.platform == 'win32' else f'${var}'
 
-        compat_setenv('yt_dlp_EXPATH_PATH', 'expanded')
+        os.environ['yt_dlp_EXPATH_PATH'] = 'expanded'
         self.assertEqual(expand_path(env('yt_dlp_EXPATH_PATH')), 'expanded')
-        self.assertEqual(expand_path(env('HOME')), compat_getenv('HOME'))
-        self.assertEqual(expand_path('~'), compat_getenv('HOME'))
-        self.assertEqual(
-            expand_path('~/%s' % env('yt_dlp_EXPATH_PATH')),
-            '%s/expanded' % compat_getenv('HOME'))
+
+        old_home = os.environ.get('HOME')
+        test_str = R'C:\Documents and Settings\тест\Application Data'
+        try:
+            os.environ['HOME'] = test_str
+            self.assertEqual(expand_path(env('HOME')), os.getenv('HOME'))
+            self.assertEqual(expand_path('~'), os.getenv('HOME'))
+            self.assertEqual(
+                expand_path('~/%s' % env('yt_dlp_EXPATH_PATH')),
+                '%s/expanded' % os.getenv('HOME'))
+        finally:
+            os.environ['HOME'] = old_home or ''
 
     def test_prepend_extension(self):
         self.assertEqual(prepend_extension('abc.ext', 'temp'), 'abc.temp.ext')
@@ -527,9 +540,6 @@ def test_str_to_int(self):
         self.assertEqual(str_to_int('123,456'), 123456)
         self.assertEqual(str_to_int('123.456'), 123456)
         self.assertEqual(str_to_int(523), 523)
-        # Python 3 has no long
-        if sys.version_info < (3, 0):
-            eval('self.assertEqual(str_to_int(123456L), 123456)')
         self.assertEqual(str_to_int('noninteger'), None)
         self.assertEqual(str_to_int([]), None)
 
@@ -617,6 +627,8 @@ def test_parse_duration(self):
         self.assertEqual(parse_duration('3h 11m 53s'), 11513)
         self.assertEqual(parse_duration('3 hours 11 minutes 53 seconds'), 11513)
         self.assertEqual(parse_duration('3 hours 11 mins 53 secs'), 11513)
+        self.assertEqual(parse_duration('3 hours, 11 minutes, 53 seconds'), 11513)
+        self.assertEqual(parse_duration('3 hours, 11 mins, 53 secs'), 11513)
         self.assertEqual(parse_duration('62m45s'), 3765)
         self.assertEqual(parse_duration('6m59s'), 419)
         self.assertEqual(parse_duration('49s'), 49)
@@ -635,6 +647,8 @@ def test_parse_duration(self):
         self.assertEqual(parse_duration('PT1H0.040S'), 3600.04)
         self.assertEqual(parse_duration('PT00H03M30SZ'), 210)
         self.assertEqual(parse_duration('P0Y0M0DT0H4M20.880S'), 260.88)
+        self.assertEqual(parse_duration('01:02:03:050'), 3723.05)
+        self.assertEqual(parse_duration('103:050'), 103.05)
 
     def test_fix_xml_ampersands(self):
         self.assertEqual(
@@ -654,8 +668,7 @@ def testPL(size, pagesize, sliceargs, expected):
             def get_page(pagenum):
                 firstid = pagenum * pagesize
                 upto = min(size, pagenum * pagesize + pagesize)
-                for i in range(firstid, upto):
-                    yield i
+                yield from range(firstid, upto)
 
             pl = OnDemandPagedList(get_page, pagesize)
             got = pl.getslice(*sliceargs)
@@ -724,7 +737,7 @@ def test_multipart_encode(self):
             multipart_encode({b'field': b'value'}, boundary='AAAAAA')[0],
             b'--AAAAAA\r\nContent-Disposition: form-data; name="field"\r\n\r\nvalue\r\n--AAAAAA--\r\n')
         self.assertEqual(
-            multipart_encode({'欄位'.encode('utf-8'): '值'.encode('utf-8')}, boundary='AAAAAA')[0],
+            multipart_encode({'欄位'.encode(): '值'.encode()}, boundary='AAAAAA')[0],
             b'--AAAAAA\r\nContent-Disposition: form-data; name="\xe6\xac\x84\xe4\xbd\x8d"\r\n\r\n\xe5\x80\xbc\r\n--AAAAAA--\r\n')
         self.assertRaises(
             ValueError, multipart_encode, {b'field': b'value'}, boundary='value')
@@ -1112,7 +1125,7 @@ def test_extract_attributes(self):
         self.assertEqual(extract_attributes('<e x="décompose&#769;">'), {'x': 'décompose\u0301'})
         # "Narrow" Python builds don't support unicode code points outside BMP.
         try:
-            compat_chr(0x10000)
+            chr(0x10000)
             supports_outside_bmp = True
         except ValueError:
             supports_outside_bmp = False
@@ -1123,7 +1136,7 @@ def test_extract_attributes(self):
 
     def test_clean_html(self):
         self.assertEqual(clean_html('a:\nb'), 'a: b')
-        self.assertEqual(clean_html('a:\n   "b"'), 'a:    "b"')
+        self.assertEqual(clean_html('a:\n   "b"'), 'a: "b"')
         self.assertEqual(clean_html('a<br>\xa0b'), 'a\nb')
 
     def test_intlist_to_bytes(self):
@@ -1385,7 +1398,7 @@ def test_dfxp2srt(self):
                     <p begin="3" dur="-1">Ignored, three</p>
                 </div>
             </body>
-            </tt>'''.encode('utf-8')
+            </tt>'''.encode()
         srt_data = '''1
 00:00:00,000 --> 00:00:01,000
 The following line contains Chinese characters and special symbols
@@ -1403,14 +1416,14 @@ def test_dfxp2srt(self):
 '''
         self.assertEqual(dfxp2srt(dfxp_data), srt_data)
 
-        dfxp_data_no_default_namespace = '''<?xml version="1.0" encoding="UTF-8"?>
+        dfxp_data_no_default_namespace = b'''<?xml version="1.0" encoding="UTF-8"?>
             <tt xml:lang="en" xmlns:tts="http://www.w3.org/ns/ttml#parameter">
             <body>
                 <div xml:lang="en">
                     <p begin="0" end="1">The first line</p>
                 </div>
             </body>
-            </tt>'''.encode('utf-8')
+            </tt>'''
         srt_data = '''1
 00:00:00,000 --> 00:00:01,000
 The first line
@@ -1418,7 +1431,7 @@ def test_dfxp2srt(self):
 '''
         self.assertEqual(dfxp2srt(dfxp_data_no_default_namespace), srt_data)
 
-        dfxp_data_with_style = '''<?xml version="1.0" encoding="utf-8"?>
+        dfxp_data_with_style = b'''<?xml version="1.0" encoding="utf-8"?>
 <tt xmlns="http://www.w3.org/2006/10/ttaf1" xmlns:ttp="http://www.w3.org/2006/10/ttaf1#parameter" ttp:timeBase="media" xmlns:tts="http://www.w3.org/2006/10/ttaf1#style" xml:lang="en" xmlns:ttm="http://www.w3.org/2006/10/ttaf1#metadata">
   <head>
     <styling>
@@ -1436,7 +1449,7 @@ def test_dfxp2srt(self):
       <p style="s1" tts:textDecoration="underline" begin="00:00:09.56" id="p2" end="00:00:12.36"><span style="s2" tts:color="lime">inner<br /> </span>style</p>
     </div>
   </body>
-</tt>'''.encode('utf-8')
+</tt>'''
         srt_data = '''1
 00:00:02,080 --> 00:00:05,840
 <font color="white" face="sansSerif" size="16">default style<font color="red">custom style</font></font>
@@ -1574,46 +1587,116 @@ def test_urshift(self):
         self.assertEqual(urshift(3, 1), 1)
         self.assertEqual(urshift(-3, 1), 2147483646)
 
+    GET_ELEMENT_BY_CLASS_TEST_STRING = '''
+        <span class="foo bar">nice</span>
+    '''
+
     def test_get_element_by_class(self):
-        html = '''
-            <span class="foo bar">nice</span>
-        '''
+        html = self.GET_ELEMENT_BY_CLASS_TEST_STRING
 
         self.assertEqual(get_element_by_class('foo', html), 'nice')
         self.assertEqual(get_element_by_class('no-such-class', html), None)
 
+    def test_get_element_html_by_class(self):
+        html = self.GET_ELEMENT_BY_CLASS_TEST_STRING
+
+        self.assertEqual(get_element_html_by_class('foo', html), html.strip())
+        self.assertEqual(get_element_by_class('no-such-class', html), None)
+
+    GET_ELEMENT_BY_ATTRIBUTE_TEST_STRING = '''
+        <div itemprop="author" itemscope>foo</div>
+    '''
+
     def test_get_element_by_attribute(self):
-        html = '''
-            <span class="foo bar">nice</span>
-        '''
+        html = self.GET_ELEMENT_BY_CLASS_TEST_STRING
 
         self.assertEqual(get_element_by_attribute('class', 'foo bar', html), 'nice')
         self.assertEqual(get_element_by_attribute('class', 'foo', html), None)
         self.assertEqual(get_element_by_attribute('class', 'no-such-foo', html), None)
 
-        html = '''
-            <div itemprop="author" itemscope>foo</div>
-        '''
+        html = self.GET_ELEMENT_BY_ATTRIBUTE_TEST_STRING
 
         self.assertEqual(get_element_by_attribute('itemprop', 'author', html), 'foo')
 
+    def test_get_element_html_by_attribute(self):
+        html = self.GET_ELEMENT_BY_CLASS_TEST_STRING
+
+        self.assertEqual(get_element_html_by_attribute('class', 'foo bar', html), html.strip())
+        self.assertEqual(get_element_html_by_attribute('class', 'foo', html), None)
+        self.assertEqual(get_element_html_by_attribute('class', 'no-such-foo', html), None)
+
+        html = self.GET_ELEMENT_BY_ATTRIBUTE_TEST_STRING
+
+        self.assertEqual(get_element_html_by_attribute('itemprop', 'author', html), html.strip())
+
+    GET_ELEMENTS_BY_CLASS_TEST_STRING = '''
+        <span class="foo bar">nice</span><span class="foo bar">also nice</span>
+    '''
+    GET_ELEMENTS_BY_CLASS_RES = ['<span class="foo bar">nice</span>', '<span class="foo bar">also nice</span>']
+
     def test_get_elements_by_class(self):
-        html = '''
-            <span class="foo bar">nice</span><span class="foo bar">also nice</span>
-        '''
+        html = self.GET_ELEMENTS_BY_CLASS_TEST_STRING
 
         self.assertEqual(get_elements_by_class('foo', html), ['nice', 'also nice'])
         self.assertEqual(get_elements_by_class('no-such-class', html), [])
 
+    def test_get_elements_html_by_class(self):
+        html = self.GET_ELEMENTS_BY_CLASS_TEST_STRING
+
+        self.assertEqual(get_elements_html_by_class('foo', html), self.GET_ELEMENTS_BY_CLASS_RES)
+        self.assertEqual(get_elements_html_by_class('no-such-class', html), [])
+
     def test_get_elements_by_attribute(self):
-        html = '''
-            <span class="foo bar">nice</span><span class="foo bar">also nice</span>
-        '''
+        html = self.GET_ELEMENTS_BY_CLASS_TEST_STRING
 
         self.assertEqual(get_elements_by_attribute('class', 'foo bar', html), ['nice', 'also nice'])
         self.assertEqual(get_elements_by_attribute('class', 'foo', html), [])
         self.assertEqual(get_elements_by_attribute('class', 'no-such-foo', html), [])
 
+    def test_get_elements_html_by_attribute(self):
+        html = self.GET_ELEMENTS_BY_CLASS_TEST_STRING
+
+        self.assertEqual(get_elements_html_by_attribute('class', 'foo bar', html), self.GET_ELEMENTS_BY_CLASS_RES)
+        self.assertEqual(get_elements_html_by_attribute('class', 'foo', html), [])
+        self.assertEqual(get_elements_html_by_attribute('class', 'no-such-foo', html), [])
+
+    def test_get_elements_text_and_html_by_attribute(self):
+        html = self.GET_ELEMENTS_BY_CLASS_TEST_STRING
+
+        self.assertEqual(
+            list(get_elements_text_and_html_by_attribute('class', 'foo bar', html)),
+            list(zip(['nice', 'also nice'], self.GET_ELEMENTS_BY_CLASS_RES)))
+        self.assertEqual(list(get_elements_text_and_html_by_attribute('class', 'foo', html)), [])
+        self.assertEqual(list(get_elements_text_and_html_by_attribute('class', 'no-such-foo', html)), [])
+
+    GET_ELEMENT_BY_TAG_TEST_STRING = '''
+    random text lorem ipsum</p>
+    <div>
+        this should be returned
+        <span>this should also be returned</span>
+        <div>
+            this should also be returned
+        </div>
+        closing tag above should not trick, so this should also be returned
+    </div>
+    but this text should not be returned
+    '''
+    GET_ELEMENT_BY_TAG_RES_OUTERDIV_HTML = GET_ELEMENT_BY_TAG_TEST_STRING.strip()[32:276]
+    GET_ELEMENT_BY_TAG_RES_OUTERDIV_TEXT = GET_ELEMENT_BY_TAG_RES_OUTERDIV_HTML[5:-6]
+    GET_ELEMENT_BY_TAG_RES_INNERSPAN_HTML = GET_ELEMENT_BY_TAG_TEST_STRING.strip()[78:119]
+    GET_ELEMENT_BY_TAG_RES_INNERSPAN_TEXT = GET_ELEMENT_BY_TAG_RES_INNERSPAN_HTML[6:-7]
+
+    def test_get_element_text_and_html_by_tag(self):
+        html = self.GET_ELEMENT_BY_TAG_TEST_STRING
+
+        self.assertEqual(
+            get_element_text_and_html_by_tag('div', html),
+            (self.GET_ELEMENT_BY_TAG_RES_OUTERDIV_TEXT, self.GET_ELEMENT_BY_TAG_RES_OUTERDIV_HTML))
+        self.assertEqual(
+            get_element_text_and_html_by_tag('span', html),
+            (self.GET_ELEMENT_BY_TAG_RES_INNERSPAN_TEXT, self.GET_ELEMENT_BY_TAG_RES_INNERSPAN_HTML))
+        self.assertRaises(compat_HTMLParseError, get_element_text_and_html_by_tag, 'article', html)
+
     def test_iri_to_uri(self):
         self.assertEqual(
             iri_to_uri('https://www.google.com/search?q=foo&ie=utf-8&oe=utf-8&client=firefox-b'),
@@ -1673,7 +1756,7 @@ def test_LazyList_laziness(self):
 
         def test(ll, idx, val, cache):
             self.assertEqual(ll[idx], val)
-            self.assertEqual(getattr(ll, '_LazyList__cache'), list(cache))
+            self.assertEqual(ll._cache, list(cache))
 
         ll = LazyList(range(10))
         test(ll, 0, 0, range(1))
@@ -1700,6 +1783,44 @@ def test_format_bytes(self):
         self.assertEqual(format_bytes(1024**6), '1.00EiB')
         self.assertEqual(format_bytes(1024**7), '1.00ZiB')
         self.assertEqual(format_bytes(1024**8), '1.00YiB')
+        self.assertEqual(format_bytes(1024**9), '1024.00YiB')
+
+    def test_hide_login_info(self):
+        self.assertEqual(Config.hide_login_info(['-u', 'foo', '-p', 'bar']),
+                         ['-u', 'PRIVATE', '-p', 'PRIVATE'])
+        self.assertEqual(Config.hide_login_info(['-u']), ['-u'])
+        self.assertEqual(Config.hide_login_info(['-u', 'foo', '-u', 'bar']),
+                         ['-u', 'PRIVATE', '-u', 'PRIVATE'])
+        self.assertEqual(Config.hide_login_info(['--username=foo']),
+                         ['--username=PRIVATE'])
+
+    def test_locked_file(self):
+        TEXT = 'test_locked_file\n'
+        FILE = 'test_locked_file.ytdl'
+        MODES = 'war'  # Order is important
+
+        try:
+            for lock_mode in MODES:
+                with locked_file(FILE, lock_mode, False) as f:
+                    if lock_mode == 'r':
+                        self.assertEqual(f.read(), TEXT * 2, 'Wrong file content')
+                    else:
+                        f.write(TEXT)
+                    for test_mode in MODES:
+                        testing_write = test_mode != 'r'
+                        try:
+                            with locked_file(FILE, test_mode, False):
+                                pass
+                        except (BlockingIOError, PermissionError):
+                            if not testing_write:  # FIXME
+                                print(f'Known issue: Exclusive lock ({lock_mode}) blocks read access ({test_mode})')
+                                continue
+                            self.assertTrue(testing_write, f'{test_mode} is blocked by {lock_mode}')
+                        else:
+                            self.assertFalse(testing_write, f'{test_mode} is not blocked by {lock_mode}')
+        finally:
+            with contextlib.suppress(OSError):
+                os.remove(FILE)
 
 
 if __name__ == '__main__':