Speck-Analysis.ipynb 383 KB

Speck Analysis Notebook

The following block initializes the chipwhisperer, as always

import chipwhisperer as cw
import time
WARNING:ChipWhisperer Other:ChipWhisperer update available! See https://chipwhisperer.readthedocs.io/en/latest/installing.html for updating instructions
# Path to the Speck .hex file for reflashing
PATH="/home/msc/documents/obsidian_notes/master-aits/subjects/implementation_attacks_and_countermeasures/praktikum/speck_cpa_cw/cw_firmware/simple-speck-CWLITEARM.hex"
#PATH="/home/juan/documents/master-aits/subjects/implementation_attacks_and_countermeasures/praktikum/speck_cpa_cw/cw_firmware_masked/simple-speck-CWLITEARM.hex"
#PATH="/home/msc/documents/obsidian_notes/master-aits/subjects/implementation_attacks_and_countermeasures/praktikum/speck_cpa_cw/cw_firmware_hiding/simple-speck-CWLITEARM.hex"
scope.dis()
True
def flash(scope, prog):
    cw.program_target(scope, prog, PATH)

def reset_target(scope):
    scope.io.nrst = 'low'
    time.sleep(0.05)
    scope.io.nrst = 'high_z'
    time.sleep(0.05)

try:
    if not scope.connectStatus:
        scope.con()
except NameError:
    scope = cw.scope()

try:
    target = cw.target(scope)
except IOError:
    print("INFO: Caught exception on reconnecting to target - attempting to reconnect to scope first.")
    print("INFO: This is a work-around when USB has died without Python knowing. Ignore errors above this line.")
    scope = cw.scope()
    target = cw.target(scope)

print("INFO: Found ChipWhisperer😍")

prog = cw.programmers.STM32FProgrammer
time.sleep(0.05)
scope.default_setup()
INFO: Found ChipWhisperer😍

Reset the target if required:

reset_target(scope)

Reflash the target if required:

flash(scope, prog)
Detected known STMF32: STM32F302xB(C)/303xB(C)
Extended erase (0x44), this can take ten seconds or more
Attempting to program 5431 bytes at 0x8000000
STM32F Programming flash...
STM32F Reading flash...
Verified flash OK, 5431 bytes

Set an encryption key:

# 32 bytes of encryption key
#encryption_key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"

# 8 byte encryption key
encryption_key = b"\x11\x22\x33\x44\x55\x66\x77\x88"
if len(encryption_key) == 8:
    target.simpleserial_write("s", encryption_key)

Get the encryption key:

target.simpleserial_write("k", b'\x00'*4)
print(target.simpleserial_read("o", 8))
CWbytearray(b'11 22 33 44 55 66 77 88')

Encrypt a 4-byte block:

#pt = b"\x70\x6f\x6f\x6e\x65\x72\x2e\x20\x49\x6e\x20\x74\x68\x6f\x73\x65"
pt = b"\x4c\x69\x74\x65"
target.simpleserial_write("e", pt)
print(target.simpleserial_read("c", 4))
CWbytearray(b'5d 7c 46 6d')

Capturing the Data

The following code snippet traces the encryption process with random plaintetext 2000 times

from tqdm.notebook import trange
import random
import numpy as np

ktp = cw.ktp.Basic()
trace_array = []
textin_array = []

pt = b"\x4c\x69\x74\x65" 
random.seed(0x5222322223) 


N = 20000
for i in trange(N, desc='Capturing traces'):
    pt = bytes([random.randint(0, 255) for i in range(4)])
    scope.arm()
    
    
    target.simpleserial_write('e', pt)
    
    ret = scope.capture()
    if ret:
        print("Target timed out!")
        continue
    
    response = target.simpleserial_read('c', 4)
    
    trace_array.append(scope.get_last_trace())
    textin_array.append(pt)

    
trace_array = np.array(trace_array)
Capturing traces:   0%|          | 0/20000 [00:00<?, ?it/s]

Saving the traces

np.save("sample_traces/20000_encryption_traces_regular.npy", trace_array)
np.save("sample_traces/20000_plaintext_traces_regular.npy", textin_array)

