]> jfr.im git - yt-dlp.git/commitdiff
[downloader/hls] Assemble single-file WebVTT subtitles from HLS segments
authorFelix S <redacted>
Wed, 28 Apr 2021 10:47:30 +0000 (16:17 +0530)
committerFelix S <redacted>
Wed, 28 Apr 2021 11:51:14 +0000 (17:21 +0530)
yt_dlp/compat.py
yt_dlp/downloader/hls.py
yt_dlp/extractor/common.py
yt_dlp/webvtt.py [new file with mode: 0644]

index 3ebf1ee7a6f02f48dc7a14bd85c1f87b685a9687..863bd2287cd9365bc3fba3c28731a9865c26bcfb 100644 (file)
@@ -3018,10 +3018,24 @@ def compat_ctypes_WINFUNCTYPE(*args, **kwargs):
         return ctypes.WINFUNCTYPE(*args, **kwargs)
 
 
+try:
+    compat_Pattern = re.Pattern
+except AttributeError:
+    compat_Pattern = type(re.compile(''))
+
+
+try:
+    compat_Match = re.Match
+except AttributeError:
+    compat_Match = type(re.compile('').match(''))
+
+
 __all__ = [
     'compat_HTMLParseError',
     'compat_HTMLParser',
     'compat_HTTPError',
+    'compat_Match',
+    'compat_Pattern',
     'compat_Struct',
     'compat_b64decode',
     'compat_basestring',
index f4e41a6c7b0f6405a68903e6756475809265cbe1..cee3807ceb963b21f03a48cac87c9d3694532eb4 100644 (file)
@@ -2,6 +2,7 @@
 
 import errno
 import re
+import io
 import binascii
 try:
     from Crypto.Cipher import AES
@@ -27,7 +28,9 @@
     parse_m3u8_attributes,
     sanitize_open,
     update_url_query,
+    bug_reports_message,
 )
+from .. import webvtt
 
 
 class HlsFD(FragmentFD):
@@ -78,6 +81,8 @@ def real_download(self, filename, info_dict):
         man_url = info_dict['url']
         self.to_screen('[%s] Downloading m3u8 manifest' % self.FD_NAME)
 
+        is_webvtt = info_dict['ext'] == 'vtt'
+
         urlh = self.ydl.urlopen(self._prepare_url(info_dict, man_url))
         man_url = urlh.geturl()
         s = urlh.read().decode('utf-8', 'ignore')
@@ -142,6 +147,8 @@ def is_ad_fragment_end(s):
         else:
             self._prepare_and_start_frag_download(ctx)
 
+        extra_state = ctx.setdefault('extra_state', {})
+
         fragment_retries = self.params.get('fragment_retries', 0)
         skip_unavailable_fragments = self.params.get('skip_unavailable_fragments', True)
         test = self.params.get('test', False)
@@ -308,6 +315,42 @@ def download_fragment(fragment):
 
                 return frag_content, frag_index
 
+            pack_fragment = lambda frag_content, _: frag_content
+
+            if is_webvtt:
+                def pack_fragment(frag_content, frag_index):
+                    output = io.StringIO()
+                    adjust = 0
+                    for block in webvtt.parse_fragment(frag_content):
+                        if isinstance(block, webvtt.CueBlock):
+                            block.start += adjust
+                            block.end += adjust
+                        elif isinstance(block, webvtt.Magic):
+                            # XXX: we do not handle MPEGTS overflow
+                            if frag_index == 1:
+                                extra_state['webvtt_mpegts'] = block.mpegts or 0
+                                extra_state['webvtt_local'] = block.local or 0
+                                # XXX: block.local = block.mpegts = None ?
+                            else:
+                                if block.mpegts is not None and block.local is not None:
+                                    adjust = (
+                                        (block.mpegts - extra_state.get('webvtt_mpegts', 0))
+                                        - (block.local - extra_state.get('webvtt_local', 0))
+                                    )
+                                continue
+                        elif isinstance(block, webvtt.HeaderBlock):
+                            if frag_index != 1:
+                                # XXX: this should probably be silent as well
+                                # or verify that all segments contain the same data
+                                self.report_warning(bug_reports_message(
+                                    'Discarding a %s block found in the middle of the stream; '
+                                    'if the subtitles display incorrectly,'
+                                    % (type(block).__name__)))
+                                continue
+                        block.write_into(output)
+
+                    return output.getvalue().encode('utf-8')
+
             def append_fragment(frag_content, frag_index):
                 if frag_content:
                     fragment_filename = '%s-Frag%d' % (ctx['tmpfilename'], frag_index)
