]> jfr.im git - yt-dlp.git/commitdiff
[core/windows] Improve shell quoting and tests (#9802)
authorSimon Sawicki <redacted>
Sat, 27 Apr 2024 08:37:26 +0000 (10:37 +0200)
committerGitHub <redacted>
Sat, 27 Apr 2024 08:37:26 +0000 (10:37 +0200)
Authored by: Grub4K

test/test_utils.py
yt_dlp/utils/_utils.py

index ddf0a7c24202492fe33d39a1f01c7acf8c560da4..824864577dc0ae4258abce3d46910d9cecbcbb4f 100644 (file)
@@ -2059,7 +2059,22 @@ def test_extract_basic_auth(self):
         assert extract_basic_auth('http://user:pass@foo.bar') == ('http://foo.bar', 'Basic dXNlcjpwYXNz')
 
     @unittest.skipUnless(compat_os_name == 'nt', 'Only relevant on Windows')
-    def test_Popen_windows_escaping(self):
+    def test_windows_escaping(self):
+        tests = [
+            'test"&',
+            '%CMDCMDLINE:~-1%&',
+            'a\nb',
+            '"',
+            '\\',
+            '!',
+            '^!',
+            'a \\ b',
+            'a \\" b',
+            'a \\ b\\',
+            # We replace \r with \n
+            ('a\r\ra', 'a\n\na'),
+        ]
+
         def run_shell(args):
             stdout, stderr, error = Popen.run(
                 args, text=True, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@@ -2067,15 +2082,18 @@ def run_shell(args):
             assert not error
             return stdout
 
-        # Test escaping
-        assert run_shell(['echo', 'test"&']) == '"test""&"\n'
-        assert run_shell(['echo', '%CMDCMDLINE:~-1%&']) == '"%CMDCMDLINE:~-1%&"\n'
-        assert run_shell(['echo', 'a\nb']) == '"a"\n"b"\n'
-        assert run_shell(['echo', '"']) == '""""\n'
-        assert run_shell(['echo', '\\']) == '\\\n'
-        # Test if delayed expansion is disabled
-        assert run_shell(['echo', '^!']) == '"^!"\n'
-        assert run_shell('echo "^!"') == '"^!"\n'
+        for argument in tests:
+            if isinstance(argument, str):
+                expected = argument
+            else:
+                argument, expected = argument
+
+            args = [sys.executable, '-c', 'import sys; print(end=sys.argv[1])', argument, 'end']
+            assert run_shell(args) == expected
+
+            escaped = shell_quote(argument, shell=True)
+            args = f'{sys.executable} -c "import sys; print(end=sys.argv[1])" {escaped} end'
+            assert run_shell(args) == expected
 
 
 if __name__ == '__main__':
index e3e80f3d33d298effdbbe5d0c4ed6596092a03d3..b637669124f69af1af017e3b3447151ac8d26af5 100644 (file)
@@ -1638,16 +1638,14 @@ def get_filesystem_encoding():
     return encoding if encoding is not None else 'utf-8'
 
 
-_WINDOWS_QUOTE_TRANS = str.maketrans({'"': '\\"', '\\': '\\\\'})
+_WINDOWS_QUOTE_TRANS = str.maketrans({'"': R'\"'})
 _CMD_QUOTE_TRANS = str.maketrans({
     # Keep quotes balanced by replacing them with `""` instead of `\\"`
     '"': '""',
-    # Requires a variable `=` containing `"^\n\n"` (set in `utils.Popen`)
+    # These require an env-variable `=` containing `"^\n\n"` (set in `utils.Popen`)
     # `=` should be unique since variables containing `=` cannot be set using cmd
     '\n': '%=%',
-    # While we are only required to escape backslashes immediately before quotes,
-    # we instead escape all of 'em anyways to be consistent
-    '\\': '\\\\',
+    '\r': '%=%',
     # Use zero length variable replacement so `%` doesn't get expanded
     # `cd` is always set as long as extensions are enabled (`/E:ON` in `utils.Popen`)
     '%': '%%cd:~,%',
@@ -1656,19 +1654,14 @@ def get_filesystem_encoding():
 
 def shell_quote(args, *, shell=False):
     args = list(variadic(args))
-    if any(isinstance(item, bytes) for item in args):
-        deprecation_warning('Passing bytes to utils.shell_quote is deprecated')
-        encoding = get_filesystem_encoding()
-        for index, item in enumerate(args):
-            if isinstance(item, bytes):
-                args[index] = item.decode(encoding)
 
     if compat_os_name != 'nt':
         return shlex.join(args)
 
     trans = _CMD_QUOTE_TRANS if shell else _WINDOWS_QUOTE_TRANS
     return ' '.join(
-        s if re.fullmatch(r'[\w#$*\-+./:?@\\]+', s, re.ASCII) else s.translate(trans).join('""')
+        s if re.fullmatch(r'[\w#$*\-+./:?@\\]+', s, re.ASCII)
+        else re.sub(r'(\\+)("|$)', r'\1\1\2', s).translate(trans).join('""')
         for s in args)