Offline Mode

If no CW is available, load the trace array from a file

import numpy as np
trace_array = np.load("sample_traces/2200_encryption_traces.npy")
textin_array = np.load("sample_traces/2200_plaintext_traces.npy")
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipykernel_1553243/2460027152.py in <module>
----> 1 trace_array = np.load("sample_traces/2200_encryption_traces.npy")
      2 textin_array = np.load("sample_traces/2200_plaintext_traces.npy")

NameError: name 'np' is not defined

Plotting the Data

%matplotlib notebook
import matplotlib.pylab as plt
plt.figure()
plt.plot(trace_array[2], 'orange')
#plt.plot(trace_array[4], 'g')
plt.show()

Hamming Weight and Speck Model

#Hamming Weight
HW = [bin(n).count("1") for n in range(0, 256)]
def popcount(x):
    x -= (x >> 1) & 0x5555555555555555
    x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333)
    x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0f
    return ((x * 0x0101010101010101) & 0xffffffffffffffff ) >> 56

Further functions for Pearson

def mean(X):
    return np.sum(X, axis=0)/len(X)

def std_dev(X, X_bar):
    return np.sqrt(np.sum((X-X_bar)**2, axis=0))

def cov(X, X_bar, Y, Y_bar):
    return np.sum((X-X_bar)*(Y-Y_bar), axis=0)

def to16(byte1, byte2):
    return int((byte1 << 8) + byte2)

The Speck Simulation

The following Code calculates the basic Speck encryption routine (one xor):

import math

NUM_ROUNDS = 22
BLOCK_SIZE = 32
KEY_SIZE = 64
WORD_SIZE = 16


# SHIFTs for SPECK
ALPHA = 7
BETA = 2

mod_mask = (2 ** WORD_SIZE) -1
mod_mask_sub = (2 ** WORD_SIZE)

def ER16(x, y, k):

    rs_x = ((x << (16 - ALPHA)) + (x >> ALPHA)) & mod_mask

    add_sxy = (rs_x + y) & mod_mask

    new_x = k ^ add_sxy

    ls_y = ((y >> (16 - BETA)) + (y << BETA)) & mod_mask

    new_y = new_x ^ ls_y

    return new_x, new_y



def simple_speck(plaintext, key, arg=None):
    Ct_0 = (int(plaintext[1]) << 8) + int(plaintext[0])
    Ct_1 = (int(plaintext[3]) << 8) + int(plaintext[2])
                                                
    
    Ct_1, Ct_0 = ER16(Ct_1, Ct_0, key)   # fixed 16 bit key of 0x55
    return popcount((Ct_1 << 8) + Ct_0)

## According to the Paper "Breaking Speck using CPA"
def PowerRightHalfKey(plaintext, key, knownkey=None):
    y = (int(plaintext[1]) << 8) + int(plaintext[0])
    x = (int(plaintext[3]) << 8) + int(plaintext[2])
    
    x = ((x << (16 - ALPHA)) + (x >> ALPHA)) & mod_mask
    x = (x + y) & mod_mask
    x = x ^ key
    
    return popcount(x)

