]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils/_utils.py
[cleanup] Standardize `import datetime as dt` (#8978)
[yt-dlp.git] / yt_dlp / utils / _utils.py
index 180bec245a3e7a48ccb2e70aa4edc682b89e757c..dec514674f5cf51b0036831c5875915fc8861401 100644 (file)
@@ -1,5 +1,3 @@
-import asyncio
-import atexit
 import base64
 import binascii
 import calendar
@@ -7,7 +5,7 @@
 import collections
 import collections.abc
 import contextlib
-import datetime
+import datetime as dt
 import email.header
 import email.utils
 import errno
@@ -54,7 +52,7 @@
     compat_os_name,
     compat_shlex_quote,
 )
-from ..dependencies import websockets, xattr
+from ..dependencies import xattr
 
 __name__ = __name__.rsplit('.', 1)[0]  # Pretend to be the parent module
 
@@ -560,7 +558,7 @@ def decode(self, s):
                     s = self._close_object(e)
                     if s is not None:
                         continue
-                raise type(e)(f'{e.msg} in {s[e.pos-10:e.pos+10]!r}', s, e.pos)
+                raise type(e)(f'{e.msg} in {s[e.pos - 10:e.pos + 10]!r}', s, e.pos)
         assert False, 'Too many attempts to decode JSON'
 
 
@@ -638,7 +636,7 @@ def replace_insane(char):
         elif char in '\\/|*<>':
             return '\0_'
         if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace() or ord(char) > 127):
-            return '\0_'
+            return '' if unicodedata.category(char)[0] in 'CM' else '\0_'
         return char
 
     # Replace look-alike Unicode glyphs
@@ -669,6 +667,7 @@ def replace_insane(char):
 
 def sanitize_path(s, force=False):
     """Sanitizes and normalizes path on Windows"""
+    # XXX: this handles drive relative paths (c:sth) incorrectly
     if sys.platform == 'win32':
         force = False
         drive_or_unc, _ = os.path.splitdrive(s)
@@ -687,7 +686,10 @@ def sanitize_path(s, force=False):
         sanitized_path.insert(0, drive_or_unc + os.path.sep)
     elif force and s and s[0] == os.path.sep:
         sanitized_path.insert(0, os.path.sep)
-    return os.path.join(*sanitized_path)
+    # TODO: Fix behavioral differences <3.12
+    # The workaround using `normpath` only superficially passes tests
+    # Ref: https://github.com/python/cpython/pull/100351
+    return os.path.normpath(os.path.join(*sanitized_path))
 
 
 def sanitize_url(url, *, scheme='http'):
@@ -821,7 +823,7 @@ def _fix(key):
         _fix('LD_LIBRARY_PATH')  # Linux
         _fix('DYLD_LIBRARY_PATH')  # macOS
 
-    def __init__(self, *args, env=None, text=False, **kwargs):
+    def __init__(self, args, *remaining, env=None, text=False, shell=False, **kwargs):
         if env is None:
             env = os.environ.copy()
         self._fix_pyinstaller_ld_path(env)
@@ -831,7 +833,21 @@ def __init__(self, *args, env=None, text=False, **kwargs):
             kwargs['universal_newlines'] = True  # For 3.6 compatibility
             kwargs.setdefault('encoding', 'utf-8')
             kwargs.setdefault('errors', 'replace')
-        super().__init__(*args, env=env, **kwargs, startupinfo=self._startupinfo)
+
+        if shell and compat_os_name == 'nt' and kwargs.get('executable') is None:
+            if not isinstance(args, str):
+                args = ' '.join(compat_shlex_quote(a) for a in args)
+            shell = False
+            args = f'{self.__comspec()} /Q /S /D /V:OFF /C "{args}"'
+
+        super().__init__(args, *remaining, env=env, shell=shell, **kwargs, startupinfo=self._startupinfo)
+
+    def __comspec(self):
+        comspec = os.environ.get('ComSpec') or os.path.join(
+            os.environ.get('SystemRoot', ''), 'System32', 'cmd.exe')
+        if os.path.isabs(comspec):
+            return comspec
+        raise FileNotFoundError('shell not found: neither %ComSpec% nor %SystemRoot% is set')
 
     def communicate_or_kill(self, *args, **kwargs):
         try:
@@ -1134,14 +1150,14 @@ def extract_timezone(date_str):
         timezone = TIMEZONE_NAMES.get(m and m.group('tz').strip())
         if timezone is not None:
             date_str = date_str[:-len(m.group('tz'))]
