]> jfr.im git - yt-dlp.git/blame - youtube_dl/aes.py
add an aes implementation
[yt-dlp.git] / youtube_dl / aes.py
CommitLineData
f3bcebb1 1__all__ = ['aes_encrypt', 'key_expansion', 'aes_ctr_decrypt', 'aes_decrypt_text']
2
3import base64
4from math import ceil
5
6BLOCK_SIZE_BYTES = 16
7
8def aes_ctr_decrypt(data, key, counter):
9 """
10 Decrypt with aes in counter mode
11
12 @param {int[]} data cipher
13 @param {int[]} key 16/24/32-Byte cipher key
14 @param {instance} counter Instance whose next_value function (@returns {int[]} 16-Byte block)
15 returns the next counter block
16 @returns {int[]} decrypted data
17 """
18 expanded_key = key_expansion(key)
19 block_count = int(ceil(float(len(data)) / BLOCK_SIZE_BYTES))
20
21 decrypted_data=[]
22 for i in range(block_count):
23 counter_block = counter.next_value()
24 block = data[i*BLOCK_SIZE_BYTES : (i+1)*BLOCK_SIZE_BYTES]
25 block += [0]*(BLOCK_SIZE_BYTES - len(block))
26
27 cipher_counter_block = aes_encrypt(counter_block, expanded_key)
28 decrypted_data += xor(block, cipher_counter_block)
29 decrypted_data = decrypted_data[:len(data)]
30
31 return decrypted_data
32
33def key_expansion(data):
34 """
35 Generate key schedule
36
37 @param {int[]} data 16/24/32-Byte cipher key
38 @returns {int[]} 176/208/240-Byte expanded key
39 """
40 data = data[:] # copy
41 rcon_iteration = 1
42 key_size_bytes = len(data)
43 expanded_key_size_bytes = (key_size_bytes/4 + 7) * BLOCK_SIZE_BYTES
44
45 while len(data) < expanded_key_size_bytes:
46 temp = data[-4:]
47 temp = key_schedule_core(temp, rcon_iteration)
48 rcon_iteration += 1
49 data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
50
51 for _ in range(3):
52 temp = data[-4:]
53 data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
54
55 if key_size_bytes == 32:
56 temp = data[-4:]
57 temp = sub_bytes(temp)
58 data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
59
60 for _ in range(3 if key_size_bytes == 32 else 2 if key_size_bytes == 24 else 0):
61 temp = data[-4:]
62 data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
63 data = data[:expanded_key_size_bytes]
64
65 return data
66
67def aes_encrypt(data, expanded_key):
68 """
69 Encrypt one block with aes
70
71 @param {int[]} data 16-Byte state
72 @param {int[]} expanded_key 176/208/240-Byte expanded key
73 @returns {int[]} 16-Byte cipher
74 """
75 rounds = len(expanded_key) / BLOCK_SIZE_BYTES - 1
76
77 data = xor(data, expanded_key[:BLOCK_SIZE_BYTES])
78 for i in range(1, rounds+1):
79 data = sub_bytes(data)
80 data = shift_rows(data)
81 if i != rounds:
82 data = mix_columns(data)
83 data = xor(data, expanded_key[i*BLOCK_SIZE_BYTES : (i+1)*BLOCK_SIZE_BYTES])
84
85 return data
86
87def aes_decrypt_text(data, password, key_size_bytes):
88 """
89 Decrypt text
90 - The first 8 Bytes of decoded 'data' are the 8 high Bytes of the counter
91 - The cipher key is retrieved by encrypting the first 16 Byte of 'password'
92 with the first 'key_size_bytes' Bytes from 'password' (if necessary filled with 0's)
93 - Mode of operation is 'counter'
94
95 @param {str} data Base64 encoded string
96 @param {str,unicode} password Password (will be encoded with utf-8)
97 @param {int} key_size_bytes Possible values: 16 for 128-Bit, 24 for 192-Bit or 32 for 256-Bit
98 @returns {str} Decrypted data
99 """
100 NONCE_LENGTH_BYTES = 8
101
102 data = map(lambda c: ord(c), base64.b64decode(data))
103 password = map(lambda c: ord(c), password.encode('utf-8'))
104
105 key = password[:key_size_bytes] + [0]*(key_size_bytes - len(password))
106 key = aes_encrypt(key[:BLOCK_SIZE_BYTES], key_expansion(key)) * (key_size_bytes / BLOCK_SIZE_BYTES)
107
108 nonce = data[:NONCE_LENGTH_BYTES]
109 cipher = data[NONCE_LENGTH_BYTES:]
110
111 class Counter:
112 __value = nonce + [0]*(BLOCK_SIZE_BYTES - NONCE_LENGTH_BYTES)
113 def next_value(self):
114 temp = self.__value
115 self.__value = inc(self.__value)
116 return temp
117
118 decrypted_data = aes_ctr_decrypt(cipher, key, Counter())
119 plaintext = ''.join(map(lambda x: chr(x), decrypted_data))
120
121 return plaintext
122
123RCON = (0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36)
124SBOX = (0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
125 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
126 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
127 0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
128 0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
129 0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
130 0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
131 0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
132 0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
133 0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
134 0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
135 0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
136 0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
137 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
138 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
139 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16)
140MIX_COLUMN_MATRIX = ((2,3,1,1),
141 (1,2,3,1),
142 (1,1,2,3),
143 (3,1,1,2))
144
145def sub_bytes(data):
146 return map(lambda x: SBOX[x], data)
147
148def rotate(data):
149 return data[1:] + [data[0]]
150
151def key_schedule_core(data, rcon_iteration):
152 data = rotate(data)
153 data = sub_bytes(data)
154 data[0] = data[0] ^ RCON[rcon_iteration]
155
156 return data
157
158def xor(data1, data2):
159 return map(lambda (x,y): x^y, zip(data1, data2))
160
161def mix_column(data):
162 data_mixed = []
163 for row in range(4):
164 mixed = 0
165 for column in range(4):
166 addend = data[column]
167 if MIX_COLUMN_MATRIX[row][column] in (2,3):
168 addend <<= 1
169 if addend > 0xff:
170 addend &= 0xff
171 addend ^= 0x1b
172 if MIX_COLUMN_MATRIX[row][column] == 3:
173 addend ^= data[column]
174 mixed ^= addend & 0xff
175 data_mixed.append(mixed)
176 return data_mixed
177
178def mix_columns(data):
179 data_mixed = []
180 for i in range(4):
181 column = data[i*4 : (i+1)*4]
182 data_mixed += mix_column(column)
183 return data_mixed
184
185def shift_rows(data):
186 data_shifted = []
187 for column in range(4):
188 for row in range(4):
189 data_shifted.append( data[((column + row) & 0b11) * 4 + row] )
190 return data_shifted
191
192def inc(data):
193 data = data[:] # copy
194 for i in range(len(data)-1,-1,-1):
195 if data[i] == 255:
196 data[i] = 0
197 else:
198 data[i] = data[i] + 1
199 break
200 return data