## According to the Paper "Breaking Speck using CPA" 
def PowerLeftHalfKeyII(plaintext, key, knownkey):
    y = (int(plaintext[1]) << 8) + int(plaintext[0])
    x = (int(plaintext[3]) << 8) + int(plaintext[2])
       
    
    # -------------- for one key -----------------#
    x = ((x << (16 - ALPHA)) + (x >> ALPHA)) & mod_mask           # x = ROR(x, 7)
    x = (x + y) & mod_mask                                        # x = ADD(x, y)
    
    if knownkey == None or len(knownkey) == 0:
        x = x ^ key                                               # x = XOR(x, k)
        return popcount(x)
    else:
        x = x ^ knownkey[0]   
    
    
    # -------------- for second key -----------------#
    
    
    y = ((y >> (16 - BETA)) + (y << BETA)) & mod_mask            # y = ROL(y, 2)
    y = y ^ x                                                    # y = XOR(y, x)
    x = ((x << (16 - ALPHA)) + (x >> ALPHA)) & mod_mask          # x = ROR(x, 7)
    x = (x + y) & mod_mask                                       # x = ADD(x, y)
    
    
    if len(knownkey) == 1:
        x = x ^ key                                             # x = XOR(x, k)
        return popcount(x)
    else:
        x = x ^ knownkey[1]                                     # x = XOR(x, k) 
                                            
    
    # -------------- for third key -----------------#

    y = ((y >> (16 - BETA)) + (y << BETA)) & mod_mask            # y = ROL(y, 2)
    y = y ^ x                                                    # y = XOR(y, x)
    x = ((x << (16 - ALPHA)) + (x >> ALPHA)) & mod_mask          # x = ROR(x, 7)
    x = (x + y) & mod_mask                                       # x = ADD(x, y)
    
    if len(knownkey) == 2:
        x = x ^ key                                              # x = XOR(x, k)
    else:
        x =  x ^ knownkey[2]                                     # x = XOR(x, k)
        
    y = ((y >> (16 - BETA)) + (y << BETA)) & mod_mask            # y = ROL(y, 2)
    y = y ^ x                                                    # y = XOR(y, x)
    
    if len(knownkey) == 2:                                       
        return popcount(y)   
    
    
    # -------------- for fourth key -----------------#
    
    x = ((x << (16 - ALPHA)) + (x >> ALPHA)) & mod_mask          # x = ROR(x, 7)
    x = (x + y) & mod_mask                                       # x = ADD(x, y)
    
    if len(knownkey) == 3:
        x = x ^ key                                              # x = XOR(x, k)
    else:
        x =  x ^ knownkey[3]                                     # x = XOR(x, k)
        
    y = ((y >> (16 - BETA)) + (y << BETA)) & mod_mask            # y = ROL(y, 2)
    y = y ^ x                                                    # y = XOR(y, x)
 
    if len(knownkey) == 3:                                       
        return popcount(y)   

    # -------------- for fith key -----------------#
    x = ((x << (16 - ALPHA)) + (x >> ALPHA)) & mod_mask          # x = ROR(x, 7)
    x = (x + y) & mod_mask                                       # x = ADD(x, y)
    
    x =  x ^ key                                                 # x = XOR(x, k)
    y = ((y >> (16 - BETA)) + (y << BETA)) & mod_mask            # y = ROL(y, 2)
    y = y ^ x                                                    # y = XOR(y, x)
    
    if len(knownkey) == 4:                                       # x = XOR(x, k)
        return popcount(y)   

    return popcount(y)
    
    
    ## According to the Paper "Breaking Speck using CPA" 
def PowerLeftHalfKey(plaintext, key, knownkey):
    pt2 = (int(plaintext[1]) << 8) + int(plaintext[0])
    pt1 = (int(plaintext[3]) << 8) + int(plaintext[2])
    
    temp = ((pt1 << (16 - ALPHA)) + (pt1 >> ALPHA)) & mod_mask
    p1 = (temp + pt2) & mod_mask
    
    r1 = p1 ^ knownkey
    
    temp = ((pt2 >> (16 - BETA)) + (pt2 << BETA)) & mod_mask
    s1 = temp ^ r1
    temp = ((r1 << (16 - ALPHA)) + (r1 >> ALPHA)) & mod_mask
    p2 = (temp + s1) & mod_mask
    
    
    intermediate = (p2) ^ key
    
    
    
    return popcount(intermediate)

def simple_speck_partial(plaintext, key, knownkey):
    Ct_0 = (int(plaintext[1]) << 8) + int(plaintext[0])
    Ct_1 = (int(plaintext[3]) << 8) + int(plaintext[2])
                                                
    Ct_1, Ct_0 = ER16(Ct_1, Ct_0, knownkey)
    

    Ct_1, Ct_0 = ER16(Ct_1, Ct_0, key)   # fixed 16 bit key of 0x55
    return popcount((Ct_1 << 8) + Ct_0)


