]> jfr.im git - yt-dlp.git/commitdiff
[networking] Add `extensions` attribute to `Response` (#9756)
authorbashonly <redacted>
Sat, 4 May 2024 22:19:42 +0000 (17:19 -0500)
committerGitHub <redacted>
Sat, 4 May 2024 22:19:42 +0000 (22:19 +0000)
CurlCFFIRH now provides an `impersonate` field in its responses' extensions

Authored by: bashonly

test/test_networking.py
yt_dlp/networking/_curlcffi.py
yt_dlp/networking/common.py

index b50f70d086a6b27d597e80996816ff53ff9699f0..d613cb5681490933b964ec98b67cd5fc6bbe34ce 100644 (file)
@@ -785,6 +785,25 @@ def test_supported_impersonate_targets(self, handler):
                 assert res.status == 200
                 assert std_headers['user-agent'].lower() not in res.read().decode().lower()
 
+    def test_response_extensions(self, handler):
+        with handler() as rh:
+            for target in rh.supported_targets:
+                request = Request(
+                    f'http://127.0.0.1:{self.http_port}/gen_200', extensions={'impersonate': target})
+                res = validate_and_send(rh, request)
+                assert res.extensions['impersonate'] == rh._get_request_target(request)
+
+    def test_http_error_response_extensions(self, handler):
+        with handler() as rh:
+            for target in rh.supported_targets:
+                request = Request(
+                    f'http://127.0.0.1:{self.http_port}/gen_404', extensions={'impersonate': target})
+                try:
+                    validate_and_send(rh, request)
+                except HTTPError as e:
+                    res = e.response
+                assert res.extensions['impersonate'] == rh._get_request_target(request)
+
 
 class TestRequestHandlerMisc:
     """Misc generic tests for request handlers, not related to request or validation testing"""
index 39d1f70fb053cbd1b9f70152b29c745d89f2dbc2..10751a105057d163f77bb47ee9d706ff2eaef37f 100644 (file)
@@ -132,6 +132,16 @@ def _check_extensions(self, extensions):
         extensions.pop('cookiejar', None)
         extensions.pop('timeout', None)
 
+    def send(self, request: Request) -> Response:
+        target = self._get_request_target(request)
+        try:
+            response = super().send(request)
+        except HTTPError as e:
+            e.response.extensions['impersonate'] = target
+            raise
+        response.extensions['impersonate'] = target
+        return response
+
     def _send(self, request: Request):
         max_redirects_exceeded = False
         session: curl_cffi.requests.Session = self._get_instance(
index 4c66ba66aaf3ee5e7b884be41c81ada4d1d27d52..a2217034c90cb3812c47dc94cd57087b42feae82 100644 (file)
@@ -497,6 +497,7 @@ class Response(io.IOBase):
     @param headers: response headers.
     @param status: Response HTTP status code. Default is 200 OK.
     @param reason: HTTP status reason. Will use built-in reasons based on status code if not provided.
+    @param extensions: Dictionary of handler-specific response extensions.
     """
 
     def __init__(
@@ -505,7 +506,9 @@ def __init__(
             url: str,
             headers: Mapping[str, str],
             status: int = 200,
-            reason: str = None):
+            reason: str = None,
+            extensions: dict = None
+    ):
 
         self.fp = fp
         self.headers = Message()
@@ -517,6 +520,7 @@ def __init__(
             self.reason = reason or HTTPStatus(status).phrase
         except ValueError:
             self.reason = None
+        self.extensions = extensions or {}
 
     def readable(self):
         return self.fp.readable()