2222
2323import asyncio
2424from concurrent .futures import ThreadPoolExecutor
25- from typing import Any , Callable , NamedTuple , Set , Union , List
25+ from typing import Callable , NamedTuple , Set , Union , List
2626
2727from .encoding import to_bytes
2828
3333 'remove_padding' ,
3434]
3535
36+
3637class Pass (NamedTuple ):
3738 block_index : int
3839 index : int
3940 byte : int
4041
42+
4143class Fail (NamedTuple ):
4244 block_index : int
4345 message : str
4446 is_critical : bool = False
4547
48+
4649class Done (NamedTuple ):
4750 block_index : int
4851 C0 : List [int ]
@@ -54,26 +57,28 @@ class Done(NamedTuple):
5457OracleFunc = Callable [[bytes ], bool ]
5558ResultCallback = Callable [[ResultType ], bool ]
5659PlainTextCallback = Callable [[List [int ]], bool ]
57-
60+
5861
5962class Context (NamedTuple ):
6063 block_size : int
6164 oracle : OracleFunc
62-
63- executor : ThreadPoolExecutor
65+
66+ executor : ThreadPoolExecutor
6467 loop : asyncio .AbstractEventLoop
65-
68+
6669 tasks : Set [asyncio .Task [ResultType ]]
6770
6871 latest_plaintext : List [int ]
6972 plaintext : List [int ]
70-
73+
7174 result_callback : ResultCallback
7275 plaintext_callback : PlainTextCallback
7376
77+
7478def dummy_callback (* a , ** ka ):
7579 pass
7680
81+
7782def solve (ciphertext : bytes ,
7883 block_size : int ,
7984 oracle : OracleFunc ,
@@ -87,6 +92,7 @@ def solve(ciphertext: bytes,
8792 result_callback , plaintext_callback )
8893 return loop .run_until_complete (future )
8994
95+
9096async def solve_async (ciphertext : bytes ,
9197 block_size : int ,
9298 oracle : OracleFunc ,
@@ -96,43 +102,47 @@ async def solve_async(ciphertext: bytes,
96102 ) -> List [int ]:
97103
98104 ciphertext = list (ciphertext )
99- assert len (ciphertext ) % block_size == 0 , \
100- 'ciphertext length must be a multiple of block_size'
101- assert len (ciphertext ) // block_size > 1 , \
102- 'cannot solve with only one block'
105+
106+ if not len (ciphertext ) % block_size == 0 :
107+ raise ValueError ('ciphertext length must be a multiple of block_size' )
108+ if not len (ciphertext ) // block_size > 1 :
109+ raise ValueError ('cannot solve with only one block' )
103110
104111 ctx = create_solve_context (ciphertext , block_size , oracle , parallel ,
105112 result_callback , plaintext_callback )
106113
107114 while True :
108- done_tasks , _ = await asyncio .wait (ctx .tasks , return_when = asyncio .FIRST_COMPLETED )
109-
115+ done_tasks , _ = await asyncio .wait (ctx .tasks ,
116+ return_when = asyncio .FIRST_COMPLETED )
117+
110118 for task in done_tasks :
111119 result = await task
112-
120+
113121 ctx .result_callback (result )
114122 ctx .tasks .remove (task )
115-
123+
116124 if isinstance (result , Pass ):
117- update_latest_plaintext (ctx , result .block_index , result .index , result .byte )
125+ update_latest_plaintext (
126+ ctx , result .block_index , result .index , result .byte )
118127 if isinstance (result , Done ):
119128 update_plaintext (ctx , result .block_index , result .C0 , result .X1 )
120-
129+
121130 if len (ctx .tasks ) == 0 :
122131 break
123-
132+
124133 # Check if any block failed
125134 error_block_indices = set ()
126-
135+
127136 for i , byte in enumerate (ctx .plaintext ):
128137 if byte is None :
129138 error_block_indices .add (i // block_size + 1 )
130-
139+
131140 for idx in error_block_indices :
132141 result_callback (Fail (idx , f'cannot decrypt cipher block { idx } ' , True ))
133-
142+
134143 return ctx .plaintext
135144
145+
136146def create_solve_context (ciphertext , block_size , oracle , parallel ,
137147 result_callback , plaintext_callback ) -> Context :
138148 tasks = set ()
@@ -143,30 +153,27 @@ def create_solve_context(ciphertext, block_size, oracle, parallel,
143153
144154 plaintext = [None ] * (len (cipher_blocks ) - 1 ) * block_size
145155 latest_plaintext = plaintext .copy ()
146-
156+
147157 executor = ThreadPoolExecutor (parallel )
148158 loop = asyncio .get_running_loop ()
149159 ctx = Context (block_size , oracle , executor , loop , tasks ,
150160 latest_plaintext , plaintext ,
151161 result_callback , plaintext_callback )
152-
162+
153163 for i in range (1 , len (cipher_blocks )):
154164 run_block_task (ctx , i , cipher_blocks [i - 1 ], cipher_blocks [i ], [])
155165
156166 return ctx
157167
168+
158169def run_block_task (ctx : Context , block_index , C0 , C1 , X1 ):
159170 future = solve_block (ctx , block_index , C0 , C1 , X1 )
160171 task = ctx .loop .create_task (future )
161172 ctx .tasks .add (task )
162173
163- async def solve_block (
164- ctx : Context ,
165- block_index : int ,
166- C0 : List [int ],
167- C1 : List [int ],
168- X1 : List [int ] = [],
169- ) -> ResultType :
174+
175+ async def solve_block (ctx : Context , block_index : int , C0 : List [int ],
176+ C1 : List [int ], X1 : List [int ] = []) -> ResultType :
170177 # X1 = decrypt(C1)
171178 # P1 = xor(C0, X1)
172179
@@ -195,46 +202,56 @@ async def solve_block(
195202 for byte in hits :
196203 X1_test = [byte ^ padding , * X1 ]
197204 run_block_task (ctx , block_index , C0 , C1 , X1_test )
198-
205+
199206 return Pass (block_index , index , byte ^ padding ^ C0 [index ])
200207
201- async def get_oracle_hits (ctx : Context , C0 : List [int ], C1 : List [int ], index : int ):
202-
208+
209+ async def get_oracle_hits (ctx : Context , C0 : List [int ], C1 : List [int ],
210+ index : int ):
211+
203212 C0 = C0 .copy ()
204213 futures = {}
205-
214+
206215 for byte in range (256 ):
207216 C0 [index ] = byte
208217 ciphertext = bytes (C0 + C1 )
209218 futures [byte ] = ctx .loop .run_in_executor (
210219 ctx .executor , ctx .oracle , ciphertext )
211-
220+
212221 hits = []
213-
222+
214223 for byte , future in futures .items ():
215224 is_valid = await future
216225 if is_valid :
217226 hits .append (byte )
218-
227+
219228 return hits
220229
221- def update_latest_plaintext (ctx : Context , block_index : int , index : int , byte : int ):
230+
231+ def update_latest_plaintext (ctx : Context , block_index : int , index : int ,
232+ byte : int ):
233+
222234 i = (block_index - 1 ) * ctx .block_size + index
223235 ctx .latest_plaintext [i ] = byte
224236 ctx .plaintext_callback (ctx .latest_plaintext )
225237
226- def update_plaintext (ctx : Context , block_index : int , C0 : List [int ], X1 : List [int ]):
238+
239+ def update_plaintext (ctx : Context , block_index : int , C0 : List [int ],
240+ X1 : List [int ]):
241+
227242 assert len (C0 ) == len (X1 ) == ctx .block_size
228243 block = compute_plaintext (C0 , X1 )
229-
244+
230245 i = (block_index - 1 ) * ctx .block_size
231246 ctx .latest_plaintext [i :i + ctx .block_size ] = block
232247 ctx .plaintext [i :i + ctx .block_size ] = block
233248 ctx .plaintext_callback (ctx .plaintext )
234249
250+
235251def compute_plaintext (C0 : List [int ], X1 : List [int ]):
236252 return [c ^ x for c , x in zip (C0 , X1 )]
237253
254+
238255def convert_to_bytes (byte_list : List [int ], replacement = b' ' ):
239256 '''
240257 Convert a list of int into bytes, replace invalid byte with replacement.
@@ -249,6 +266,7 @@ def convert_to_bytes(byte_list: List[int], replacement=b' '):
249266 byte_list [i ] = byte
250267 return bytes (byte_list )
251268
269+
252270def remove_padding (data : Union [str , bytes , List [int ]]) -> bytes :
253271 '''
254272 Remove PKCS#7 padding bytes.
0 commit comments