def speck_keyschedule(plaintext, key, known_keys):
    Ct_0 = (int(plaintext[1]) << 8) + int(plaintext[0])
    Ct_1 = (int(plaintext[3]) << 8) + int(plaintext[2])
    
    for known_key in known_keys:
        Ct_1, Ct_0 = ER16(Ct_1, Ct_0, known_key)    

    Ct_1, Ct_0 = ER16(Ct_1, Ct_0, key)   # fixed 16 bit key of 0x55
    return popcount((Ct_1 << 8) + Ct_0)

   
    
simple_speck(b'\xf5\xf9\xa97', 0x1)
9

This Methods works for calculating the correct key from the Power-Trace

For the following C-Implementation, the function calc_mean_from_trace() seems to gather the correct key from the trace:

u16 i;
Ct[0]=Pt[0]; Ct[1]=Pt[1];

for(i=0;i&lt;22; i++) {
    ER16(Ct[1],Ct[0],0x69);
}

This also works for 2-byte keys:

ER16(Ct[1],Ct[0],0xdead);
def calculate_correlations_plot(traces, plaintexts, model_callback, leftmost=True, other_keybyte=0x00, argument=None):
    
    maxcpa = [0] * 256   # Correlations
    
    # Calculate mean and standard derivation
    t_bar = mean(traces) 
    o_t = std_dev(traces, t_bar)
    
    
    for key in range(0, 256):
        
        # Run the model
        if leftmost:
            hws = np.array([[model_callback(pt, (key << 8) + other_keybyte, argument) for pt in plaintexts]]).transpose()
        elif not leftmost:
            hws = np.array([[model_callback(pt, (other_keybyte << 8) + key, argument) for pt in plaintexts]]).transpose()
        else:
            raise Exception("[-] Invalid Key Position")
            
        
        hws_bar = mean(hws)
        o_hws = std_dev(hws, hws_bar)
        correlation = cov(traces, t_bar, hws, hws_bar)
        cpaoutput = correlation/(o_t*o_hws)
        maxcpa[key] = max(abs(cpaoutput))
    
    # Return the two best guesses
    best_guess = int(np.argmax(maxcpa))
    second_guess = int(np.argsort(maxcpa, axis=0)[-2])
    
    return (maxcpa, best_guess)
def calculate_correlations(traces, plaintexts, model_callback, leftmost=True, other_keybyte=0x00, argument=None):
    
    maxcpa = [0] * 256   # Correlations
    
    # Calculate mean and standard derivation
    t_bar = mean(traces) 
    o_t = std_dev(traces, t_bar)
    
    
    for key in range(0, 256):
        
        # Run the model
        if leftmost:
            hws = np.array([[model_callback(pt, (key << 8) + other_keybyte, argument) for pt in plaintexts]]).transpose()
        elif not leftmost:
            hws = np.array([[model_callback(pt, (other_keybyte << 8) + key, argument) for pt in plaintexts]]).transpose()
        else:
            raise Exception("[-] Invalid Key Position")
            
        
        hws_bar = mean(hws)
        o_hws = std_dev(hws, hws_bar)
        correlation = cov(traces, t_bar, hws, hws_bar)
        cpaoutput = correlation/(o_t*o_hws)
        maxcpa[key] = max(abs(cpaoutput))
    
    # Return the two best guesses
    best_guess = int(np.argmax(maxcpa))
    second_guess = int(np.argsort(maxcpa, axis=0)[-2])
    
    return ([best_guess, maxcpa[best_guess]], [second_guess, maxcpa[second_guess]])
    
def calculate_correlations_retcorr(traces, plaintexts, model_callback, leftmost=True, other_keybyte=0x00, argument=None):
    
    maxcpa = [0] * 256   # Correlations
    
    # Calculate mean and standard derivation
    t_bar = mean(traces) 
    o_t = std_dev(traces, t_bar)
    
    
    for key in range(0, 256):
        
        # Run the model
        if leftmost:
            hws = np.array([[model_callback(pt, (key << 8) + other_keybyte, argument) for pt in plaintexts]]).transpose()
        elif not leftmost:
            hws = np.array([[model_callback(pt, (other_keybyte << 8) + key, argument) for pt in plaintexts]]).transpose()
        else:
            raise Exception("[-] Invalid Key Position")
            
        
        hws_bar = mean(hws)
        o_hws = std_dev(hws, hws_bar)
        correlation = cov(traces, t_bar, hws, hws_bar)
        cpaoutput = correlation/(o_t*o_hws)
        maxcpa[key] = max(abs(cpaoutput))
    
    # Return the two best guesses
    best_guess = int(np.argmax(maxcpa))
    second_guess = int(np.argsort(maxcpa, axis=0)[-2])
    
    return best_guess, maxcpa
    
