]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/webvtt.py
[ie/podbayfm] Fix extraction (#10195)
[yt-dlp.git] / yt_dlp / webvtt.py
index ef55e6459ad2ae0ed9045495b3bdbf9f038fef33..9f1a5086b8b9b929bc0b5758ef19958f8ee5c82a 100644 (file)
@@ -1,6 +1,3 @@
-# coding: utf-8
-from __future__ import unicode_literals, print_function, division
-
 """
 A partial parser for WebVTT segments. Interprets enough of the WebVTT stream
 to be able to assemble a single stand-alone subtitle file, suitably adjusting
 in RFC 8216 §3.5 <https://tools.ietf.org/html/rfc8216#section-3.5>.
 """
 
-import re
 import io
-from .utils import int_or_none
-from .compat import (
-    compat_str as str,
-    compat_Pattern,
-    compat_Match,
-)
+import re
 
+from .utils import int_or_none, timetuple_from_msec
 
-class _MatchParser(object):
+
+class _MatchParser:
     """
     An object that maintains the current parsing position and allows
     conveniently advancing it as syntax elements are successfully parsed.
@@ -32,7 +25,7 @@ def __init__(self, string):
         self._pos = 0
 
     def match(self, r):
-        if isinstance(r, compat_Pattern):
+        if isinstance(r, re.Pattern):
             return r.match(self._data, self._pos)
         if isinstance(r, str):
             if self._data.startswith(r, self._pos):
@@ -43,7 +36,7 @@ def match(self, r):
     def advance(self, by):
         if by is None:
             amt = 0
-        elif isinstance(by, compat_Match):
+        elif isinstance(by, re.Match):
             amt = len(by.group(0))
         elif isinstance(by, str):
             amt = len(by)
@@ -70,7 +63,7 @@ class _MatchChildParser(_MatchParser):
     """
 
     def __init__(self, parent):
-        super(_MatchChildParser, self).__init__(parent._data)
+        super().__init__(parent._data)
         self.__parent = parent
         self._pos = parent._pos
 
@@ -84,20 +77,24 @@ def commit(self):
 
 class ParseError(Exception):
     def __init__(self, parser):
-        super(ParseError, self).__init__("Parse error at position %u (near %r)" % (
-            parser._pos, parser._data[parser._pos:parser._pos + 20]
-        ))
+        data = parser._data[parser._pos:parser._pos + 100]
+        super().__init__(f'Parse error at position {parser._pos} (near {data!r})')
 
 
+# While the specification <https://www.w3.org/TR/webvtt1/#webvtt-timestamp>
+# prescribes that hours must be *2 or more* digits, timestamps with a single
+# digit for the hour part has been seen in the wild.
+# See https://github.com/yt-dlp/yt-dlp/issues/921
 _REGEX_TS = re.compile(r'''(?x)
-    (?:([0-9]{2,}):)?
+    (?:([0-9]{1,}):)?
     ([0-9]{2}):
     ([0-9]{2})\.
     ([0-9]{3})?
 ''')
 _REGEX_EOF = re.compile(r'\Z')
-_REGEX_NL = re.compile(r'(?:\r\n|[\r\n])')
+_REGEX_NL = re.compile(r'(?:\r\n|[\r\n]|$)')
 _REGEX_BLANK = re.compile(r'(?:\r\n|[\r\n])+')
+_REGEX_OPTIONAL_WHITESPACE = re.compile(r'[ \t]*')
 
 
 def _parse_ts(ts):
@@ -105,14 +102,8 @@ def _parse_ts(ts):
     Convert a parsed WebVTT timestamp (a re.Match obtained from _REGEX_TS)
     into an MPEG PES timestamp: a tick counter at 90 kHz resolution.
     """
-
-    h, min, s, ms = ts.groups()
-    return 90 * (
-        int(h or 0) * 3600000 +  # noqa: W504,E221,E222
-        int(min)    *   60000 +  # noqa: W504,E221,E222
-        int(s)      *    1000 +  # noqa: W504,E221,E222
-        int(ms)                  # noqa: W504,E221,E222
-    )
+    return 90 * sum(
+        int(part or 0) * mult for part, mult in zip(ts.groups(), (3600_000, 60_000, 1000, 1)))
 
 
 def _format_ts(ts):
@@ -120,14 +111,10 @@ def _format_ts(ts):
     Convert an MPEG PES timestamp into a WebVTT timestamp.
     This will lose sub-millisecond precision.
     """
-    msec = int((ts + 45) // 90)
-    secs, msec = divmod(msec, 1000)
-    mins, secs = divmod(secs, 60)
-    hrs, mins = divmod(mins, 60)
-    return '%02u:%02u:%02u.%03u' % (hrs, mins, secs, msec)
+    return '%02u:%02u:%02u.%03u' % timetuple_from_msec(int((ts + 45) // 90))
 
 
-class Block(object):
+class Block:
     """
     An abstract WebVTT block.
     """
@@ -153,7 +140,6 @@ class HeaderBlock(Block):
     A WebVTT block that may only appear in the header part of the file,
     i.e. before any cue blocks.
     """
-
     pass
 
 
@@ -162,7 +148,7 @@ class Magic(HeaderBlock):
 
     # XXX: The X-TIMESTAMP-MAP extension is described in RFC 8216 §3.5
     # <https://tools.ietf.org/html/rfc8216#section-3.5>, but the RFC
-    # doesnt specify the exact grammar nor where in the WebVTT
+    # doesn't specify the exact grammar nor where in the WebVTT
     # syntax it should be placed; the below has been devised based
     # on usage in the wild
     #
@@ -172,6 +158,13 @@ class Magic(HeaderBlock):
     _REGEX_TSMAP = re.compile(r'X-TIMESTAMP-MAP=')
     _REGEX_TSMAP_LOCAL = re.compile(r'LOCAL:')
     _REGEX_TSMAP_MPEGTS = re.compile(r'MPEGTS:([0-9]+)')
+    _REGEX_TSMAP_SEP = re.compile(r'[ \t]*,[ \t]*')
+
+    # This was removed from the spec in the 2017 revision;
+    # the last spec draft to describe this syntax element is
+    # <https://www.w3.org/TR/2015/WD-webvtt1-20151208/#webvtt-metadata-header>.
+    # Nevertheless, YouTube keeps serving those
+    _REGEX_META = re.compile(r'(?:(?!-->)[^\r\n])+:(?:(?!-->)[^\r\n])+(?:\r\n|[\r\n])')
 
     @classmethod
     def __parse_tsmap(cls, parser):
@@ -194,7 +187,7 @@ def __parse_tsmap(cls, parser):
                         raise ParseError(parser)
                 else:
                     raise ParseError(parser)
-            if parser.consume(','):
+            if parser.consume(cls._REGEX_TSMAP_SEP):
                 continue
             if parser.consume(_REGEX_NL):
                 break
@@ -212,13 +205,18 @@ def parse(cls, parser):
             raise ParseError(parser)
 
         extra = m.group(1)
-        local, mpegts = None, None
-        if parser.consume(cls._REGEX_TSMAP):
-            local, mpegts = cls.__parse_tsmap(parser)
-        if not parser.consume(_REGEX_NL):
+        local, mpegts, meta = None, None, ''
+        while not parser.consume(_REGEX_NL):
+            if parser.consume(cls._REGEX_TSMAP):
+                local, mpegts = cls.__parse_tsmap(parser)
+                continue
+            m = parser.consume(cls._REGEX_META)
+            if m:
+                meta += m.group(0)
+                continue
             raise ParseError(parser)
         parser.commit()
-        return cls(extra=extra, mpegts=mpegts, local=local)
+        return cls(extra=extra, mpegts=mpegts, local=local, meta=meta)
 
     def write_into(self, stream):
         stream.write('WEBVTT')
@@ -231,6 +229,8 @@ def write_into(self, stream):
             stream.write(',MPEGTS:')
             stream.write(str(self.mpegts if self.mpegts is not None else 0))
             stream.write('\n')
+        if self.meta:
+            stream.write(self.meta)
         stream.write('\n')
 
 
@@ -272,10 +272,10 @@ class CueBlock(Block):
     def parse(cls, parser):
         parser = parser.child()
 
-        id = None
+        id_ = None
         m = parser.consume(cls._REGEX_ID)
         if m:
-            id = m.group(1)
+            id_ = m.group(1)
 
         m0 = parser.consume(_REGEX_TS)
         if not m0:
@@ -286,6 +286,7 @@ def parse(cls, parser):
         if not m1:
             return None
         m2 = parser.consume(cls._REGEX_SETTINGS)
+        parser.consume(_REGEX_OPTIONAL_WHITESPACE)
         if not parser.consume(_REGEX_NL):
             return None
 
@@ -302,9 +303,9 @@ def parse(cls, parser):
 
         parser.commit()
         return cls(
-            id=id,
+            id=id_,
             start=start, end=end, settings=settings,
-            text=text.getvalue()
+            text=text.getvalue(),
         )
 
     def write_into(self, stream):
@@ -331,6 +332,26 @@ def as_json(self):
             'settings': self.settings,
         }
 
+    def __eq__(self, other):
+        return self.as_json == other.as_json
+
+    @classmethod
+    def from_json(cls, json):
+        return cls(
+            id=json['id'],
+            start=json['start'],
+            end=json['end'],
+            text=json['text'],
+            settings=json['settings'],
+        )
+
+    def hinges(self, other):
+        if self.text != other.text:
+            return False
+        if self.settings != other.settings:
+            return False
+        return self.start <= self.end == other.start <= other.end
+
 
 def parse_fragment(frag_content):
     """
@@ -338,7 +359,7 @@ def parse_fragment(frag_content):
     a bytes object containing the raw contents of a WebVTT file.
     """
 
-    parser = _MatchParser(frag_content.decode('utf-8'))
+    parser = _MatchParser(frag_content.decode())
 
     yield Magic.parse(parser)