]> jfr.im git - yt-dlp.git/commitdiff
[jsinterp] Some optimizations and refactoring
authorpukkandan <redacted>
Tue, 21 Jun 2022 16:16:35 +0000 (21:46 +0530)
committerpukkandan <redacted>
Tue, 21 Jun 2022 17:53:48 +0000 (23:23 +0530)
Motivated by: https://github.com/ytdl-org/youtube-dl/issues/30641#issuecomment-1041904912

Authored by: dirkf, pukkandan

yt_dlp/jsinterp.py

index 56229cd992fa313f31e0fb552ac8e4bca9a7ecb4..c95a0ff57b9904e8b833bc2906c1a432bc852350 100644 (file)
@@ -6,22 +6,19 @@
 
 from .utils import ExtractorError, remove_quotes
 
-_OPERATORS = [
-    ('|', operator.or_),
-    ('^', operator.xor),
-    ('&', operator.and_),
-    ('>>', operator.rshift),
-    ('<<', operator.lshift),
-    ('-', operator.sub),
-    ('+', operator.add),
-    ('%', operator.mod),
-    ('/', operator.truediv),
-    ('*', operator.mul),
-]
-_ASSIGN_OPERATORS = [(op + '=', opfunc) for op, opfunc in _OPERATORS]
-_ASSIGN_OPERATORS.append(('=', (lambda cur, right: right)))
-
-_NAME_RE = r'[a-zA-Z_$][a-zA-Z_$0-9]*'
+_NAME_RE = r'[a-zA-Z_$][\w$]*'
+_OPERATORS = {
+    '|': operator.or_,
+    '^': operator.xor,
+    '&': operator.and_,
+    '>>': operator.rshift,
+    '<<': operator.lshift,
+    '-': operator.sub,
+    '+': operator.add,
+    '%': operator.mod,
+    '/': operator.truediv,
+    '*': operator.mul,
+}
 
 _MATCHING_PARENS = dict(zip('({[', ')}]'))
 _QUOTES = '\'"'
@@ -50,13 +47,11 @@ def __delitem__(self, key):
 
 
 class JSInterpreter:
+    __named_object_counter = 0
+
     def __init__(self, code, objects=None):
-        if objects is None:
-            objects = {}
-        self.code = code
-        self._functions = {}
-        self._objects = objects
-        self.__named_object_counter = 0
+        self.code, self._functions = code, {}
+        self._objects = {} if objects is None else objects
 
     def _named_object(self, namespace, obj):
         self.__named_object_counter += 1
@@ -93,9 +88,9 @@ def _separate(expr, delim=',', max_split=None):
                 break
         yield expr[start:]
 
-    @staticmethod
-    def _separate_at_paren(expr, delim):
-        separated = list(JSInterpreter._separate(expr, delim, 1))
+    @classmethod
+    def _separate_at_paren(cls, expr, delim):
+        separated = list(cls._separate(expr, delim, 1))
         if len(separated) < 2:
             raise ExtractorError(f'No terminating paren {delim} in {expr}')
         return separated[0][1:].strip(), separated[1].strip()
@@ -104,33 +99,29 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100):
         if allow_recursion < 0:
             raise ExtractorError('Recursion limit reached')
 
-        sub_statements = list(self._separate(stmt, ';'))
-        stmt = (sub_statements or ['']).pop()
+        should_abort = False
+        sub_statements = list(self._separate(stmt, ';')) or ['']
+        stmt = sub_statements.pop().lstrip()
+
         for sub_stmt in sub_statements:
             ret, should_abort = self.interpret_statement(sub_stmt, local_vars, allow_recursion - 1)
             if should_abort:
-                return ret
+                return ret, should_abort
 
-        should_abort = False
-        stmt = stmt.lstrip()
-        stmt_m = re.match(r'var\s', stmt)
-        if stmt_m:
-            expr = stmt[len(stmt_m.group(0)):]
+        m = re.match(r'(?P<var>var\s)|return(?:\s+|$)', stmt)
+        if not m:  # Try interpreting it as an expression
+            expr = stmt
+        elif m.group('var'):
+            expr = stmt[len(m.group(0)):]
         else:
-            return_m = re.match(r'return(?:\s+|$)', stmt)
-            if return_m:
-                expr = stmt[len(return_m.group(0)):]
-                should_abort = True
-            else:
-                # Try interpreting it as an expression
-                expr = stmt
+            expr = stmt[len(m.group(0)):]
+            should_abort = True
 
-        v = self.interpret_expression(expr, local_vars, allow_recursion)
-        return v, should_abort
+        return self.interpret_expression(expr, local_vars, allow_recursion), should_abort
 
     def interpret_expression(self, expr, local_vars, allow_recursion):
         expr = expr.strip()
-        if expr == '':  # Empty expression
+        if not expr:
             return None
 
         if expr.startswith('{'):
@@ -156,8 +147,8 @@ def interpret_expression(self, expr, local_vars, allow_recursion):
                 for item in self._separate(inner)])
             expr = name + outer
 
-        m = re.match(r'try\s*', expr)
-        if m:
+        m = re.match(r'(?P<try>try)\s*|(?:(?P<catch>catch)|(?P<for>for)|(?P<switch>switch))\s*\(', expr)
+        if m and m.group('try'):
             if expr[m.end()] == '{':
                 try_expr, expr = self._separate_at_paren(expr[m.end():], '}')
             else:
@@ -167,21 +158,19 @@ def interpret_expression(self, expr, local_vars, allow_recursion):
                 return ret
             return self.interpret_statement(expr, local_vars, allow_recursion - 1)[0]
 
-        m = re.match(r'catch\s*\(', expr)
-        if m:
+        elif m and m.group('catch'):
             # We ignore the catch block
             _, expr = self._separate_at_paren(expr, '}')
             return self.interpret_statement(expr, local_vars, allow_recursion - 1)[0]
 
-        m = re.match(r'for\s*\(', expr)
-        if m:
+        elif m and m.group('for'):
             constructor, remaining = self._separate_at_paren(expr[m.end() - 1:], ')')
             if remaining.startswith('{'):
                 body, expr = self._separate_at_paren(remaining, '}')
             else:
-                m = re.match(r'switch\s*\(', remaining)  # FIXME
-                if m:
-                    switch_val, remaining = self._separate_at_paren(remaining[m.end() - 1:], ')')
+                switch_m = re.match(r'switch\s*\(', remaining)  # FIXME
+                if switch_m:
+                    switch_val, remaining = self._separate_at_paren(remaining[switch_m.end() - 1:], ')')
                     body, expr = self._separate_at_paren(remaining, '}')
                     body = 'switch(%s){%s}' % (switch_val, body)
                 else:
@@ -206,8 +195,7 @@ def interpret_expression(self, expr, local_vars, allow_recursion):
                         f'Premature return in the initialization of a for loop in {constructor!r}')
             return self.interpret_statement(expr, local_vars, allow_recursion - 1)[0]
 
-        m = re.match(r'switch\s*\(', expr)
-        if m:
+        elif m and m.group('switch'):
             switch_val, remaining = self._separate_at_paren(expr[m.end() - 1:], ')')
             switch_val = self.interpret_expression(switch_val, local_vars, allow_recursion)
             body, expr = self._separate_at_paren(remaining, '}')
@@ -250,55 +238,63 @@ def interpret_expression(self, expr, local_vars, allow_recursion):
                 ret = local_vars[var]
             expr = expr[:start] + json.dumps(ret) + expr[end:]
 
-        for op, opfunc in _ASSIGN_OPERATORS:
-            m = re.match(rf'''(?x)
-                (?P<out>{_NAME_RE})(?:\[(?P<index>[^\]]+?)\])?
-                \s*{re.escape(op)}
-                (?P<expr>.*)$''', expr)
-            if not m:
-                continue
-            right_val = self.interpret_expression(m.group('expr'), local_vars, allow_recursion)
+        if not expr:
+            return None
 