from tqdm import tnrange


def calc_mean_from_trace(traces, plaintexts):

    maxcpa = [0] * 256

    t_bar = mean(traces) 
    o_t = std_dev(traces, t_bar)

    for key in range(0, 256):
        
        hws = np.array([[simple_speck(textin, (key << 8) + 0x00) for textin in textin_array]]).transpose()
        
        # The following line works for a one byte key
        #hws = np.array([[simple_speck(textin, key) for textin in textin_array]]).transpose()
        

        hws_bar = mean(hws)
        o_hws = std_dev(hws, hws_bar)
        correlation = cov(traces, t_bar, hws, hws_bar)
        cpaoutput = correlation/(o_t*o_hws)
        maxcpa[key] = max(abs(cpaoutput))
    
    plt.figure()
    plt.plot(maxcpa, 'orange')
    plt.show()
    guess = np.argmax(maxcpa)
    print(f"Key guess: (xored with) = ", hex(guess))
    return guess

calc_mean_from_trace(trace_array, textin_array)
Key guess: (xored with) =  0x80
128
def calc_mean_from_trace_two_byte_key(traces, plaintexts):

    best, second = calculate_correlations(traces, plaintexts, simple_speck, True, 0x00)
    print(f"[+] Highest Correlation: {hex(best[0])} ({best[1]}) and Second Highest: {hex(second[0])} ({second[1]})")
    sbest, ssecond = calculate_correlations(traces, plaintexts, simple_speck, False, best[0])
    print(f"[+] Highest Correlation: {hex(sbest[0])} ({sbest[1]}) and Second Highest: {hex(ssecond[0])} ({ssecond[1]})")
    
    
    return to16(best[0], sbest[0])
possible_firstkeys = calc_mean_from_trace_two_byte_key(trace_array, textin_array)
[+] Highest Correlation: 0x22 (0.3829957203458803) and Second Highest: 0x23 (0.3746689384242885)
[+] Highest Correlation: 0x11 (0.44296688610734375) and Second Highest: 0x51 (0.42720967143621397)
%matplotlib notebook
import matplotlib.pylab as plt
import json

def analyze_correlations(traces, plaintexts):
    
    steps = 200
    max_traces = 20000
    
    allkeys = {}
    for j in range(256):
        allkeys[j] = []
        
    stats = []
    
    for i in range(steps, max_traces, steps):
        print(f"[+] Calculating Correlations with {i} traces")
        best, corrs = calculate_correlations_retcorr(traces[:i], plaintexts[:i], simple_speck, True, 0x00)
        
        for j in range(256):
            allkeys[j].append(corrs[j])
        stats.append(i)
 
    plt.figure()
    

    for keybyte, correlations in allkeys.items():
        if keybyte == 0x22:
            plt.plot(stats, correlations, color='gray')
        else:
            plt.plot(stats, correlations, color='lightgray')

    plt.ylabel('Correlation')
    plt.xlabel('Number of Traces')
    
    
    #plt.legend(loc="upper left")
    plt.show()
    return allkeys
    
data = analyze_correlations(trace_array, textin_array)
import json
with open("20kregular.json", "w") as out:
    out.write(json.dumps(data))