-        timezone = datetime.timedelta(hours=timezone or 0)
+        timezone = dt.timedelta(hours=timezone or 0)
     else:
         date_str = date_str[:-len(m.group('tz'))]
         if not m.group('sign'):
-            timezone = datetime.timedelta()
+            timezone = dt.timedelta()
         else:
             sign = 1 if m.group('sign') == '+' else -1
-            timezone = datetime.timedelta(
+            timezone = dt.timedelta(
                 hours=sign * int(m.group('hours')),
                 minutes=sign * int(m.group('minutes')))
     return timezone, date_str
@@ -1160,8 +1176,8 @@ def parse_iso8601(date_str, delimiter='T', timezone=None):
 
     with contextlib.suppress(ValueError):
         date_format = f'%Y-%m-%d{delimiter}%H:%M:%S'
-        dt = datetime.datetime.strptime(date_str, date_format) - timezone
-        return calendar.timegm(dt.timetuple())
+        dt_ = dt.datetime.strptime(date_str, date_format) - timezone
+        return calendar.timegm(dt_.timetuple())
 
 
 def date_formats(day_first=True):
@@ -1182,12 +1198,12 @@ def unified_strdate(date_str, day_first=True):
 
     for expression in date_formats(day_first):
         with contextlib.suppress(ValueError):
-            upload_date = datetime.datetime.strptime(date_str, expression).strftime('%Y%m%d')
+            upload_date = dt.datetime.strptime(date_str, expression).strftime('%Y%m%d')
     if upload_date is None:
         timetuple = email.utils.parsedate_tz(date_str)
         if timetuple:
             with contextlib.suppress(ValueError):
-                upload_date = datetime.datetime(*timetuple[:6]).strftime('%Y%m%d')
+                upload_date = dt.datetime(*timetuple[:6]).strftime('%Y%m%d')
     if upload_date is not None:
         return str(upload_date)
 
@@ -1217,8 +1233,8 @@ def unified_timestamp(date_str, day_first=True):
 
     for expression in date_formats(day_first):
         with contextlib.suppress(ValueError):
-            dt = datetime.datetime.strptime(date_str, expression) - timezone + datetime.timedelta(hours=pm_delta)
-            return calendar.timegm(dt.timetuple())
+            dt_ = dt.datetime.strptime(date_str, expression) - timezone + dt.timedelta(hours=pm_delta)
+            return calendar.timegm(dt_.timetuple())
 
     timetuple = email.utils.parsedate_tz(date_str)
     if timetuple:
@@ -1256,11 +1272,11 @@ def datetime_from_str(date_str, precision='auto', format='%Y%m%d'):
     if precision == 'auto':
         auto_precision = True
         precision = 'microsecond'
-    today = datetime_round(datetime.datetime.utcnow(), precision)
+    today = datetime_round(dt.datetime.now(dt.timezone.utc), precision)
     if date_str in ('now', 'today'):
         return today
     if date_str == 'yesterday':
-        return today - datetime.timedelta(days=1)
+        return today - dt.timedelta(days=1)
     match = re.match(
         r'(?P<start>.+)(?P<sign>[+-])(?P<time>\d+)(?P<unit>microsecond|second|minute|hour|day|week|month|year)s?',
         date_str)
@@ -1275,13 +1291,13 @@ def datetime_from_str(date_str, precision='auto', format='%Y%m%d'):
             if unit == 'week':
                 unit = 'day'
                 time *= 7
-            delta = datetime.timedelta(**{unit + 's': time})
+            delta = dt.timedelta(**{unit + 's': time})
             new_date = start_time + delta
         if auto_precision:
             return datetime_round(new_date, unit)
         return new_date
 
-    return datetime_round(datetime.datetime.strptime(date_str, format), precision)
+    return datetime_round(dt.datetime.strptime(date_str, format), precision)
 
 
 def date_from_str(date_str, format='%Y%m%d', strict=False):
@@ -1296,21 +1312,21 @@ def date_from_str(date_str, format='%Y%m%d', strict=False):
     return datetime_from_str(date_str, precision='microsecond', format=format).date()
 
 
-def datetime_add_months(dt, months):
+def datetime_add_months(dt_, months):
     """Increment/Decrement a datetime object by months."""
-    month = dt.month + months - 1
-    year = dt.year + month // 12
+    month = dt_.month + months - 1
+    year = dt_.year + month // 12
     month = month % 12 + 1
-    day = min(dt.day, calendar.monthrange(year, month)[1])
-    return dt.replace(year, month, day)
+    day = min(dt_.day, calendar.monthrange(year, month)[1])
+    return dt_.replace(year, month, day)
 
 
-def datetime_round(dt, precision='day'):
+def datetime_round(dt_, precision='day'):
     """
     Round a datetime object's time to a specific precision
     """
     if precision == 'microsecond':
-        return dt
+        return dt_
 
     unit_seconds = {
         'day': 86400,
@@ -1319,8 +1335,8 @@ def datetime_round(dt, precision='day'):
         'second': 1,
     }
     roundto = lambda x, n: ((x + n / 2) // n) * n
-    timestamp = calendar.timegm(dt.timetuple())
-    return datetime.datetime.utcfromtimestamp(roundto(timestamp, unit_seconds[precision]))
+    timestamp = roundto(calendar.timegm(dt_.timetuple()), unit_seconds[precision])
+    return dt.datetime.fromtimestamp(timestamp, dt.timezone.utc)
 
 
 def hyphenate_date(date_str):
@@ -1341,11 +1357,11 @@ def __init__(self, start=None, end=None):
         if start is not None:
             self.start = date_from_str(start, strict=True)
         else:
-            self.start = datetime.datetime.min.date()
+            self.start = dt.datetime.min.date()
         if end is not None:
             self.end = date_from_str(end, strict=True)
         else:
-            self.end = datetime.datetime.max.date()
+            self.end = dt.datetime.max.date()
         if self.start > self.end:
             raise ValueError('Date range: "%s" , the start date must be before the end date' % self)
 
@@ -1356,13 +1372,16 @@ def day(cls, day):
 
     def __contains__(self, date):
         """Check if the date is in the range"""
-        if not isinstance(date, datetime.date):
+        if not isinstance(date, dt.date):
             date = date_from_str(date)
         return self.start <= date <= self.end
 
     def __repr__(self):
         return f'{__name__}.{type(self).__name__}({self.start.isoformat()!r}, {self.end.isoformat()!r})'
 
+    def __str__(self):
+        return f'{self.start} to {self.end}'
+
     def __eq__(self, other):
         return (isinstance(other, DateRange)
                 and self.start == other.start and self.end == other.end)
@@ -1408,7 +1427,8 @@ def write_string(s, out=None, encoding=None):
         s = re.sub(r'([\r\n]+)', r' \1', s)
 
     enc, buffer = None, out
-    if 'b' in getattr(out, 'mode', ''):
+    # `mode` might be `None` (Ref: https://github.com/yt-dlp/yt-dlp/issues/8816)
+    if 'b' in (getattr(out, 'mode', None) or ''):
         enc = encoding or preferredencoding()
     elif hasattr(out, 'buffer'):
         buffer = out.buffer
@@ -1869,6 +1889,7 @@ def setproctitle(title):
     buf = ctypes.create_string_buffer(len(title_bytes))
     buf.value = title_bytes
     try:
+        # PR_SET_NAME = 15      Ref: /usr/include/linux/prctl.h
         libc.prctl(15, buf, 0, 0, 0)
     except AttributeError:
         return  # Strange libc, just skip this
@@ -1975,12 +1996,12 @@ def strftime_or_none(timestamp, date_format='%Y%m%d', default=None):
         if isinstance(timestamp, (int, float)):  # unix timestamp
             # Using naive datetime here can break timestamp() in Windows
             # Ref: https://github.com/yt-dlp/yt-dlp/issues/5185, https://github.com/python/cpython/issues/94414
-            # Also, datetime.datetime.fromtimestamp breaks for negative timestamps
+            # Also, dt.datetime.fromtimestamp breaks for negative timestamps
             # Ref: https://github.com/yt-dlp/yt-dlp/issues/6706#issuecomment-1496842642
-            datetime_object = (datetime.datetime.fromtimestamp(0, datetime.timezone.utc)
-                               + datetime.timedelta(seconds=timestamp))
+            datetime_object = (dt.datetime.fromtimestamp(0, dt.timezone.utc)
+                               + dt.timedelta(seconds=timestamp))
         elif isinstance(timestamp, str):  # assume YYYYMMDD
-            datetime_object = datetime.datetime.strptime(timestamp, '%Y%m%d')
+            datetime_object = dt.datetime.strptime(timestamp, '%Y%m%d')
         date_format = re.sub(  # Support %s on windows
             r'(?<!%)(%%)*%s', rf'\g<1>{int(datetime_object.timestamp())}', date_format)
         return datetime_object.strftime(date_format)
@@ -2244,6 +2265,9 @@ def __getitem__(self, idx):
             raise self.IndexError()
         return entries[0]
 
+    def __bool__(self):
+        return bool(self.getslice(0, 1))
+
 
 class OnDemandPagedList(PagedList):
     """Download pages until a page with less than maximum results"""
@@ -2723,9 +2747,10 @@ def fix_kv(m):
     def create_map(mobj):
         return json.dumps(dict(json.loads(js_to_json(mobj.group(1) or '[]', vars=vars))))
 
+    code = re.sub(r'(?:new\s+)?Array\((.*?)\)', r'[\g<1>]', code)
     code = re.sub(r'new Map\((\[.*?\])?\)', create_map, code)
     if not strict:
-        code = re.sub(r'new Date\((".+")\)', r'\g<1>', code)
+        code = re.sub(rf'new Date\(({STRING_RE})\)', r'\g<1>', code)
         code = re.sub(r'new \w+\((.*?)\)', lambda m: json.dumps(m.group(0)), code)
         code = re.sub(r'parseInt\([^\d]+(\d+)[^\d]+\)', r'\1', code)
         code = re.sub(r'\(function\([^)]*\)\s*\{[^}]*\}\s*\)\s*\(\s*(["\'][^)]*["\'])\s*\)', r'\1', code)
@@ -3217,6 +3242,8 @@ def match_str(filter_str, dct, incomplete=False):
 def match_filter_func(filters, breaking_filters=None):
     if not filters and not breaking_filters:
         return None
+    repr_ = f'{match_filter_func.__module__}.{match_filter_func.__qualname__}({filters}, {breaking_filters})'
+
     breaking_filters = match_filter_func(breaking_filters) or (lambda _, __: None)
     filters = set(variadic(filters or []))
 
@@ -3224,6 +3251,7 @@ def match_filter_func(filters, breaking_filters=None):
     if interactive:
         filters.remove('-')
 
+    @function_with_repr.set_repr(repr_)
     def _match_func(info_dict, incomplete=False):
         ret = breaking_filters(info_dict, incomplete)
         if ret is not None:
@@ -4422,10 +4450,12 @@ def write_xattr(path, key, value):
             raise XAttrMetadataError(e.errno, e.strerror)
         return
 
-    # UNIX Method 1. Use xattrs/pyxattrs modules
+    # UNIX Method 1. Use os.setxattr/xattrs/pyxattrs modules
 
     setxattr = None
-    if getattr(xattr, '_yt_dlp__identifier', None) == 'pyxattr':
+    if callable(getattr(os, 'setxattr', None)):
+        setxattr = os.setxattr
+    elif getattr(xattr, '_yt_dlp__identifier', None) == 'pyxattr':
         # Unicode arguments are not supported in pyxattr until version 0.5.0
         # See https://github.com/ytdl-org/youtube-dl/issues/5498
         if version_tuple(xattr.__version__) >= (0, 5, 0):
@@ -4445,7 +4475,7 @@ def write_xattr(path, key, value):
            else 'xattr' if check_executable('xattr', ['-h']) else None)
     if not exe:
         raise XAttrUnavailableError(
-            'Couldn\'t find a tool to set the xattrs. Install either the python "xattr" or "pyxattr" modules or the '
+            'Couldn\'t find a tool to set the xattrs. Install either the "xattr" or "pyxattr" Python modules or the '
             + ('"xattr" binary' if sys.platform != 'linux' else 'GNU "attr" package (which contains the "setfattr" tool)'))
 
     value = value.decode()
@@ -4460,10 +4490,10 @@ def write_xattr(path, key, value):
 
 
 def random_birthday(year_field, month_field, day_field):
-    start_date = datetime.date(1950, 1, 1)
-    end_date = datetime.date(1995, 12, 31)
+    start_date = dt.date(1950, 1, 1)
+    end_date = dt.date(1995, 12, 31)
     offset = random.randint(0, (end_date - start_date).days)
-    random_date = start_date + datetime.timedelta(offset)
+    random_date = start_date + dt.timedelta(offset)
     return {
         year_field: str(random_date.year),
         month_field: str(random_date.month),
@@ -4642,7 +4672,7 @@ def time_seconds(**kwargs):
     """
     Returns TZ-aware time in seconds since the epoch (1970-01-01T00:00:00Z)
     """
-    return time.time() + datetime.timedelta(**kwargs).total_seconds()
+    return time.time() + dt.timedelta(**kwargs).total_seconds()
 
 
 # create a JSON Web Signature (jws) with HS256 algorithm
@@ -4770,8 +4800,9 @@ def parse_http_range(range):
 
 
 def read_stdin(what):
-    eof = 'Ctrl+Z' if compat_os_name == 'nt' else 'Ctrl+D'
-    write_string(f'Reading {what} from STDIN - EOF ({eof}) to end:\n')
+    if what:
+        eof = 'Ctrl+Z' if compat_os_name == 'nt' else 'Ctrl+D'
+        write_string(f'Reading {what} from STDIN - EOF ({eof}) to end:\n')
     return sys.stdin
 
 
@@ -4902,77 +4933,6 @@ def parse_args(self):
         return self.parser.parse_args(self.all_args)
 
 
-class WebSocketsWrapper:
-    """Wraps websockets module to use in non-async scopes"""
-    pool = None
-
-    def __init__(self, url, headers=None, connect=True):
-        self.loop = asyncio.new_event_loop()
-        # XXX: "loop" is deprecated
-        self.conn = websockets.connect(
-            url, extra_headers=headers, ping_interval=None,
-            close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
-        if connect:
-            self.__enter__()
-        atexit.register(self.__exit__, None, None, None)
-
-    def __enter__(self):
-        if not self.pool:
-            self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
-        return self
-
-    def send(self, *args):
-        self.run_with_loop(self.pool.send(*args), self.loop)
-
-    def recv(self, *args):
-        return self.run_with_loop(self.pool.recv(*args), self.loop)
-
-    def __exit__(self, type, value, traceback):
-        try:
-            return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
-        finally:
-            self.loop.close()
-            self._cancel_all_tasks(self.loop)
-
-    # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
-    # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
-    @staticmethod
-    def run_with_loop(main, loop):
-        if not asyncio.iscoroutine(main):
-            raise ValueError(f'a coroutine was expected, got {main!r}')
-
-        try:
-            return loop.run_until_complete(main)
-        finally:
-            loop.run_until_complete(loop.shutdown_asyncgens())
-            if hasattr(loop, 'shutdown_default_executor'):
-                loop.run_until_complete(loop.shutdown_default_executor())
-
-    @staticmethod
-    def _cancel_all_tasks(loop):
-        to_cancel = asyncio.all_tasks(loop)
-
-        if not to_cancel:
-            return
-
-        for task in to_cancel:
-            task.cancel()
-
-        # XXX: "loop" is removed in python 3.10+
-        loop.run_until_complete(
-            asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
-
-        for task in to_cancel:
-            if task.cancelled():
-                continue
-            if task.exception() is not None:
-                loop.call_exception_handler({
-                    'message': 'unhandled exception during asyncio.run() shutdown',
-                    'exception': task.exception(),
-                    'task': task,
-                })
-
-
 def merge_headers(*dicts):
     """Merge dicts of http headers case insensitively, prioritizing the latter ones"""
     return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}
@@ -5023,6 +4983,10 @@ def __init__(self, func, repr_=None):
     def __call__(self, *args, **kwargs):
         return self.func(*args, **kwargs)
 
+    @classmethod
+    def set_repr(cls, repr_):
+        return functools.partial(cls, repr_=repr_)
+
     def __repr__(self):
         if self.__repr:
             return self.__repr
@@ -5121,7 +5085,7 @@ def truncate_string(s, left, right=0):
     assert left > 3 and right >= 0
     if s is None or len(s) <= left + right:
         return s
-    return f'{s[:left-3]}...{s[-right:] if right else ""}'
+    return f'{s[:left - 3]}...{s[-right:] if right else ""}'
 
 
 def orderedSet_from_options(options, alias_dict, *, use_regex=False, start=None):
@@ -5451,6 +5415,17 @@ def calculate_preference(self, format):
         return tuple(self._calculate_field_preference(format, field) for field in self._order)
 
 
+def filesize_from_tbr(tbr, duration):
+    """
+    @param tbr:      Total bitrate in kbps (1000 bits/sec)
+    @param duration: Duration in seconds
+    @returns         Filesize in bytes
+    """
+    if tbr is None or duration is None:
+        return None
+    return int(duration * tbr * (1000 / 8))
+
+
 # XXX: Temporary
 class _YDLLogger:
     def __init__(self, ydl=None):