X-Git-Url: https://jfr.im/git/yt-dlp.git/blobdiff_plain/f6a765ceb59c55aea06921880c1c87d1ff36e5de..d4b52ce3fcb8d9578ed12365648eaba8718c603e:/yt_dlp/compat/compat_utils.py diff --git a/yt_dlp/compat/compat_utils.py b/yt_dlp/compat/compat_utils.py index 373389a46..d62b7d048 100644 --- a/yt_dlp/compat/compat_utils.py +++ b/yt_dlp/compat/compat_utils.py @@ -1,5 +1,6 @@ import collections import contextlib +import functools import importlib import sys import types @@ -14,7 +15,7 @@ def get_package_info(module): name=getattr(module, '_yt_dlp__identifier', module.__name__), version=str(next(filter(None, ( getattr(module, attr, None) - for attr in ('__version__', 'version_string', 'version') + for attr in ('_yt_dlp__version', '__version__', 'version_string', 'version') )), None))) @@ -22,21 +23,11 @@ def _is_package(module): return '__path__' in vars(module) -class EnhancedModule(types.ModuleType): - def __new__(cls, name, *args, **kwargs): - if name not in sys.modules: - return super().__new__(cls, name, *args, **kwargs) - - assert not args and not kwargs, 'Cannot pass additional arguments to an existing module' - module = sys.modules[name] - module.__class__ = cls - return module +def _is_dunder(name): + return name.startswith('__') and name.endswith('__') - def __init__(self, name, *args, **kwargs): - # Prevent __new__ from trigerring __init__ again - if name not in sys.modules: - super().__init__(name, *args, **kwargs) +class EnhancedModule(types.ModuleType): def __bool__(self): return vars(self).get('__bool__', lambda: True)() @@ -44,7 +35,7 @@ def __getattribute__(self, attr): try: ret = super().__getattribute__(attr) except AttributeError: - if attr.startswith('__') and attr.endswith('__'): + if _is_dunder(attr): raise getter = getattr(self, '__getattr__', None) if not getter: @@ -53,13 +44,11 @@ def __getattribute__(self, attr): return ret.fget() if isinstance(ret, property) else ret -def passthrough_module(parent, child, allowed_attributes=None, *, callback=lambda _: None): +def passthrough_module(parent, child, allowed_attributes=(..., ), *, callback=lambda _: None): """Passthrough parent module into a child module, creating the parent if necessary""" - parent = EnhancedModule(parent) - def __getattr__(attr): if _is_package(parent): - with contextlib.suppress(ImportError): + with contextlib.suppress(ModuleNotFoundError): return importlib.import_module(f'.{attr}', parent.__name__) ret = from_child(attr) @@ -68,26 +57,27 @@ def __getattr__(attr): callback(attr) return ret + @functools.lru_cache(maxsize=None) def from_child(attr): nonlocal child - - if allowed_attributes is None: - if attr.startswith('__') and attr.endswith('__'): + if attr not in allowed_attributes: + if ... not in allowed_attributes or _is_dunder(attr): return _NO_ATTRIBUTE - elif attr not in allowed_attributes: - return _NO_ATTRIBUTE if isinstance(child, str): child = importlib.import_module(child, parent.__name__) - with contextlib.suppress(AttributeError): - return getattr(child, attr) - if _is_package(child): with contextlib.suppress(ImportError): - return importlib.import_module(f'.{attr}', child.__name__) + return passthrough_module(f'{parent.__name__}.{attr}', + importlib.import_module(f'.{attr}', child.__name__)) + + with contextlib.suppress(AttributeError): + return getattr(child, attr) return _NO_ATTRIBUTE + parent = sys.modules.get(parent, types.ModuleType(parent)) + parent.__class__ = EnhancedModule parent.__getattr__ = __getattr__ return parent