[+] Calculating Correlations with 200 traces
[+] Calculating Correlations with 400 traces
[+] Calculating Correlations with 600 traces
[+] Calculating Correlations with 800 traces
[+] Calculating Correlations with 1000 traces
[+] Calculating Correlations with 1200 traces
[+] Calculating Correlations with 1400 traces
[+] Calculating Correlations with 1600 traces
[+] Calculating Correlations with 1800 traces
[+] Calculating Correlations with 2000 traces
[+] Calculating Correlations with 2200 traces
[+] Calculating Correlations with 2400 traces
[+] Calculating Correlations with 2600 traces
[+] Calculating Correlations with 2800 traces
[+] Calculating Correlations with 3000 traces
[+] Calculating Correlations with 3200 traces
[+] Calculating Correlations with 3400 traces
[+] Calculating Correlations with 3600 traces
[+] Calculating Correlations with 3800 traces
[+] Calculating Correlations with 4000 traces
[+] Calculating Correlations with 4200 traces
[+] Calculating Correlations with 4400 traces
[+] Calculating Correlations with 4600 traces
[+] Calculating Correlations with 4800 traces
[+] Calculating Correlations with 5000 traces
[+] Calculating Correlations with 5200 traces
[+] Calculating Correlations with 5400 traces
[+] Calculating Correlations with 5600 traces
[+] Calculating Correlations with 5800 traces
[+] Calculating Correlations with 6000 traces
[+] Calculating Correlations with 6200 traces
[+] Calculating Correlations with 6400 traces
[+] Calculating Correlations with 6600 traces
[+] Calculating Correlations with 6800 traces
[+] Calculating Correlations with 7000 traces
[+] Calculating Correlations with 7200 traces
[+] Calculating Correlations with 7400 traces
[+] Calculating Correlations with 7600 traces
[+] Calculating Correlations with 7800 traces
[+] Calculating Correlations with 8000 traces
[+] Calculating Correlations with 8200 traces
[+] Calculating Correlations with 8400 traces
[+] Calculating Correlations with 8600 traces
[+] Calculating Correlations with 8800 traces
[+] Calculating Correlations with 9000 traces
[+] Calculating Correlations with 9200 traces
[+] Calculating Correlations with 9400 traces
[+] Calculating Correlations with 9600 traces
[+] Calculating Correlations with 9800 traces
[+] Calculating Correlations with 10000 traces
[+] Calculating Correlations with 10200 traces
[+] Calculating Correlations with 10400 traces
[+] Calculating Correlations with 10600 traces
[+] Calculating Correlations with 10800 traces
[+] Calculating Correlations with 11000 traces
[+] Calculating Correlations with 11200 traces
[+] Calculating Correlations with 11400 traces
[+] Calculating Correlations with 11600 traces
[+] Calculating Correlations with 11800 traces
[+] Calculating Correlations with 12000 traces
[+] Calculating Correlations with 12200 traces
[+] Calculating Correlations with 12400 traces
[+] Calculating Correlations with 12600 traces
[+] Calculating Correlations with 12800 traces
[+] Calculating Correlations with 13000 traces
[+] Calculating Correlations with 13200 traces
[+] Calculating Correlations with 13400 traces
[+] Calculating Correlations with 13600 traces
[+] Calculating Correlations with 13800 traces
[+] Calculating Correlations with 14000 traces
[+] Calculating Correlations with 14200 traces
[+] Calculating Correlations with 14400 traces
[+] Calculating Correlations with 14600 traces
[+] Calculating Correlations with 14800 traces
[+] Calculating Correlations with 15000 traces
[+] Calculating Correlations with 15200 traces
[+] Calculating Correlations with 15400 traces
[+] Calculating Correlations with 15600 traces
[+] Calculating Correlations with 15800 traces
[+] Calculating Correlations with 16000 traces
[+] Calculating Correlations with 16200 traces
[+] Calculating Correlations with 16400 traces
[+] Calculating Correlations with 16600 traces
[+] Calculating Correlations with 16800 traces
[+] Calculating Correlations with 17000 traces
[+] Calculating Correlations with 17200 traces
[+] Calculating Correlations with 17400 traces
import json
with open("out.json", "w") as out:
    out.write(json.dumps(data))
