]> jfr.im git - yt-dlp.git/commitdiff
[rh] Remove additional logging handlers on close (#9032)
authorcoletdjnz <redacted>
Sat, 17 Feb 2024 22:32:34 +0000 (11:32 +1300)
committerGitHub <redacted>
Sat, 17 Feb 2024 22:32:34 +0000 (11:32 +1300)
Fixes https://github.com/yt-dlp/yt-dlp/issues/8922

Authored by: coletdjnz

test/test_networking.py
yt_dlp/networking/_requests.py
yt_dlp/networking/_websockets.py

index 8cadd86f5a0ef4e95b606b39fa6e1b1c208e30be..10534242a89bb26edceb839ac92bd901ee6c087a 100644 (file)
@@ -13,6 +13,7 @@
 import http.cookiejar
 import http.server
 import io
+import logging
 import pathlib
 import random
 import ssl
@@ -752,6 +753,25 @@ def test_certificate_nocombined_pass(self, handler):
         })
 
 
+class TestRequestHandlerMisc:
+    """Misc generic tests for request handlers, not related to request or validation testing"""
+    @pytest.mark.parametrize('handler,logger_name', [
+        ('Requests', 'urllib3'),
+        ('Websockets', 'websockets.client'),
+        ('Websockets', 'websockets.server')
+    ], indirect=['handler'])
+    def test_remove_logging_handler(self, handler, logger_name):
+        # Ensure any logging handlers, which may contain a YoutubeDL instance,
+        # are removed when we close the request handler
+        # See: https://github.com/yt-dlp/yt-dlp/issues/8922
+        logging_handlers = logging.getLogger(logger_name).handlers
+        before_count = len(logging_handlers)
+        rh = handler()
+        assert len(logging_handlers) == before_count + 1
+        rh.close()
+        assert len(logging_handlers) == before_count
+
+
 class TestUrllibRequestHandler(TestRequestHandlerBase):
     @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
     def test_file_urls(self, handler):
@@ -827,6 +847,7 @@ def test_httplib_validation_errors(self, handler, req, match, version_check):
             assert not isinstance(exc_info.value, TransportError)
 
 
+@pytest.mark.parametrize('handler', ['Requests'], indirect=True)
 class TestRequestsRequestHandler(TestRequestHandlerBase):
     @pytest.mark.parametrize('raised,expected', [
         (lambda: requests.exceptions.ConnectTimeout(), TransportError),
@@ -843,7 +864,6 @@ class TestRequestsRequestHandler(TestRequestHandlerBase):
         (lambda: requests.exceptions.RequestException(), RequestError)
         #  (lambda: requests.exceptions.TooManyRedirects(), HTTPError) - Needs a response object
     ])
-    @pytest.mark.parametrize('handler', ['Requests'], indirect=True)
     def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
         with handler() as rh:
             def mock_get_instance(*args, **kwargs):
@@ -877,7 +897,6 @@ def request(self, *args, **kwargs):
             '3 bytes read, 5 more expected'
         ),
     ])
-    @pytest.mark.parametrize('handler', ['Requests'], indirect=True)
     def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match):
         from requests.models import Response as RequestsResponse
         from urllib3.response import HTTPResponse as Urllib3Response
@@ -896,6 +915,21 @@ def mock_read(*args, **kwargs):
 
         assert exc_info.type is expected
 
+    def test_close(self, handler, monkeypatch):
+        rh = handler()
+        session = rh._get_instance(cookiejar=rh.cookiejar)
+        called = False
+        original_close = session.close
+
+        def mock_close(*args, **kwargs):
+            nonlocal called
+            called = True
+            return original_close(*args, **kwargs)
+
+        monkeypatch.setattr(session, 'close', mock_close)
+        rh.close()
+        assert called
+
 
 def run_validation(handler, error, req, **handler_kwargs):
     with handler(**handler_kwargs) as rh:
@@ -1205,6 +1239,19 @@ def some_preference(rh, request):
         assert director.send(Request('http://')).read() == b''
         assert director.send(Request('http://', headers={'prefer': '1'})).read() == b'supported'
 
+    def test_close(self, monkeypatch):
+        director = RequestDirector(logger=FakeLogger())
+        director.add_handler(FakeRH(logger=FakeLogger()))
+        called = False
+
+        def mock_close(*args, **kwargs):
+            nonlocal called
+            called = True
+
+        monkeypatch.setattr(director.handlers[FakeRH.RH_KEY], 'close', mock_close)
+        director.close()
+        assert called
+
 
 # XXX: do we want to move this to test_YoutubeDL.py?
 class TestYoutubeDLNetworking:
index 00e4bdb490cdf03c701a4849d5bcaf1dfbdb9833..7b19029bfe15c894e03409eae5eba1174935c95b 100644 (file)
@@ -258,10 +258,10 @@ def __init__(self, *args, **kwargs):
 
         # Forward urllib3 debug messages to our logger
         logger = logging.getLogger('urllib3')
-        handler = Urllib3LoggingHandler(logger=self._logger)
-        handler.setFormatter(logging.Formatter('requests: %(message)s'))
-        handler.addFilter(Urllib3LoggingFilter())
-        logger.addHandler(handler)
+        self.__logging_handler = Urllib3LoggingHandler(logger=self._logger)
+        self.__logging_handler.setFormatter(logging.Formatter('requests: %(message)s'))
+        self.__logging_handler.addFilter(Urllib3LoggingFilter())
+        logger.addHandler(self.__logging_handler)
         # TODO: Use a logger filter to suppress pool reuse warning instead
         logger.setLevel(logging.ERROR)
 
@@ -276,6 +276,9 @@ def __init__(self, *args, **kwargs):
 
     def close(self):
         self._clear_instances()
+        # Remove the logging handler that contains a reference to our logger
+        # See: https://github.com/yt-dlp/yt-dlp/issues/8922
+        logging.getLogger('urllib3').removeHandler(self.__logging_handler)
 
     def _check_extensions(self, extensions):
         super()._check_extensions(extensions)
index ed64080d62a27a9728a86f3106c8f02948b3a35b..159793204b126480271e040d12f501b4a3c67f94 100644 (file)
@@ -90,10 +90,12 @@ class WebsocketsRH(WebSocketRequestHandler):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
+        self.__logging_handlers = {}
         for name in ('websockets.client', 'websockets.server'):
             logger = logging.getLogger(name)
             handler = logging.StreamHandler(stream=sys.stdout)
             handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
+            self.__logging_handlers[name] = handler
             logger.addHandler(handler)
             if self.verbose:
                 logger.setLevel(logging.DEBUG)
@@ -103,6 +105,12 @@ def _check_extensions(self, extensions):
         extensions.pop('timeout', None)
         extensions.pop('cookiejar', None)
 
+    def close(self):
+        # Remove the logging handler that contains a reference to our logger
+        # See: https://github.com/yt-dlp/yt-dlp/issues/8922
+        for name, handler in self.__logging_handlers.items():
+            logging.getLogger(name).removeHandler(handler)
+
     def _send(self, request):
         timeout = float(request.extensions.get('timeout') or self.timeout)
         headers = self._merge_headers(request.headers)