]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/jsinterp.py
[jsinterp] Handle `NaN` in bitwise operators
[yt-dlp.git] / yt_dlp / jsinterp.py
index e25997129d0ae69f4a9103c9453df9a7fdc0afba..965b1c0f2980bb90cd0ad72c6636eea07b796967 100644 (file)
@@ -9,6 +9,7 @@
 from .utils import (
     NO_DEFAULT,
     ExtractorError,
+    function_with_repr,
     js_to_json,
     remove_quotes,
     truncate_string,
 
 def _js_bit_op(op):
     def zeroise(x):
-        return 0 if x in (None, JS_Undefined) else x
+        if x in (None, JS_Undefined):
+            return 0
+        with contextlib.suppress(TypeError):
+            if math.isnan(x):  # NB: NaN cannot be checked by membership
+                return 0
+        return x
 
     def wrapped(a, b):
         return op(zeroise(a), zeroise(b)) & 0xffffffff
@@ -184,7 +190,8 @@ def interpret_statement(self, stmt, local_vars, allow_recursion, *args, **kwargs
                     cls.write('=> Raises:', e, '<-|', stmt, level=allow_recursion)
                 raise
             if cls.ENABLED and stmt.strip():
-                cls.write(['->', '=>'][should_ret], repr(ret), '<-|', stmt, level=allow_recursion)
+                if should_ret or not repr(ret) == stmt:
+                    cls.write(['->', '=>'][should_ret], repr(ret), '<-|', stmt, level=allow_recursion)
             return ret, should_ret
         return interpret_statement
 
@@ -205,8 +212,6 @@ class JSInterpreter:
         'y': 4096,  # Perform a "sticky" search that matches starting at the current position in the target string
     }
 
-    _EXC_NAME = '__yt_dlp_exception__'
-
     def __init__(self, code, objects=None):
         self.code, self._functions = code, {}
         self._objects = {} if objects is None else objects
@@ -220,6 +225,8 @@ def __init__(self, msg, expr=None, *args, **kwargs):
     def _named_object(self, namespace, obj):
         self.__named_object_counter += 1
         name = f'__yt_dlp_jsinterp_obj{self.__named_object_counter}'
+        if callable(obj) and not isinstance(obj, function_with_repr):
+            obj = function_with_repr(obj, f'F<{self.__named_object_counter}>')
         namespace[name] = obj
         return name
 
@@ -241,7 +248,7 @@ def _separate(expr, delim=',', max_split=None):
             return
         counters = {k: 0 for k in _MATCHING_PARENS.values()}
         start, splits, pos, delim_len = 0, 0, 0, len(delim) - 1
-        in_quote, escaping, after_op, in_regex_char_group = None, False, True, False
+        in_quote, escaping, after_op, in_regex_char_group, in_unary_op = None, False, True, False, False
         for idx, char in enumerate(expr):
             if not in_quote and char in _MATCHING_PARENS:
                 counters[_MATCHING_PARENS[char]] += 1
@@ -256,9 +263,11 @@ def _separate(expr, delim=',', max_split=None):
                 elif in_quote == '/' and char in '[]':
                     in_regex_char_group = char == '['
             escaping = not escaping and in_quote and char == '\\'
-            after_op = not in_quote and char in OP_CHARS or (char.isspace() and after_op)
+            in_unary_op = (not in_quote and not in_regex_char_group
+                           and after_op not in (True, False) and char in '-+')
+            after_op = char if (not in_quote and char in OP_CHARS) else (char.isspace() and after_op)
 
-            if char != delim[pos] or any(counters.values()) or in_quote:
+            if char != delim[pos] or any(counters.values()) or in_quote or in_unary_op:
                 pos = 0
                 continue
             elif pos != delim_len:
@@ -343,7 +352,8 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100):
             inner, outer = self._separate(expr, expr[0], 1)
             if expr[0] == '/':
                 flags, outer = self._regex_flags(outer)
-                inner = re.compile(inner[1:], flags=flags)
+                # Avoid https://github.com/python/cpython/issues/74534
+                inner = re.compile(inner[1:].replace('[[', r'[\['), flags=flags)
             else:
                 inner = json.loads(js_to_json(f'{inner}{expr[0]}', strict=True))
             if not outer:
@@ -354,11 +364,11 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100):
             obj = expr[4:]
             if obj.startswith('Date('):
                 left, right = self._separate_at_paren(obj[4:])
-                expr = unified_timestamp(
+                date = unified_timestamp(
                     self.interpret_expression(left, local_vars, allow_recursion), False)
-                if not expr:
+                if date is None:
                     raise self.Exception(f'Failed to parse date {left!r}', expr)
-                expr = self._dump(int(expr * 1000), local_vars) + right
+                expr = self._dump(int(date * 1000), local_vars) + right
             else:
                 raise self.Exception(f'Unsupported object {obj}', expr)
 
@@ -402,10 +412,25 @@ def dict_item(key, val):
 
         m = re.match(r'''(?x)
                 (?P<try>try)\s*\{|
+                (?P<if>if)\s*\(|
                 (?P<switch>switch)\s*\(|
                 (?P<for>for)\s*\(
                 ''', expr)
         md = m.groupdict() if m else {}
+        if md.get('if'):
+            cndn, expr = self._separate_at_paren(expr[m.end() - 1:])
+            if_expr, expr = self._separate_at_paren(expr.lstrip())
+            # TODO: "else if" is not handled
+            else_expr = None
+            m = re.match(r'else\s*{', expr)
+            if m:
+                else_expr, expr = self._separate_at_paren(expr[m.end() - 1:])
+            cndn = _js_ternary(self.interpret_expression(cndn, local_vars, allow_recursion))
+            ret, should_abort = self.interpret_statement(
+                if_expr if cndn else else_expr, local_vars, allow_recursion)
+            if should_abort:
+                return ret, True
+
         if md.get('try'):
             try_expr, expr = self._separate_at_paren(expr[m.end() - 1:])
             err = None
@@ -768,7 +793,8 @@ def extract_object(self, objname):
             fields)
         for f in fields_m:
             argnames = f.group('args').split(',')
-            obj[remove_quotes(f.group('key'))] = self.build_function(argnames, f.group('code'))
+            name = remove_quotes(f.group('key'))
+            obj[name] = function_with_repr(self.build_function(argnames, f.group('code')), f'F<{name}>')
 
         return obj
 
@@ -790,7 +816,9 @@ def extract_function_code(self, funcname):
         return [x.strip() for x in func_m.group('args').split(',')], code
 
     def extract_function(self, funcname):
-        return self.extract_function_from_code(*self.extract_function_code(funcname))
+        return function_with_repr(
+            self.extract_function_from_code(*self.extract_function_code(funcname)),
+            f'F<{funcname}>')
 
     def extract_function_from_code(self, argnames, code, *global_stack):
         local_vars = {}