plt.figure()
a= {150: 0.30808746837860246, 300: 0.3244230833989644, 450: 0.31733689560071926, 600: 0.2772590504271152, 750: 0.26619518283629395, 900: 0.24880631112549342}
plt.plot(a.keys(), a.values(), color='lightgray')
[<matplotlib.lines.Line2D at 0x7fa929b4b850>]
def calc_mean_from_trace_two_byte_key_plot(traces, plaintexts):

    plt.figure()

    corr, best = calculate_correlations_plot(traces, plaintexts, simple_speck, True, 0x00)
    scorr, second = calculate_correlations_plot(traces, plaintexts, simple_speck, False, best)
 
    rk = plt.plot(corr, 'orange', label="Left Roundkey")
    lk = plt.plot(scorr, 'blue', label="Right Roundkey")
    plt.ylabel('Correlation')
    plt.xlabel('Key Byte')

    plt.legend(loc="upper left")
    plt.show()
    
calc_mean_from_trace_two_byte_key_plot(trace_array, textin_array)

Testing the Paper Algos

The two functions PowerRightHalfKey and PowerLeftHalfKey are implemented like described in the Paper "Breaking Speck using CPA

from tqdm import tnrange


def get_first_keybyte(traces, plaintexts):

    best, second = calculate_correlations(traces, plaintexts, PowerRightHalfKey, False, 0x00)
    print(f"[+] Highest Correlation: {hex(best[0])} ({best[1]}) and Second Highest: {hex(second[0])} ({second[1]})")
    sbest, ssecond = calculate_correlations(traces, plaintexts, PowerRightHalfKey, True, best[0])
    print(f"[+] Highest Correlation: {hex(sbest[0])} ({sbest[1]}) and Second Highest: {hex(ssecond[0])} ({ssecond[1]})")
    
    print(f"Key guess: (xored with) = ", hex(best[0]))
    return to16(sbest[0], best[0])
get_first_keybyte(trace_array, textin_array)
[+] Highest Correlation: 0xdf (0.1501770712319788) and Second Highest: 0xdb (0.1484601794831993)
[+] Highest Correlation: 0x43 (0.17878250618338265) and Second Highest: 0x88 (0.1783296245203708)
Key guess: (xored with) =  0xdf
17375
from tqdm import tnrange


def get_keybytel(traces):

    maxcpa = [0] * 256

    t_bar = mean(traces) 
    o_t = std_dev(traces, t_bar)

    for key in range(0, 256):
        
        hws = np.array([[speck_keyschedule(textin, (0x00 << 8) + key, [0x2211, 0x00dd, 0xa8dc]) for textin in textin_array]]).transpose()
        
        # The following line works for a one byte key
        #hws = np.array([[simple_speck(textin, key) for textin in textin_array]]).transpose()
        

        hws_bar = mean(hws)
        o_hws = std_dev(hws, hws_bar)
        correlation = cov(traces, t_bar, hws, hws_bar)
        cpaoutput = correlation/(o_t*o_hws)
        maxcpa[key] = max(abs(cpaoutput))
    
    plt.figure()
    plt.plot(maxcpa, 'orange')
    plt.show()
    guess = np.argmax(maxcpa)
    second = np.argsort(maxcpa, axis=0)[-2]
    print(f"Key guess: (xored with) = ", hex(guess))
    return maxcpa
cpa2 = get_keybytel(trace_array)
Key guess: (xored with) =  0x53

The following functions receives a new roundkey

This function takes an array of previous roundkeys and try to find the next roundkey

def recv_roundkey(traces, plaintexts, roundkeys):

    best, second = calculate_correlations(traces, plaintexts, speck_keyschedule, False, 0x00, roundkeys)
    print(f"[+] Highest Correlation: {hex(best[0])} ({best[1]}) and Second Highest: {hex(second[0])} ({second[1]})")
    sbest, ssecond = calculate_correlations(traces, plaintexts, speck_keyschedule, True, best[0], roundkeys)
    print(f"[+] Highest Correlation: {hex(sbest[0])} ({sbest[1]}) and Second Highest: {hex(ssecond[0])} ({ssecond[1]})")
    
    return to16(sbest[0], best[0])
    

