import matplotlib as plt
import math
import numpy as np
import json

#Hamming Weight
HW = [bin(n).count("1") for n in range(0, 256)]
def popcount(x):
    x -= (x >> 1) & 0x5555555555555555
    x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333)
    x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0f
    return ((x * 0x0101010101010101) & 0xffffffffffffffff ) >> 56

# Helper Functions
def mean(X):
    return np.sum(X, axis=0)/len(X)

def std_dev(X, X_bar):
    return np.sqrt(np.sum((X-X_bar)**2, axis=0))

def cov(X, X_bar, Y, Y_bar):
    return np.sum((X-X_bar)*(Y-Y_bar), axis=0)

def to16(byte1, byte2):
    return int((byte1 << 8) + byte2)


# Speck Model

NUM_ROUNDS = 22
BLOCK_SIZE = 32
KEY_SIZE = 64
WORD_SIZE = 16


# SHIFTs for SPECK
ALPHA = 7
BETA = 2

mod_mask = (2 ** WORD_SIZE) -1
mod_mask_sub = (2 ** WORD_SIZE)

def ER16(x, y, k):
    rs_x = ((x << (16 - ALPHA)) + (x >> ALPHA)) & mod_mask
    add_sxy = (rs_x + y) & mod_mask
    new_x = k ^ add_sxy
    ls_y = ((y >> (16 - BETA)) + (y << BETA)) & mod_mask
    new_y = new_x ^ ls_y

    return new_x, new_y



def simple_speck(plaintext, key):
    Ct_0 = (int(plaintext[1]) << 8) + int(plaintext[0])
    Ct_1 = (int(plaintext[3]) << 8) + int(plaintext[2])

    Ct_1, Ct_0 = ER16(Ct_1, Ct_0, key)   # fixed 16 bit key of 0x55
    return popcount((Ct_1 << 8) + Ct_0)


def calc_corr(traces, plaintexts):

    maxcpa = [0] * 256   # Correlations

    # Calculate mean and standard derivation
    t_bar = mean(traces)
    o_t = std_dev(traces, t_bar)


    for key in range(0, 256):

        hws = np.array([[simple_speck(pt, (key << 8) + 0x00) for pt in plaintexts]]).transpose()

        hws_bar = mean(hws)
        o_hws = std_dev(hws, hws_bar)
        correlation = cov(traces, t_bar, hws, hws_bar)
        cpaoutput = correlation/(o_t*o_hws)
        maxcpa[key] = max(abs(cpaoutput))

    # Return the two best guesses
    best_guess = int(np.argmax(maxcpa))

    return best_guess, maxcpa


def analyze_correlations(traces, plaintexts):

    steps = 200
    max_traces = 1000

    allkeys = {}
    for j in range(256):
        allkeys[j] = []

    stats = []

    for i in range(steps, max_traces, steps):
        best, corrs = calc_corr(traces[:i], plaintexts[:i])

        for j in range(256):
            allkeys[j].append(corrs[j])
        stats.append(i)



    plt.figure()


    for keybyte, correlations in allkeys.items():
        if keybyte == 0x22:
            plt.plot(stats, correlations, color='gray')
        else:
            plt.plot(stats, correlations, color='lightgray')

    plt.ylabel('Correlation')
    plt.xlabel('Number of Traces')


    #plt.legend(loc="upper left")
    plt.save("plot_hiding.png")
    return allkeys


# Load Inputs
print("[+] Loading Hiding Data")
trace_array = np.load("../sample_traces/5000_encryption_traces_with_hiding_random.npy")
textin_array = np.load("../sample_traces/5000_plaintext_traces_with_hiding_random.npy")

print("[+] Calculating Correlations")
allkeys = analyze_correlations(trace_array, textin_array)


with open("2000k_correlations.json", "w") as out:
    out.write(json.dumps(allkeys))