-            if m.groupdict().get('index'):
-                lvar = local_vars[m.group('out')]
-                idx = self.interpret_expression(m.group('index'), local_vars, allow_recursion)
-                if not isinstance(idx, int):
-                    raise ExtractorError(f'List indices must be integers: {idx}')
-                cur = lvar[idx]
-                val = opfunc(cur, right_val)
-                lvar[idx] = val
-                return val
+        m = re.match(fr'''(?x)
+            (?P<assign>
+                (?P<out>{_NAME_RE})(?:\[(?P<index>[^\]]+?)\])?\s*
+                (?P<op>{"|".join(map(re.escape, _OPERATORS))})?
+                =(?P<expr>.*)$
+            )|(?P<return>
+                (?!if|return|true|false|null)(?P<name>{_NAME_RE})$
+            )|(?P<indexing>
+                (?P<in>{_NAME_RE})\[(?P<idx>.+)\]$
+            )|(?P<attribute>
+                (?P<var>{_NAME_RE})(?:\.(?P<member>[^(]+)|\[(?P<member2>[^\]]+)\])\s*
+            )|(?P<function>
+                (?P<fname>{_NAME_RE})\((?P<args>[\w$,]*)\)$
+            )''', expr)
+        if m and m.group('assign'):
+            if not m.group('op'):
+                opfunc = lambda curr, right: right
             else:
-                cur = local_vars.get(m.group('out'))
-                val = opfunc(cur, right_val)
-                local_vars[m.group('out')] = val
-                return val
+                opfunc = _OPERATORS[m.group('op')]
+            right_val = self.interpret_expression(m.group('expr'), local_vars, allow_recursion)
+            left_val = local_vars.get(m.group('out'))
+
+            if not m.group('index'):
+                local_vars[m.group('out')] = opfunc(left_val, right_val)
+                return local_vars[m.group('out')]
+            elif left_val is None:
+                raise ExtractorError(f'Cannot index undefined variable: {m.group("out")}')
+
+            idx = self.interpret_expression(m.group('index'), local_vars, allow_recursion)
+            if not isinstance(idx, int):
+                raise ExtractorError(f'List indices must be integers: {idx}')
+            left_val[idx] = opfunc(left_val[idx], right_val)
+            return left_val[idx]
 
-        if expr.isdigit():
+        elif expr.isdigit():
             return int(expr)
 
-        if expr == 'break':
+        elif expr == 'break':
             raise JS_Break()
         elif expr == 'continue':
             raise JS_Continue()
 
-        var_m = re.match(
-            r'(?!if|return|true|false|null)(?P<name>%s)$' % _NAME_RE,
-            expr)
-        if var_m:
-            return local_vars[var_m.group('name')]
+        elif m and m.group('return'):
+            return local_vars[m.group('name')]
 
         with contextlib.suppress(ValueError):
             return json.loads(expr)
 
-        m = re.match(
-            r'(?P<in>%s)\[(?P<idx>.+)\]$' % _NAME_RE, expr)
-        if m:
+        if m and m.group('indexing'):
             val = local_vars[m.group('in')]
             idx = self.interpret_expression(m.group('idx'), local_vars, allow_recursion)
             return val[idx]
 
-        for op, opfunc in _OPERATORS:
+        for op, opfunc in _OPERATORS.items():
             separated = list(self._separate(expr, op))
             if len(separated) < 2:
                 continue
@@ -314,10 +310,7 @@ def interpret_expression(self, expr, local_vars, allow_recursion):
                 raise ExtractorError(f'Premature right-side return of {op} in {expr!r}')
             return opfunc(left_val or 0, right_val)
 
-        m = re.match(
-            r'(?P<var>%s)(?:\.(?P<member>[^(]+)|\[(?P<member2>[^]]+)\])\s*' % _NAME_RE,
-            expr)
-        if m:
+        if m and m.group('attribute'):
             variable = m.group('var')
             member = remove_quotes(m.group('member') or m.group('member2'))
             arg_str = expr[m.end():]
