]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils/_utils.py
[rh:websockets] Migrate websockets to networking framework (#7720)
[yt-dlp.git] / yt_dlp / utils / _utils.py
index d0e3287166db2af6223323de2b2ee00bfb9acb13..b0164a8953d29400f6129ededbdfe9bb57ad625c 100644 (file)
@@ -1,5 +1,3 @@
-import asyncio
-import atexit
 import base64
 import binascii
 import calendar
@@ -54,7 +52,7 @@
     compat_os_name,
     compat_shlex_quote,
 )
-from ..dependencies import websockets, xattr
+from ..dependencies import xattr
 
 __name__ = __name__.rsplit('.', 1)[0]  # Pretend to be the parent module
 
 compiled_regex_type = type(re.compile(''))
 
 
-USER_AGENTS = {
-    'Safari': 'Mozilla/5.0 (X11; Linux x86_64; rv:10.0) AppleWebKit/533.20.25 (KHTML, like Gecko) Version/5.0.4 Safari/533.20.27',
-}
-
-
 class NO_DEFAULT:
     pass
 
@@ -674,6 +667,7 @@ def replace_insane(char):
 
 def sanitize_path(s, force=False):
     """Sanitizes and normalizes path on Windows"""
+    # XXX: this handles drive relative paths (c:sth) incorrectly
     if sys.platform == 'win32':
         force = False
         drive_or_unc, _ = os.path.splitdrive(s)
@@ -692,7 +686,10 @@ def sanitize_path(s, force=False):
         sanitized_path.insert(0, drive_or_unc + os.path.sep)
     elif force and s and s[0] == os.path.sep:
         sanitized_path.insert(0, os.path.sep)
-    return os.path.join(*sanitized_path)
+    # TODO: Fix behavioral differences <3.12
+    # The workaround using `normpath` only superficially passes tests
+    # Ref: https://github.com/python/cpython/pull/100351
+    return os.path.normpath(os.path.join(*sanitized_path))
 
 
 def sanitize_url(url, *, scheme='http'):
@@ -727,14 +724,6 @@ def extract_basic_auth(url):
     return url, f'Basic {auth_payload.decode()}'
 
 
-def sanitized_Request(url, *args, **kwargs):
-    url, auth_header = extract_basic_auth(escape_url(sanitize_url(url)))
-    if auth_header is not None:
-        headers = args[1] if len(args) >= 2 else kwargs.setdefault('headers', {})
-        headers['Authorization'] = auth_header
-    return urllib.request.Request(url, *args, **kwargs)
-
-
 def expand_path(s):
     """Expand shell variables and ~"""
     return os.path.expandvars(compat_expanduser(s))
@@ -834,7 +823,7 @@ def _fix(key):
         _fix('LD_LIBRARY_PATH')  # Linux
         _fix('DYLD_LIBRARY_PATH')  # macOS
 
-    def __init__(self, *args, env=None, text=False, **kwargs):
+    def __init__(self, args, *remaining, env=None, text=False, shell=False, **kwargs):
         if env is None:
             env = os.environ.copy()
         self._fix_pyinstaller_ld_path(env)
@@ -844,7 +833,21 @@ def __init__(self, *args, env=None, text=False, **kwargs):
             kwargs['universal_newlines'] = True  # For 3.6 compatibility
             kwargs.setdefault('encoding', 'utf-8')
             kwargs.setdefault('errors', 'replace')
-        super().__init__(*args, env=env, **kwargs, startupinfo=self._startupinfo)
+
+        if shell and compat_os_name == 'nt' and kwargs.get('executable') is None:
+            if not isinstance(args, str):
+                args = ' '.join(compat_shlex_quote(a) for a in args)
+            shell = False
+            args = f'{self.__comspec()} /Q /S /D /V:OFF /C "{args}"'
+
+        super().__init__(args, *remaining, env=env, shell=shell, **kwargs, startupinfo=self._startupinfo)
+
+    def __comspec(self):
+        comspec = os.environ.get('ComSpec') or os.path.join(
+            os.environ.get('SystemRoot', ''), 'System32', 'cmd.exe')
+        if os.path.isabs(comspec):
+            return comspec
+        raise FileNotFoundError('shell not found: neither %ComSpec% nor %SystemRoot% is set')
 
     def communicate_or_kill(self, *args, **kwargs):
         try:
@@ -894,19 +897,6 @@ def formatSeconds(secs, delim=':', msec=False):
     return '%s.%03d' % (ret, time.milliseconds) if msec else ret
 
 