@@ -315,6 +358,7 @@ def append_fragment(frag_content, frag_index):
                         file, frag_sanitized = sanitize_open(fragment_filename, 'rb')
                         ctx['fragment_filename_sanitized'] = frag_sanitized
                         file.close()
+                        frag_content = pack_fragment(frag_content, frag_index)
                         self._append_fragment(ctx, frag_content)
                         return True
                     except EnvironmentError as ose:
index 6257c17cd7e93466617a1e11ccfd1c62fa8a85fa..803c7fa066a18ed6972f7a9bedc5daeb5577c837 100644 (file)
@@ -2035,6 +2035,12 @@ def extract_media(x_media_line):
                     'url': url,
                     'ext': determine_ext(url),
                 }
+                if sub_info['ext'] == 'm3u8':
+                    # Per RFC 8216 §3.1, the only possible subtitle format m3u8
+                    # files may contain is WebVTT:
+                    # <https://tools.ietf.org/html/rfc8216#section-3.1>
+                    sub_info['ext'] = 'vtt'
+                    sub_info['protocol'] = 'm3u8_native'
                 subtitles.setdefault(lang, []).append(sub_info)
             if media_type not in ('VIDEO', 'AUDIO'):
                 return
diff --git a/yt_dlp/webvtt.py b/yt_dlp/webvtt.py
new file mode 100644 (file)
index 0000000..4d02683
--- /dev/null
@@ -0,0 +1,368 @@
+# coding: utf-8
+from __future__ import unicode_literals, print_function, division
+
+"""
+A partial parser for WebVTT segments. Interprets enough of the WebVTT stream
+to be able to assemble a single stand-alone subtitle file, suitably adjusting
+timestamps on the way, while everything else is passed through unmodified.
+
+Regular expressions based on the W3C WebVTT specification
+<https://www.w3.org/TR/webvtt1/>. The X-TIMESTAMP-MAP extension is described
+in RFC 8216 §3.5 <https://tools.ietf.org/html/rfc8216#section-3.5>.
+"""
+
+import re
+import io
+from .utils import int_or_none
+from .compat import (
+    compat_str as str,
+    compat_Pattern,
+    compat_Match,
+)
+
+
+class _MatchParser(object):
+    """
+    An object that maintains the current parsing position and allows
+    conveniently advancing it as syntax elements are successfully parsed.
+    """
+
+    def __init__(self, string):
+        self._data = string
+        self._pos = 0
+
+    def match(self, r):
+        if isinstance(r, compat_Pattern):
+            return r.match(self._data, self._pos)
+        if isinstance(r, str):
+            if self._data.startswith(r, self._pos):
+                return len(r)
+            return None
+        raise ValueError(r)
+
+    def advance(self, by):
+        if by is None:
+            amt = 0
+        elif isinstance(by, compat_Match):
+            amt = len(by.group(0))
+        elif isinstance(by, str):
+            amt = len(by)
+        elif isinstance(by, int):
+            amt = by
+        else:
+            raise ValueError(by)
+        self._pos += amt
+        return by
+
+    def consume(self, r):
+        return self.advance(self.match(r))
+
+    def child(self):
+        return _MatchChildParser(self)
+
+
+class _MatchChildParser(_MatchParser):
+    """
+    A child parser state, which advances through the same data as
+    its parent, but has an independent position. This is useful when
+    advancing through syntax elements we might later want to backtrack
+    from.
+    """
+
+    def __init__(self, parent):
+        super(_MatchChildParser, self).__init__(parent._data)
+        self.__parent = parent
+        self._pos = parent._pos
+
+    def commit(self):
+        """
+        Advance the parent state to the current position of this child state.
+        """
+        self.__parent._pos = self._pos
+        return self.__parent
+
+
+class ParseError(Exception):
+    def __init__(self, parser):
+        super(ParseError, self).__init__("Parse error at position %u (near %r)" % (
+            parser._pos, parser._data[parser._pos:parser._pos + 20]
+        ))
+
+
+_REGEX_TS = re.compile(r'''(?x)
+    (?:([0-9]{2,}):)?
+    ([0-9]{2}):
+    ([0-9]{2})\.
+    ([0-9]{3})?
+''')
+_REGEX_EOF = re.compile(r'\Z')
+_REGEX_NL = re.compile(r'(?:\r\n|[\r\n])')
+_REGEX_BLANK = re.compile(r'(?:\r\n|[\r\n])+')
+
+
+def _parse_ts(ts):
+    """
+    Convert a parsed WebVTT timestamp (a re.Match obtained from _REGEX_TS)
+    into an MPEG PES timestamp: a tick counter at 90 kHz resolution.
+    """
+
+    h, min, s, ms = ts.groups()
+    return 90 * (
+        int(h or 0) * 3600000 +  # noqa: W504,E221,E222
+        int(min)    *   60000 +  # noqa: W504,E221,E222
+        int(s)      *    1000 +  # noqa: W504,E221,E222
+        int(ms)                  # noqa: W504,E221,E222
+    )
+
+
+def _format_ts(ts):
+    """
+    Convert an MPEG PES timestamp into a WebVTT timestamp.
+    This will lose sub-millisecond precision.
+    """
+
+    ts = int((ts + 45) // 90)
+    ms , ts = divmod(ts, 1000)  # noqa: W504,E221,E222,E203
+    s  , ts = divmod(ts, 60)    # noqa: W504,E221,E222,E203
+    min, h  = divmod(ts, 60)    # noqa: W504,E221,E222
+    return '%02u:%02u:%02u.%03u' % (h, min, s, ms)
+
+
+class Block(object):
+    """
+    An abstract WebVTT block.
+    """
+
+    def __init__(self, **kwargs):
+        for key, val in kwargs.items():
+            setattr(self, key, val)
+
+    @classmethod
+    def parse(cls, parser):
+        m = parser.match(cls._REGEX)
+        if not m:
+            return None
+        parser.advance(m)
+        return cls(raw=m.group(0))
+
+    def write_into(self, stream):
+        stream.write(self.raw)
+
+
+class HeaderBlock(Block):
+    """
+    A WebVTT block that may only appear in the header part of the file,
+    i.e. before any cue blocks.
+    """
+
+    pass
+
+
+class Magic(HeaderBlock):
+    _REGEX = re.compile(r'\ufeff?WEBVTT([ \t][^\r\n]*)?(?:\r\n|[\r\n])')
+
+    # XXX: The X-TIMESTAMP-MAP extension is described in RFC 8216 §3.5
+    # <https://tools.ietf.org/html/rfc8216#section-3.5>, but the RFC
+    # doesn’t specify the exact grammar nor where in the WebVTT
+    # syntax it should be placed; the below has been devised based
+    # on usage in the wild
+    #
+    # And strictly speaking, the presence of this extension violates
+    # the W3C WebVTT spec. Oh well.
+
+    _REGEX_TSMAP = re.compile(r'X-TIMESTAMP-MAP=')
+    _REGEX_TSMAP_LOCAL = re.compile(r'LOCAL:')
+    _REGEX_TSMAP_MPEGTS = re.compile(r'MPEGTS:([0-9]+)')
+
+    @classmethod
+    def __parse_tsmap(cls, parser):
+        parser = parser.child()
+
+        while True:
+            m = parser.consume(cls._REGEX_TSMAP_LOCAL)
+            if m:
+                m = parser.consume(_REGEX_TS)
+                if m is None:
+                    raise ParseError(parser)
+                local = _parse_ts(m)
+                if local is None:
+                    raise ParseError(parser)
+            else:
+                m = parser.consume(cls._REGEX_TSMAP_MPEGTS)
+                if m:
+                    mpegts = int_or_none(m.group(1))
+                    if mpegts is None:
+                        raise ParseError(parser)
+                else:
+                    raise ParseError(parser)
+            if parser.consume(','):
+                continue
+            if parser.consume(_REGEX_NL):
+                break
+            raise ParseError(parser)
+
+        parser.commit()
+        return local, mpegts
+
+    @classmethod
+    def parse(cls, parser):
+        parser = parser.child()
+
+        m = parser.consume(cls._REGEX)
+        if not m:
+            raise ParseError(parser)
+
+        extra = m.group(1)
+        local, mpegts = None, None
+        if parser.consume(cls._REGEX_TSMAP):
+            local, mpegts = cls.__parse_tsmap(parser)
+        if not parser.consume(_REGEX_NL):
+            raise ParseError(parser)
+        parser.commit()
+        return cls(extra=extra, mpegts=mpegts, local=local)
+
+    def write_into(self, stream):
+        stream.write('WEBVTT')
+        if self.extra is not None:
+            stream.write(self.extra)
+        stream.write('\n')
+        if self.local or self.mpegts:
+            stream.write('X-TIMESTAMP-MAP=LOCAL:')
+            stream.write(_format_ts(self.local if self.local is not None else 0))
+            stream.write(',MPEGTS:')
+            stream.write(str(self.mpegts if self.mpegts is not None else 0))
+            stream.write('\n')
+        stream.write('\n')
+
+
+class StyleBlock(HeaderBlock):
+    _REGEX = re.compile(r'''(?x)
+        STYLE[\ \t]*(?:\r\n|[\r\n])
+        ((?:(?!-->)[^\r\n])+(?:\r\n|[\r\n]))*
+        (?:\r\n|[\r\n])
+    ''')
+
+
+class RegionBlock(HeaderBlock):
+    _REGEX = re.compile(r'''(?x)
+        REGION[\ \t]*
+        ((?:(?!-->)[^\r\n])+(?:\r\n|[\r\n]))*
+        (?:\r\n|[\r\n])
+    ''')
+
+
+class CommentBlock(Block):
+    _REGEX = re.compile(r'''(?x)
+        NOTE(?:\r\n|[\ \t\r\n])
+        ((?:(?!-->)[^\r\n])+(?:\r\n|[\r\n]))*
+        (?:\r\n|[\r\n])
+    ''')
+
+
+class CueBlock(Block):
+    """
+    A cue block. The payload is not interpreted.
+    """
+
+    _REGEX_ID = re.compile(r'((?:(?!-->)[^\r\n])+)(?:\r\n|[\r\n])')
+    _REGEX_ARROW = re.compile(r'[ \t]+-->[ \t]+')
+    _REGEX_SETTINGS = re.compile(r'[ \t]+((?:(?!-->)[^\r\n])+)')
+    _REGEX_PAYLOAD = re.compile(r'[^\r\n]+(?:\r\n|[\r\n])?')
+
+    @classmethod
+    def parse(cls, parser):
+        parser = parser.child()
+
+        id = None
+        m = parser.consume(cls._REGEX_ID)
+        if m:
+            id = m.group(1)
+
+        m0 = parser.consume(_REGEX_TS)
+        if not m0:
+            return None
+        if not parser.consume(cls._REGEX_ARROW):
+            return None
+        m1 = parser.consume(_REGEX_TS)
+        if not m1:
+            return None
+        m2 = parser.consume(cls._REGEX_SETTINGS)
+        if not parser.consume(_REGEX_NL):
+            return None
+
+        start = _parse_ts(m0)
+        end = _parse_ts(m1)
+        settings = m2.group(1) if m2 is not None else None
+
+        text = io.StringIO()
+        while True:
+            m = parser.consume(cls._REGEX_PAYLOAD)
+            if not m:
+                break
+            text.write(m.group(0))
+
+        parser.commit()
+        return cls(
+            id=id,
+            start=start, end=end, settings=settings,
+            text=text.getvalue()
+        )
+
+    def write_into(self, stream):
+        if self.id is not None:
+            stream.write(self.id)
+            stream.write('\n')
+        stream.write(_format_ts(self.start))
+        stream.write(' --> ')
+        stream.write(_format_ts(self.end))
+        if self.settings is not None:
+            stream.write(' ')
+            stream.write(self.settings)
+        stream.write('\n')
+        stream.write(self.text)
+        stream.write('\n')
+
+
+def parse_fragment(frag_content):
+    """
+    A generator that yields (partially) parsed WebVTT blocks when given
+    a bytes object containing the raw contents of a WebVTT file.
+    """
+
+    parser = _MatchParser(frag_content.decode('utf-8'))
+
+    yield Magic.parse(parser)
+
+    while not parser.match(_REGEX_EOF):
+        if parser.consume(_REGEX_BLANK):
+            continue
+
+        block = RegionBlock.parse(parser)
+        if block:
+            yield block
+            continue
+        block = StyleBlock.parse(parser)
+        if block:
+            yield block
+            continue
+        block = CommentBlock.parse(parser)
+        if block:
+            yield block  # XXX: or skip
+            continue
+
+        break
+
+    while not parser.match(_REGEX_EOF):
+        if parser.consume(_REGEX_BLANK):
+            continue
+
+        block = CommentBlock.parse(parser)
+        if block:
+            yield block  # XXX: or skip
+            continue
+        block = CueBlock.parse(parser)
+        if block:
+            yield block
+            continue
+
+        raise ParseError(parser)