]> jfr.im git - yt-dlp.git/commitdiff
Make outtmpl more robust and catch errors early
authorpukkandan <redacted>
Tue, 8 Jun 2021 14:41:00 +0000 (20:11 +0530)
committerpukkandan <redacted>
Tue, 8 Jun 2021 14:41:00 +0000 (20:11 +0530)
test/test_YoutubeDL.py
yt_dlp/YoutubeDL.py
yt_dlp/__init__.py

index 30c48c78fa002233aebbeea6728965e16418aa45..e77597d3c8f592b6a3be2f05d76d761fc72ab80f 100644 (file)
@@ -669,6 +669,9 @@ def out(tmpl, **params):
             params['outtmpl'] = tmpl
             ydl = YoutubeDL(params)
             ydl._num_downloads = 1
+            err = ydl.validate_outtmpl(tmpl)
+            if err:
+                raise err
             outtmpl, tmpl_dict = ydl.prepare_outtmpl(tmpl, self.outtmpl_info)
             return outtmpl % tmpl_dict
 
@@ -686,6 +689,9 @@ def out(tmpl, **params):
         self.assertEqual(out('%(invalid@tmpl|def)s', outtmpl_na_placeholder='none'), 'none')
         self.assertEqual(out('%()s'), 'NA')
         self.assertEqual(out('%s'), '%s')
+        self.assertEqual(out('%d'), '%d')
+        self.assertRaises(ValueError, out, '%')
+        self.assertRaises(ValueError, out, '%(title)')
 
         NA_TEST_OUTTMPL = '%(uploader_date)s-%(width)d-%(x|def)s-%(id)s.%(ext)s'
         self.assertEqual(out(NA_TEST_OUTTMPL), 'NA-NA-def-1234.mp4')
@@ -705,6 +711,8 @@ def out(tmpl, **params):
         self.assertEqual(out(FMT_TEST_OUTTMPL % '   0   6d'), ' 01080.mp4')
 
         self.assertEqual(out('%(id)d'), '1234')
+        self.assertEqual(out('%(height)c'), '1')
+        self.assertEqual(out('%(ext)c'), 'm')
         self.assertEqual(out('%(id)d %(id)r'), "1234 '1234'")
         self.assertEqual(out('%(ext)s-%(ext|def)d'), 'mp4-def')
         self.assertEqual(out('%(width|0)04d'), '0000')
@@ -715,6 +723,7 @@ def out(tmpl, **params):
         self.assertEqual(out('%(id+1-height+3)05d'), '00158')
         self.assertEqual(out('%(width+100)05d'), 'NA')
         self.assertEqual(out('%(formats.0)s'), str(FORMATS[0]))
+        self.assertEqual(out('%(height.0)03d'), '001')
         self.assertEqual(out('%(formats.-1.id)s'), str(FORMATS[-1]['id']))
         self.assertEqual(out('%(formats.3)s'), 'NA')
         self.assertEqual(out('%(formats.:2:-1)r'), repr(FORMATS[:2:-1]))
index 1643649fbaad1d74ec4ec00eb5141546422309ab..ad96cebcd80ce5f3cbb624e4244cafaca376bc21 100644 (file)
@@ -813,6 +813,19 @@ def parse_outtmpl(self):
                     'Put  from __future__ import unicode_literals  at the top of your code file or consider switching to Python 3.x.')
         return outtmpl_dict
 
+    @staticmethod
+    def validate_outtmpl(tmpl):
+        ''' @return None or Exception object '''
+        try:
+            re.sub(
+                STR_FORMAT_RE.format(''),
+                lambda mobj: ('%' if not mobj.group('has_key') else '') + mobj.group(0),
+                tmpl
+            ) % collections.defaultdict(int)
+            return None
+        except ValueError as err:
+            return err
+
     def prepare_outtmpl(self, outtmpl, info_dict, sanitize=None):
         """ Make the template and info_dict suitable for substitution (outtmpl % info_dict)"""
         info_dict = dict(info_dict)
@@ -852,10 +865,12 @@ def prepare_outtmpl(self, outtmpl, info_dict, sanitize=None):
         }
         tmpl_dict = {}
 
+        get_key = lambda k: traverse_obj(
+            info_dict, k.split('.'), is_user_input=True, traverse_string=True)
+
         def get_value(mdict):
             # Object traversal
-            fields = mdict['fields'].split('.')
-            value = traverse_obj(info_dict, fields)
+            value = get_key(mdict['fields'])
             # Negative
             if mdict['negate']:
                 value = float_or_none(value)
@@ -872,7 +887,7 @@ def get_value(mdict):
                         item, multiplier = (item[1:], -1) if item[0] == '-' else (item, 1)
                         offset = float_or_none(item)
                         if offset is None:
-                            offset = float_or_none(traverse_obj(info_dict, item.split('.')))
+                            offset = float_or_none(get_key(item))
                         try:
                             value = operator(value, multiplier * offset)
                         except (TypeError, ZeroDivisionError):
@@ -906,7 +921,13 @@ def create_key(outer_mobj):
             value = default if value is None else value
             key += '\0%s' % fmt
 
-            if fmt[-1] not in 'crs':  # numeric
+            if fmt == 'c':
+                value = compat_str(value)
+                if value is None:
+                    value, fmt = default, 's'
+                else:
+                    value = value[0]
+            elif fmt[-1] not in 'rs':  # numeric
                 value = float_or_none(value)
                 if value is None:
                     value, fmt = default, 's'
index 45a29d3c766f79350fe0b3d1ed7d30614144b335..6d6b0dd66cf2b19975581bbe00557a00b732a14e 100644 (file)
@@ -24,6 +24,7 @@
     DateRange,
     decodeOption,
     DownloadError,
+    error_to_compat_str,
     ExistingVideoReached,
     expand_path,
     match_filter_func,
@@ -307,6 +308,16 @@ def set_default_compat(compat_name, opt_name, default=True, remove_compat=False)
         else:
             _unused_compat_opt('filename')
 
+    def validate_outtmpl(tmpl, msg):
+        err = YoutubeDL.validate_outtmpl(tmpl)
+        if err:
+            parser.error('invalid %s %r: %s' % (msg, tmpl, error_to_compat_str(err)))
+
+    for k, tmpl in opts.outtmpl.items():
+        validate_outtmpl(tmpl, '%s output template' % k)
+    for tmpl in opts.forceprint:
+        validate_outtmpl(tmpl, 'print template')
+
     if opts.extractaudio and not opts.keepvideo and opts.format is None:
         opts.format = 'bestaudio/best'