]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/aes.py
[MainStreaming] Add extractor (#2180)
[yt-dlp.git] / yt_dlp / aes.py
index f52b992df0097c57bcdafd530c2ba6f4961b53e6..8503e3dfd63524ce5246b2d826e3e9dea6fb56d3 100644 (file)
@@ -28,6 +28,48 @@ def aes_gcm_decrypt_and_verify_bytes(data, key, tag, nonce):
 BLOCK_SIZE_BYTES = 16
 
 
+def aes_ecb_encrypt(data, key, iv=None):
+    """
+    Encrypt with aes in ECB mode
+
+    @param {int[]} data        cleartext
+    @param {int[]} key         16/24/32-Byte cipher key
+    @param {int[]} iv          Unused for this mode
+    @returns {int[]}           encrypted data
+    """
+    expanded_key = key_expansion(key)
+    block_count = int(ceil(float(len(data)) / BLOCK_SIZE_BYTES))
+
+    encrypted_data = []
+    for i in range(block_count):
+        block = data[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES]
+        encrypted_data += aes_encrypt(block, expanded_key)
+    encrypted_data = encrypted_data[:len(data)]
+
+    return encrypted_data
+
+
+def aes_ecb_decrypt(data, key, iv=None):
+    """
+    Decrypt with aes in ECB mode
+
+    @param {int[]} data        cleartext
+    @param {int[]} key         16/24/32-Byte cipher key
+    @param {int[]} iv          Unused for this mode
+    @returns {int[]}           decrypted data
+    """
+    expanded_key = key_expansion(key)
+    block_count = int(ceil(float(len(data)) / BLOCK_SIZE_BYTES))
+
+    encrypted_data = []
+    for i in range(block_count):
+        block = data[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES]
+        encrypted_data += aes_decrypt(block, expanded_key)
+    encrypted_data = encrypted_data[:len(data)]
+
+    return encrypted_data
+
+
 def aes_ctr_decrypt(data, key, iv):
     """
     Decrypt with aes in counter mode
@@ -178,7 +220,7 @@ def aes_encrypt(data, expanded_key):
         data = sub_bytes(data)
         data = shift_rows(data)
         if i != rounds:
-            data = mix_columns(data)
+            data = list(iter_mix_columns(data, MIX_COLUMN_MATRIX))
         data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES])
 
     return data
@@ -197,7 +239,7 @@ def aes_decrypt(data, expanded_key):
     for i in range(rounds, 0, -1):
         data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES])
         if i != rounds:
-            data = mix_columns_inv(data)
+            data = list(iter_mix_columns(data, MIX_COLUMN_MATRIX_INV))
         data = shift_rows_inv(data)
         data = sub_bytes_inv(data)
     data = xor(data, expanded_key[:BLOCK_SIZE_BYTES])
@@ -375,49 +417,23 @@ def xor(data1, data2):
     return [x ^ y for x, y in zip(data1, data2)]
 
 
-def rijndael_mul(a, b):
-    if a == 0 or b == 0:
-        return 0
-    return RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[a] + RIJNDAEL_LOG_TABLE[b]) % 0xFF]
-
-
-def mix_column(data, matrix):
-    data_mixed = []
-    for row in range(4):
-        mixed = 0
-        for column in range(4):
-            # xor is (+) and (-)
-            mixed ^= rijndael_mul(data[column], matrix[row][column])
-        data_mixed.append(mixed)
-    return data_mixed
-
-
-def mix_columns(data, matrix=MIX_COLUMN_MATRIX):
-    data_mixed = []
-    for i in range(4):
-        column = data[i * 4: (i + 1) * 4]
-        data_mixed += mix_column(column, matrix)
-    return data_mixed
-
-
-def mix_columns_inv(data):
-    return mix_columns(data, MIX_COLUMN_MATRIX_INV)
+def iter_mix_columns(data, matrix):
+    for i in (0, 4, 8, 12):
+        for row in matrix:
+            mixed = 0
+            for j in range(4):
+                # xor is (+) and (-)
+                mixed ^= (0 if data[i:i + 4][j] == 0 or row[j] == 0 else
+                          RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[data[i + j]] + RIJNDAEL_LOG_TABLE[row[j]]) % 0xFF])
+            yield mixed
 
 
 def shift_rows(data):
-    data_shifted = []
-    for column in range(4):
-        for row in range(4):
-            data_shifted.append(data[((column + row) & 0b11) * 4 + row])
-    return data_shifted
+    return [data[((column + row) & 0b11) * 4 + row] for column in range(4) for row in range(4)]
 
 
 def shift_rows_inv(data):
-    data_shifted = []
-    for column in range(4):
-        for row in range(4):
-            data_shifted.append(data[((column - row) & 0b11) * 4 + row])
-    return data_shifted
+    return [data[((column - row) & 0b11) * 4 + row] for column in range(4) for row in range(4)]
 
 
 def shift_block(data):