@@ -332,7 +325,6 @@ def assertion(cndn, msg):
                     raise ExtractorError(f'{member} {msg}: {expr}')
 
             def eval_method():
-                nonlocal member
                 if variable == 'String':
                     obj = str
                 elif variable in local_vars:
@@ -342,8 +334,8 @@ def eval_method():
                         self._objects[variable] = self.extract_object(variable)
                     obj = self._objects[variable]
 
+                # Member access
                 if arg_str is None:
-                    # Member access
                     if member == 'length':
                         return len(obj)
                     return obj[member]
@@ -418,9 +410,7 @@ def eval_method():
                     except ValueError:
                         return -1
 
-                if isinstance(obj, list):
-                    member = int(member)
-                return obj[member](argvals)
+                return obj[int(member) if isinstance(obj, list) else member](argvals)
 
             if remaining:
                 return self.interpret_expression(
@@ -429,9 +419,8 @@ def eval_method():
             else:
                 return eval_method()
 
-        m = re.match(r'^(?P<func>%s)\((?P<args>[a-zA-Z0-9_$,]*)\)$' % _NAME_RE, expr)
-        if m:
-            fname = m.group('func')
+        elif m and m.group('function'):
+            fname = m.group('fname')
             argvals = tuple(
                 int(v) if v.isdigit() else local_vars[v]
                 for v in self._separate(m.group('args')))
@@ -441,8 +430,7 @@ def eval_method():
                 self._functions[fname] = self.extract_function(fname)
             return self._functions[fname](argvals)
 
-        if expr:
-            raise ExtractorError('Unsupported JS expression %r' % expr)
+        raise ExtractorError(f'Unsupported JS expression {expr!r}')
 
     def extract_object(self, objname):
         _FUNC_NAME_RE = r'''(?:[a-zA-Z$0-9]+|"[a-zA-Z$0-9]+"|'[a-zA-Z$0-9]+')'''
@@ -471,14 +459,17 @@ def extract_function_code(self, funcname):
         """ @returns argnames, code """
         func_m = re.search(
             r'''(?x)
-                (?:function\s+%s|[{;,]\s*%s\s*=\s*function|var\s+%s\s*=\s*function)\s*
+                (?:
+                    function\s+%(name)s|
+                    [{;,]\s*%(name)s\s*=\s*function|
+                    var\s+%(name)s\s*=\s*function
+                )\s*
                 \((?P<args>[^)]*)\)\s*
-                (?P<code>\{(?:(?!};)[^"]|"([^"]|\\")*")+\})''' % (
-                re.escape(funcname), re.escape(funcname), re.escape(funcname)),
+                (?P<code>{(?:(?!};)[^"]|"([^"]|\\")*")+})''' % {'name': re.escape(funcname)},
             self.code)
         code, _ = self._separate_at_paren(func_m.group('code'), '}')  # refine the match
         if func_m is None:
-            raise ExtractorError('Could not find JS function %r' % funcname)
+            raise ExtractorError(f'Could not find JS function "{funcname}"')
         return func_m.group('args').split(','), code
 
     def extract_function(self, funcname):
@@ -492,11 +483,9 @@ def extract_function_from_code(self, argnames, code, *global_stack):
                 break
             start, body_start = mobj.span()
             body, remaining = self._separate_at_paren(code[body_start - 1:], '}')
-            name = self._named_object(
-                local_vars,
-                self.extract_function_from_code(
-                    [str.strip(x) for x in mobj.group('args').split(',')],
-                    body, local_vars, *global_stack))
+            name = self._named_object(local_vars, self.extract_function_from_code(
+                [x.strip() for x in mobj.group('args').split(',')],
+                body, local_vars, *global_stack))
             code = code[:start] + name + remaining
         return self.build_function(argnames, code, local_vars, *global_stack)