-def make_HTTPS_handler(params, **kwargs):
-    from ._deprecated import YoutubeDLHTTPSHandler
-    from ..networking._helper import make_ssl_context
-    return YoutubeDLHTTPSHandler(params, context=make_ssl_context(
-        verify=not params.get('nocheckcertificate'),
-        client_certificate=params.get('client_certificate'),
-        client_certificate_key=params.get('client_certificate_key'),
-        client_certificate_password=params.get('client_certificate_password'),
-        legacy_support=params.get('legacyserverconnect'),
-        use_certifi='no-certifi' not in params.get('compat_opts', []),
-    ), **kwargs)
-
-
 def bug_reports_message(before=';'):
     from ..update import REPOSITORY
 
@@ -1143,17 +1133,6 @@ def is_path_like(f):
     return isinstance(f, (str, bytes, os.PathLike))
 
 
-class YoutubeDLCookieProcessor(urllib.request.HTTPCookieProcessor):
-    def __init__(self, cookiejar=None):
-        urllib.request.HTTPCookieProcessor.__init__(self, cookiejar)
-
-    def http_response(self, request, response):
-        return urllib.request.HTTPCookieProcessor.http_response(self, request, response)
-
-    https_request = urllib.request.HTTPCookieProcessor.http_request
-    https_response = http_response
-
-
 def extract_timezone(date_str):
     m = re.search(
         r'''(?x)
@@ -1293,7 +1272,7 @@ def datetime_from_str(date_str, precision='auto', format='%Y%m%d'):
     if precision == 'auto':
         auto_precision = True
         precision = 'microsecond'
-    today = datetime_round(datetime.datetime.utcnow(), precision)
+    today = datetime_round(datetime.datetime.now(datetime.timezone.utc), precision)
     if date_str in ('now', 'today'):
         return today
     if date_str == 'yesterday':
@@ -1356,8 +1335,8 @@ def datetime_round(dt, precision='day'):
         'second': 1,
     }
     roundto = lambda x, n: ((x + n / 2) // n) * n
-    timestamp = calendar.timegm(dt.timetuple())
-    return datetime.datetime.utcfromtimestamp(roundto(timestamp, unit_seconds[precision]))
+    timestamp = roundto(calendar.timegm(dt.timetuple()), unit_seconds[precision])
+    return datetime.datetime.fromtimestamp(timestamp, datetime.timezone.utc)
 
 
 def hyphenate_date(date_str):
@@ -1455,6 +1434,7 @@ def write_string(s, out=None, encoding=None):
     out.flush()
 
 
+# TODO: Use global logger
 def deprecation_warning(msg, *, printer=None, stacklevel=0, **kwargs):
     from .. import _IN_CLI
     if _IN_CLI:
@@ -2005,13 +1985,6 @@ def url_or_none(url):
     return url if re.match(r'^(?:(?:https?|rt(?:m(?:pt?[es]?|fp)|sp[su]?)|mms|ftps?):)?//', url) else None
 
 
-def request_to_url(req):
-    if isinstance(req, urllib.request.Request):
-        return req.get_full_url()
-    else:
-        return req
-
-
 def strftime_or_none(timestamp, date_format='%Y%m%d', default=None):
     datetime_object = None
     try:
@@ -2064,7 +2037,7 @@ def parse_duration(s):
                 )?
                 T)?
                 (?:
-                    (?P<hours>[0-9]+)\s*h(?:ours?)?,?\s*
+                    (?P<hours>[0-9]+)\s*h(?:(?:ou)?rs?)?,?\s*
                 )?
                 (?:
                     (?P<mins>[0-9]+)\s*m(?:in(?:ute)?s?)?,?\s*
@@ -2507,23 +2480,6 @@ def lowercase_escape(s):
         s)
 
 
-def escape_rfc3986(s):
-    """Escape non-ASCII characters as suggested by RFC 3986"""
-    return urllib.parse.quote(s, b"%/;:@&=+$,!~*'()?#[]")
-
-
-def escape_url(url):
-    """Escape URL as suggested by RFC 3986"""
-    url_parsed = urllib.parse.urlparse(url)
-    return url_parsed._replace(
-        netloc=url_parsed.netloc.encode('idna').decode('ascii'),
-        path=escape_rfc3986(url_parsed.path),
-        params=escape_rfc3986(url_parsed.params),
-        query=escape_rfc3986(url_parsed.query),
-        fragment=escape_rfc3986(url_parsed.fragment)
-    ).geturl()
-
-
 def parse_qs(url, **kwargs):
     return urllib.parse.parse_qs(urllib.parse.urlparse(url).query, **kwargs)
 
@@ -2783,9 +2739,10 @@ def fix_kv(m):
     def create_map(mobj):
         return json.dumps(dict(json.loads(js_to_json(mobj.group(1) or '[]', vars=vars))))
 
+    code = re.sub(r'(?:new\s+)?Array\((.*?)\)', r'[\g<1>]', code)
     code = re.sub(r'new Map\((\[.*?\])?\)', create_map, code)
     if not strict:
-        code = re.sub(r'new Date\((".+")\)', r'\g<1>', code)
+        code = re.sub(rf'new Date\(({STRING_RE})\)', r'\g<1>', code)
         code = re.sub(r'new \w+\((.*?)\)', lambda m: json.dumps(m.group(0)), code)
         code = re.sub(r'parseInt\([^\d]+(\d+)[^\d]+\)', r'\1', code)
         code = re.sub(r'\(function\([^)]*\)\s*\{[^}]*\}\s*\)\s*\(\s*(["\'][^)]*["\'])\s*\)', r'\1', code)
@@ -2907,6 +2864,7 @@ def mimetype2ext(mt, default=NO_DEFAULT):
         'quicktime': 'mov',
         'webm': 'webm',
         'vp9': 'vp9',
+        'video/ogg': 'ogv',
         'x-flv': 'flv',
         'x-m4v': 'm4v',
         'x-matroska': 'mkv',
@@ -4481,10 +4439,12 @@ def write_xattr(path, key, value):
             raise XAttrMetadataError(e.errno, e.strerror)
         return
 
-    # UNIX Method 1. Use xattrs/pyxattrs modules
+    # UNIX Method 1. Use os.setxattr/xattrs/pyxattrs modules
 
     setxattr = None
-    if getattr(xattr, '_yt_dlp__identifier', None) == 'pyxattr':
+    if callable(getattr(os, 'setxattr', None)):
+        setxattr = os.setxattr
+    elif getattr(xattr, '_yt_dlp__identifier', None) == 'pyxattr':
         # Unicode arguments are not supported in pyxattr until version 0.5.0
         # See https://github.com/ytdl-org/youtube-dl/issues/5498
         if version_tuple(xattr.__version__) >= (0, 5, 0):
@@ -4961,77 +4921,6 @@ def parse_args(self):
         return self.parser.parse_args(self.all_args)
 
 
-class WebSocketsWrapper:
-    """Wraps websockets module to use in non-async scopes"""
-    pool = None
-
-    def __init__(self, url, headers=None, connect=True):
-        self.loop = asyncio.new_event_loop()
-        # XXX: "loop" is deprecated
-        self.conn = websockets.connect(
-            url, extra_headers=headers, ping_interval=None,
-            close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
-        if connect:
-            self.__enter__()
-        atexit.register(self.__exit__, None, None, None)
-
-    def __enter__(self):
-        if not self.pool:
-            self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
-        return self
-
-    def send(self, *args):
-        self.run_with_loop(self.pool.send(*args), self.loop)
-
-    def recv(self, *args):
-        return self.run_with_loop(self.pool.recv(*args), self.loop)
-
-    def __exit__(self, type, value, traceback):
-        try:
-            return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
-        finally:
-            self.loop.close()
-            self._cancel_all_tasks(self.loop)
-
-    # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
-    # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
-    @staticmethod
-    def run_with_loop(main, loop):
-        if not asyncio.iscoroutine(main):
-            raise ValueError(f'a coroutine was expected, got {main!r}')
-
-        try:
-            return loop.run_until_complete(main)
-        finally:
-            loop.run_until_complete(loop.shutdown_asyncgens())
-            if hasattr(loop, 'shutdown_default_executor'):
-                loop.run_until_complete(loop.shutdown_default_executor())
-
-    @staticmethod
-    def _cancel_all_tasks(loop):
-        to_cancel = asyncio.all_tasks(loop)
-
-        if not to_cancel:
-            return
-
-        for task in to_cancel:
-            task.cancel()
-
-        # XXX: "loop" is removed in python 3.10+
-        loop.run_until_complete(
-            asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
-
-        for task in to_cancel:
-            if task.cancelled():
-                continue
-            if task.exception() is not None:
-                loop.call_exception_handler({
-                    'message': 'unhandled exception during asyncio.run() shutdown',
-                    'exception': task.exception(),
-                    'task': task,
-                })
-
-
 def merge_headers(*dicts):
     """Merge dicts of http headers case insensitively, prioritizing the latter ones"""
     return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}
@@ -5525,7 +5414,7 @@ def info(self, message):
 
     def warning(self, message, *, once=False):
         if self._ydl:
-            self._ydl.report_warning(message, only_once=once)
+            self._ydl.report_warning(message, once)
 
     def error(self, message, *, is_error=True):
         if self._ydl: