import strformat
import bitops
# For rotation, the following functions can be used:
# rotateLeftBits, rotateRightBits from std module

#[
  SPECK Helper functions

  * words64ToBytes
  * bytesToWords64
]#

proc bytesToWords64(bytes: seq[uint8], numbytes: int): seq[uint64] =

  # counter for the input bytes
  var j: int = 0

  for i in 0..<int(numbytes/8):
    var tmpNum: uint64 = cast[uint64](bytes[j]) or
    cast[uint64](bytes[j+1]).shl(8) or
    cast[uint64](bytes[j+2]).shl(16) or
    cast[uint64](bytes[j+3]).shl(24) or
    cast[uint64](bytes[j+4]).shl(32) or
    cast[uint64](bytes[j+5]).shl(40) or
    cast[uint64](bytes[j+6]).shl(48) or
    cast[uint64](bytes[j+7]).shl(56)
    result.add(tmpNum)
    j += 8



proc words64ToBytes(words: seq[uint64], numwords: int): seq[uint8] =

  for i in 0..<numwords:
    var tmpCurrentWord = words[i]
    result.add(cast[uint8](tmpCurrentWord))
    result.add(cast[uint8](tmpCurrentWord.shr(8)))
    result.add(cast[uint8](tmpCurrentWord.shr(16)))
    result.add(cast[uint8](tmpCurrentWord.shr(24)))
    result.add(cast[uint8](tmpCurrentWord.shr(32)))
    result.add(cast[uint8](tmpCurrentWord.shr(40)))
    result.add(cast[uint8](tmpCurrentWord.shr(48)))
    result.add(cast[uint8](tmpCurrentWord.shr(56)))


#[
  simpler in c:
  #define r(x,y,k) (x=ror64(x,8), x+=y, x^=k, y=rol64(y,3), y^=x)
]#
proc R(x: uint64, y: uint64, k: uint64): (uint64, uint64) =
  var lx = x
  var ly = y
  var lk = k
  lx = rotaterightbits(lx, 8)
  lx += ly
  lx = lx xor lk
  ly = rotateleftbits(ly, 3)
  ly = ly xor lx
  return (lx, ly)

#[
  simpler in c:
  #define RI(x,y,k) (y^=x, y=ROR64(y,3), x^=k, x-=y, x=ROL64(x,8))
]#
proc RI(x: uint64, y: uint64, k: uint64): (uint64, uint64) =
  var lx = x
  var ly = y
  var lk = k
  ly = ly xor lx
  ly = rotaterightbits(ly, 3)
  lx = lx xor lk
  lx = lx - ly
  lx = rotateleftbits(lx, 8)
  return (lx, ly)



#[
  Modifies the initial key in several rounds
  to output the round-key
]#
proc speck128256KeySchedule(K: seq[uint64]): seq[uint64] =

  var D:uint64 = K[3]
  var C:uint64 = K[2]
  var B:uint64 = K[1]
  var A:uint64 = K[0]

  for i in countUp(0, 33, 3):
    result.add(A)
    (B, A) = R(B, A, cast[uint64](i))  # ER64(B,A,i)
    result.add(A)
    (C, A) = R(C, A, cast[uint64](i+1)) #ER64(C,A,i+1)
    result.add(A)
    (D, A) = R(D, A, cast[uint64](i+2)) # ER64(D,A,i+2)

  result.add(A)



proc speck128256Encrypt(Pt: seq[uint64], rk: seq[uint64]): seq[uint64] =

  result.add(Pt[0])
  result.add(Pt[1])

  for i in 0..<34:
    (result[1], result[0]) = R(result[1], result[0], rk[i])

proc speck128256Decrypt(Ct: seq[uint64], rk: seq[uint64]): seq[uint64] =

  result.add(Ct[0])
  result.add(Ct[1])

  for i in countDown(33, 0, 1):
    (result[1], result[0]) = RI(result[1], result[0], rk[i])


proc test() =

  echo "[+] Starting tests"

  # plaintext as byte array
  var pt = @[uint8 0x70, 0x6f, 0x6f, 0x6e, 0x65, 0x72, 0x2e, 0x20, 0x49, 0x6e, 0x20, 0x74, 0x68, 0x6f, 0x73, 0x65]
  var key = @[uint8 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f]

  # Start encryption routine
  let Pt = bytesToWords64(pt, 16)
  let K = bytesToWords64(key, 32)
  let rk = speck128256KeySchedule(K)
  let Ct = speck128256Encrypt(Pt, rk)
  let ct = words64ToBytes(Ct, 2)

  # decryption
  let reversed_PtB = speck128256Decrypt(Ct, rk)
  let reversed_pt =  words64ToBytes(reversed_PtB, 2)
  #
  # check if every result is correct according to the reference values
  assert Pt == @[uint64 0x202e72656e6f6f70'u64, 0x65736f6874206e49'u64]
  assert K == @[uint64 0x0706050403020100'u64, 0x0f0e0d0c0b0a0908'u64, 0x1716151413121110'u64,  0x1f1e1d1c1b1a1918'u64]
  assert Ct == @[uint64 0x4eeeb48d9c188f43'u64, 0x4109010405c0f53e'u64]
  assert ct == @[uint8 67, 143, 24, 156, 141, 180, 238, 78, 62, 245, 192, 5, 4, 1, 9, 65]
  assert reversed_pt == pt
  echo "[+] All tests successfull"

test()