]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/compat/compat_utils.py
[ie/podbayfm] Fix extraction (#10195)
[yt-dlp.git] / yt_dlp / compat / compat_utils.py
index 373389a4661b88391419e89435e4d05caa15e99a..d62b7d0488dc9510cd9c5ffb1a5d3f8d2f834c72 100644 (file)
@@ -1,5 +1,6 @@
 import collections
 import contextlib
+import functools
 import importlib
 import sys
 import types
@@ -14,7 +15,7 @@ def get_package_info(module):
         name=getattr(module, '_yt_dlp__identifier', module.__name__),
         version=str(next(filter(None, (
             getattr(module, attr, None)
-            for attr in ('__version__', 'version_string', 'version')
+            for attr in ('_yt_dlp__version', '__version__', 'version_string', 'version')
         )), None)))
 
 
@@ -22,21 +23,11 @@ def _is_package(module):
     return '__path__' in vars(module)
 
 
-class EnhancedModule(types.ModuleType):
-    def __new__(cls, name, *args, **kwargs):
-        if name not in sys.modules:
-            return super().__new__(cls, name, *args, **kwargs)
-
-        assert not args and not kwargs, 'Cannot pass additional arguments to an existing module'
-        module = sys.modules[name]
-        module.__class__ = cls
-        return module
+def _is_dunder(name):
+    return name.startswith('__') and name.endswith('__')
 
-    def __init__(self, name, *args, **kwargs):
-        # Prevent __new__ from trigerring __init__ again
-        if name not in sys.modules:
-            super().__init__(name, *args, **kwargs)
 
+class EnhancedModule(types.ModuleType):
     def __bool__(self):
         return vars(self).get('__bool__', lambda: True)()
 
@@ -44,7 +35,7 @@ def __getattribute__(self, attr):
         try:
             ret = super().__getattribute__(attr)
         except AttributeError:
-            if attr.startswith('__') and attr.endswith('__'):
+            if _is_dunder(attr):
                 raise
             getter = getattr(self, '__getattr__', None)
             if not getter:
@@ -53,13 +44,11 @@ def __getattribute__(self, attr):
         return ret.fget() if isinstance(ret, property) else ret
 
 
-def passthrough_module(parent, child, allowed_attributes=None, *, callback=lambda _: None):
+def passthrough_module(parent, child, allowed_attributes=(..., ), *, callback=lambda _: None):
     """Passthrough parent module into a child module, creating the parent if necessary"""
-    parent = EnhancedModule(parent)
-
     def __getattr__(attr):
         if _is_package(parent):
-            with contextlib.suppress(ImportError):
+            with contextlib.suppress(ModuleNotFoundError):
                 return importlib.import_module(f'.{attr}', parent.__name__)
 
         ret = from_child(attr)
@@ -68,26 +57,27 @@ def __getattr__(attr):
         callback(attr)
         return ret
 
+    @functools.lru_cache(maxsize=None)
     def from_child(attr):
         nonlocal child
-
-        if allowed_attributes is None:
-            if attr.startswith('__') and attr.endswith('__'):
+        if attr not in allowed_attributes:
+            if ... not in allowed_attributes or _is_dunder(attr):
                 return _NO_ATTRIBUTE
-        elif attr not in allowed_attributes:
-            return _NO_ATTRIBUTE
 
         if isinstance(child, str):
             child = importlib.import_module(child, parent.__name__)
 
-        with contextlib.suppress(AttributeError):
-            return getattr(child, attr)
-
         if _is_package(child):
             with contextlib.suppress(ImportError):
-                return importlib.import_module(f'.{attr}', child.__name__)
+                return passthrough_module(f'{parent.__name__}.{attr}',
+                                          importlib.import_module(f'.{attr}', child.__name__))
+
+        with contextlib.suppress(AttributeError):
+            return getattr(child, attr)
 
         return _NO_ATTRIBUTE
 
+    parent = sys.modules.get(parent, types.ModuleType(parent))
+    parent.__class__ = EnhancedModule
     parent.__getattr__ = __getattr__
     return parent