'''
    This seems to work most reliable (currently, only one byte is wrong....)
'''
def recv_roundkey_two(traces, plaintexts, roundkeys):
    best, second = calculate_correlations(traces, plaintexts, PowerLeftHalfKeyII, False, 0x00, roundkeys)
    print(f"[+] Highest Correlation: {hex(best[0])} ({best[1]}) and Second Highest: {hex(second[0])} ({second[1]})")
    sbest, ssecond = calculate_correlations(traces, plaintexts, PowerLeftHalfKeyII, True, best[0], roundkeys)
    print(f"[+] Highest Correlation: {hex(sbest[0])} ({sbest[1]}) and Second Highest: {hex(ssecond[0])} ({ssecond[1]})")
    
    return to16(sbest[0], best[0])
    
       

Get the required number of round keys

a = recv_roundkey_two(trace_array, textin_array, [0x2211, 0x00dd])
print(hex(a))
[+] Highest Correlation: 0x1f (0.14902296797030073) and Second Highest: 0x98 (0.14816403020606916)
[+] Highest Correlation: 0xe3 (0.17251298199467352) and Second Highest: 0xa3 (0.15950148234725475)
0xe31f
def get_roundkey(traces, firstkey):
    initial_key = [firstkey]
    
    for i in range(4):
        next_rk = recv_roundkey_two(traces, textin_array, initial_key)    
        initial_key.append(next_rk)
    print([hex(k) for k in initial_key])
    return initial_key
def get_key_from_roundkey(roundkey):
    if len(roundkey) != 5:
        raise Exception("Wrong length of roundkey")
    key = [roundkey[0]]
    for i in range(4):
        for keybyte in range(2**16):
            _, out = ER16(keybyte, known_keys[i], i)
            if out == roundkey[i+1]:
                key.append(keybyte)
                print(f"Found: {hex(keybyte)}")
    return [hex(k) for k in key]

rk = get_roundkey(trace_array, int(0x2211))  # get the roundkeys, based on the first roundkey
[+] Highest Correlation: 0xdd (0.9278200457894269) and Second Highest: 0xdf (0.8631820447413402)
[+] Highest Correlation: 0x0 (0.9278200457894269) and Second Highest: 0x1 (0.8126688581419583)
[+] Highest Correlation: 0xdc (0.5462925024514341) and Second Highest: 0xd8 (0.48516519669326075)
[+] Highest Correlation: 0xa8 (0.8432509671276792) and Second Highest: 0x88 (0.7925942806028411)
[+] Highest Correlation: 0x9c (0.6194829714625257) and Second Highest: 0xdc (0.5587442317020618)
[+] Highest Correlation: 0x34 (0.7935821344924826) and Second Highest: 0x14 (0.7829382954252116)
[+] Highest Correlation: 0x21 (0.48675689201334293) and Second Highest: 0xde (0.43080062755434895)
[+] Highest Correlation: 0x4a (0.8508098840188391) and Second Highest: 0x6a (0.8267326599115615)
['0x2211', '0xdd', '0xa8dc', '0x349c', '0x4a21']
get_key_from_roundkey(rk)
Found: 0x4433
Found: 0x6655
Found: 0x8877
Found: 0xdb31
['0x2211', '0x4433', '0x6655', '0x8877', '0xdb31']
# These are the roundkeys that __should__ be there for the key
known_keys = [0x2211, 0x00dd, 0xa8dc, 0x349c, 0xb5de]
get_key_from_roundkey(known_keys)
Found: 0x4433
Found: 0x6655
Found: 0x8877
Found: 0x8899
['0x2211', '0x4433', '0x6655', '0x8877', '0x8899']
get_key_from_roundkey([8721, 221, 43228, 13468, 18977])
Found: 0x4433
Found: 0x6655
Found: 0x8877
Found: 0xdb31
['0x2211', '0x4433', '0x6655', '0x8877', '0xdb31']

Analysis and Plots

x = [
        "kb1",
        "kb2",
        "kb3",
        "kb4",
        "kb5",
        "kb6",
        "kb7",
        "kb8",
        "kb9",
        "kb10",
]


y = [
    0.41,
    0.41,
    0.92,
    0.92,
    0.54,
    0.84,
    0.62,
    0.79,
    0.49,
    0.85,
    ]

%matplotlib notebook
import matplotlib.pylab as plt
plt.figure()
plt.bar(x, y, color="orange")
plt.show()