]> jfr.im git - yt-dlp.git/commitdiff
Support module level `__bool__` and `property`
authorpukkandan <redacted>
Wed, 8 Feb 2023 01:55:36 +0000 (07:25 +0530)
committerpukkandan <redacted>
Wed, 8 Feb 2023 01:58:45 +0000 (07:28 +0530)
yt_dlp/compat/__init__.py
yt_dlp/compat/compat_utils.py

index 5d3db4b4ca00982d5485af42e5007d90176d3ec7..5cc78ebc2b474358d2bc8e00c9f961188c178423 100644 (file)
@@ -8,7 +8,7 @@
 
 # XXX: Implement this the same way as other DeprecationWarnings without circular import
 passthrough_module(__name__, '._legacy', callback=lambda attr: warnings.warn(
-    DeprecationWarning(f'{__name__}.{attr} is deprecated'), stacklevel=3))
+    DeprecationWarning(f'{__name__}.{attr} is deprecated'), stacklevel=5))
 
 
 # HTMLParseError has been deprecated in Python 3.3 and removed in
index 82e17628105f26cd251cd2c98751540969258b6b..b67944e6bd601b81c0c86260a7fc0d37c47c0686 100644 (file)
@@ -23,48 +23,75 @@ def get_package_info(module):
 
 
 def _is_package(module):
-    try:
-        module.__getattribute__('__path__')
-    except AttributeError:
-        return False
-    return True
+    return '__path__' in vars(module)
+
+
+class EnhancedModule(types.ModuleType):
+    def __new__(cls, name, *args, **kwargs):
+        if name not in sys.modules:
+            return super().__new__(cls, name, *args, **kwargs)
+
+        assert not args and not kwargs, 'Cannot pass additional arguments to an existing module'
+        module = sys.modules[name]
+        module.__class__ = cls
+        return module
+
+    def __init__(self, name, *args, **kwargs):
+        # Prevent __new__ from trigerring __init__ again
+        if name not in sys.modules:
+            super().__init__(name, *args, **kwargs)
+
+    def __bool__(self):
+        return vars(self).get('__bool__', lambda: True)()
+
+    def __getattribute__(self, attr):
+        try:
+            ret = super().__getattribute__(attr)
+        except AttributeError:
+            if attr.startswith('__') and attr.endswith('__'):
+                raise
+            getter = getattr(self, '__getattr__', None)
+            if not getter:
+                raise
+            ret = getter(attr)
+        return ret.fget() if isinstance(ret, property) else ret
 
 
 def passthrough_module(parent, child, allowed_attributes=None, *, callback=lambda _: None):
-    parent_module = importlib.import_module(parent)
-    child_module = None  # Import child module only as needed
-
-    class PassthroughModule(types.ModuleType):
-        def __getattr__(self, attr):
-            if _is_package(parent_module):
-                with contextlib.suppress(ImportError):
-                    return importlib.import_module(f'.{attr}', parent)
-
-            ret = self.__from_child(attr)
-            if ret is _NO_ATTRIBUTE:
-                raise AttributeError(f'module {parent} has no attribute {attr}')
-            callback(attr)
-            return ret
-
-        def __from_child(self, attr):
-            if allowed_attributes is None:
-                if attr.startswith('__') and attr.endswith('__'):
-                    return _NO_ATTRIBUTE
-            elif attr not in allowed_attributes:
-                return _NO_ATTRIBUTE
+    """Passthrough parent module into a child module, creating the parent if necessary"""
+    parent = EnhancedModule(parent)
 
-            nonlocal child_module
-            child_module = child_module or importlib.import_module(child, parent)
+    def __getattr__(attr):
+        if _is_package(parent):
+            with contextlib.suppress(ImportError):
+                return importlib.import_module(f'.{attr}', parent.__name__)
 
-            with contextlib.suppress(AttributeError):
-                return getattr(child_module, attr)
+        ret = from_child(attr)
+        if ret is _NO_ATTRIBUTE:
+            raise AttributeError(f'module {parent.__name__} has no attribute {attr}')
+        callback(attr)
+        return ret
 
-            if _is_package(child_module):
-                with contextlib.suppress(ImportError):
-                    return importlib.import_module(f'.{attr}', child)
+    def from_child(attr):
+        nonlocal child
 
+        if allowed_attributes is None:
+            if attr.startswith('__') and attr.endswith('__'):
+                return _NO_ATTRIBUTE
+        elif attr not in allowed_attributes:
             return _NO_ATTRIBUTE
 
-    # Python 3.6 does not have module level __getattr__
-    # https://peps.python.org/pep-0562/
-    sys.modules[parent].__class__ = PassthroughModule
+        if isinstance(child, str):
+            child = importlib.import_module(child, parent.__name__)
+
+        with contextlib.suppress(AttributeError):
+            return getattr(child, attr)
+
+        if _is_package(child):
+            with contextlib.suppress(ImportError):
+                return importlib.import_module(f'.{attr}', child.__name__)
+
+        return _NO_ATTRIBUTE
+
+    parent.__getattr__ = __getattr__
+    return parent