Card image cap

Along side the challenge description above, we got the following python code:

from hashlib import md5
from binascii import hexlify, unhexlify
from secret import key, flag
import sys
BLOCK_LENGTH = 16
KEY_LENGTH = 3
ROUND_COUNT = 16

sbox = [210, 213, 115, 178, 122, 4, 94, 164, 199, 230, 237, 248, 54, 217, 156, 202, 212, 177, 132, 36, 245, 31, 163, 49, 68, 107, 91, 251, 134, 242, 59, 46, 37, 124, 185, 25, 41, 184, 221, 63, 10, 42, 28, 104, 56, 155, 43, 250, 161, 22, 92, 81, 201, 229, 183, 214, 208, 66, 128, 162, 172, 147, 1, 74, 15, 151, 227, 247, 114, 47, 53, 203, 170, 228, 226, 239, 44, 119, 123, 67, 11, 175, 240, 13, 52, 255, 143, 88, 219, 188, 99, 82, 158, 14, 241, 78, 33, 108, 198, 85, 72, 192, 236, 129, 131, 220, 96, 71, 98, 75, 127, 3, 120, 243, 109, 23, 48, 97, 234, 187, 244, 12, 139, 18, 101, 126, 38, 216, 90, 125, 106, 24, 235, 207, 186, 190, 84, 171, 113, 232, 2, 105, 200, 70, 137, 152, 165, 19, 166, 154, 112, 142, 180, 167, 57, 153, 174, 8, 146, 194, 26, 150, 206, 141, 39, 60, 102, 9, 65, 176, 79, 61, 62, 110, 111, 30, 218, 197, 140, 168, 196, 83, 223, 144, 55, 58, 157, 173, 133, 191, 145, 27, 103, 40, 246, 169, 73, 179, 160, 253, 225, 51, 32, 224, 29, 34, 77, 117, 100, 233, 181, 76, 21, 5, 149, 204, 182, 138, 211, 16, 231, 0, 238, 254, 252, 6, 195, 89, 69, 136, 87, 209, 118, 222, 20, 249, 64, 130, 35, 86, 116, 193, 7, 121, 135, 189, 215, 50, 148, 159, 93, 80, 45, 17, 205, 95]

p = [3, 9, 0, 1, 8, 7, 15, 2, 5, 6, 13, 10, 4, 12, 11, 14]

def xor(a, b):
    return bytearray(map(lambda s: s[0] ^ s[1], zip(a, b)))


def fun(key, pt):
    assert len(pt) == BLOCK_LENGTH
    assert len(key) == KEY_LENGTH
    key = bytearray(unhexlify(md5(key).hexdigest()))
    ct = bytearray(pt)
    for _ in range(ROUND_COUNT):
        ct = xor(ct, key)
        for i in range(BLOCK_LENGTH):
            ct[i] = sbox[ct[i]]
        nct = bytearray(BLOCK_LENGTH)
        for i in range(BLOCK_LENGTH):
            nct[i] = ct[p[i]]
        ct = nct
    return hexlify(ct)
def toofun(key, pt):
    assert len(key) == 2 * KEY_LENGTH
    key1 = key[:KEY_LENGTH]
    key2 = key[KEY_LENGTH:]
    ct1 = unhexlify(fun(key1, pt))
    ct2 = fun(key2, ct1)
    return ct2
print("16 bit plaintext: %s" % toofun(key, b"16 bit plaintext"))
print("flag: %s" % toofun(key, flag))

Like the challenge description suggests, a 6 byte key is too big to be bruteforced since there are 2566 possible keys (in other words waaay to many), however the "toofun" function splits the key in half. Each half is used seperatly for encryption when passed to the "fun" function. With this, along side the given plaintext example ("16 bit plaintext") and the corresponding ciphertext ("b'0467a52afa8f15cfb8f0ea40365a6692'") we can establish a Meet In The Middle Attack (the same vulnerability that affects 2-DES) by bruteforcing all possible 3 byte keys instead of 6. If we code the reverse of the "fun" function and pass all possible keys to fun(key,"16 bit plaintext") and reversed_fun(key,b'0467a52afa8f15cfb8f0ea40365a6692') there will be a certain pair of keys (let's call them keyA and keyB) such that fun(keyA,"16 bit plaintext") is equal to reversed_fun(keyB,b'0467a52afa8f15cfb8f0ea40365a6692')

Our original key will be the concatenation of keyA with keyB.

First I wrote a simple function to help invert the permutation schemes created using "p" and "sbox".

sbox = [210, 213, 115, 178, 122, 4, 94, 164, 199, 230, 237, 248, 54, 217, 156, 202, 212, 177, 132, 36, 245, 31, 163, 49, 68, 107, 91, 251, 134, 242, 59, 46, 37, 124, 185, 25, 41, 184, 221, 63, 10, 42, 28, 104, 56, 155, 43, 250, 161, 22, 92, 81, 201, 229, 183, 214, 208, 66, 128, 162, 172, 147, 1, 74, 15, 151, 227, 247, 114, 47, 53, 203, 170, 228, 226, 239, 44, 119, 123, 67, 11, 175, 240, 13, 52, 255, 143, 88, 219, 188, 99, 82, 158, 14, 241, 78, 33, 108, 198, 85, 72, 192, 236, 129, 131, 220, 96, 71, 98, 75, 127, 3, 120, 243, 109, 23, 48, 97, 234, 187, 244, 12, 139, 18, 101, 126, 38, 216, 90, 125, 106, 24, 235, 207, 186, 190, 84, 171, 113, 232, 2, 105, 200, 70, 137, 152, 165, 19, 166, 154, 112, 142, 180, 167, 57, 153, 174, 8, 146, 194, 26, 150, 206, 141, 39, 60, 102, 9, 65, 176, 79, 61, 62, 110, 111, 30, 218, 197, 140, 168, 196, 83, 223, 144, 55, 58, 157, 173, 133, 191, 145, 27, 103, 40, 246, 169, 73, 179, 160, 253, 225, 51, 32, 224, 29, 34, 77, 117, 100, 233, 181, 76, 21, 5, 149, 204, 182, 138, 211, 16, 231, 0, 238, 254, 252, 6, 195, 89, 69, 136, 87, 209, 118, 222, 20, 249, 64, 130, 35, 86, 116, 193, 7, 121, 135, 189, 215, 50, 148, 159, 93, 80, 45, 17, 205, 95]

p = [3, 9, 0, 1, 8, 7, 15, 2, 5, 6, 13, 10, 4, 12, 11, 14]

def inversePerm(p):
  ret = []
  for i in range(0,len(p)):
    ret.append(i)
  for i in range(0,len(p)):
    ret[p[i]] = i
  return ret


invert_p = inversePerm(p)
invert_sbox = inversePerm(sbox)

And then a reverse version of the original “fun” function.

def no_fun_allowed(key, ct):
  ct = unhexlify(ct)
  key = bytearray(unhexlify(md5(key).hexdigest()))
  mitm = bytearray(ct)
  for _ in range(ROUND_COUNT):
   nct = bytearray(BLOCK_LENGTH)
   for i in range(BLOCK_LENGTH):
   nct[i] = mitm[invert_p[i]]
   mitm = nct
   for i in range(BLOCK_LENGTH):
   mitm[i] = invert_sbox[mitm[i]]
   mitm = xor(mitm,key)
  return hexlify(mitm)

Now we're ready to iterate through all possible keys. We'll be passing them as arguments to both the "fun" function and our newly created "no_fun_allowed" function as discussed previously. I chose to save the respective outputs to seperate files as you can see below.

# Stage 0 - Collect all possible keys

ptout = open("ptOutput.txt","w")
ctout = open("ctOutput.txt","w")
for i in range(0,256):
  for j in range(0,256):
    for k in range(0,256):
      key = bytearray(3)
      key[0] = i; key[1] = j; key[2] = k
      ptout.write(hexlify(key) + " :: " + fun(key,b'16 bit plaintext') + "\n")
      ctout.write(hexlify(key) + " :: " + no_fun_allowed(key,b'0467a52afa8f15cfb8f0ea40365a6692') + "\n")
ptout.close()
ctout.close()

I then used python sets to find the hash that's common to both files, which turned out to be: 36e8c221ea84efcaab2c393786f938d0

# Stage 1 - Find Meet In The Middle Collision

ptSet = set(line.split(" ")[2].strip() for line in open("ptOutput.txt","r"))
ctSet = set(line.split(" ")[2].strip() for line in open("ctOutput.txt","r"))

collision =  ptSet & ctSet
print "Collision > " + collision

With the common hash in hand we can now recover both halfs of the original key. When concatenated we get the original key: a277b5c1bc8b

# Stage 2 - Get Key

key = ""
with open("ptOutput.txt","r") as f:
  for line in f:
    if line.split(" ")[2].strip() == collision:
      print "Left side key > " + line.split(" ")[0].strip()
      key += line.split(" ")[0].strip()
with open("ctOutput.txt","r") as f:
  for line in f:
    if line.split(" ")[2].strip() == collision:
      print "Right side key > " + line.split(" ")[0].strip()
      key += line.split(" ")[0].strip()
print "Key > " + key

Finally all that's left is to recover the flag by reversing the "toofun" function.

# Stage 3 - Get Flag

flag = b'04b34e5af4a1f5260f6043b8b9abb4f8'
key = unhexlify(key)

def not_so_fun_anymore(key,ct):
  assert len(key) == 2 * KEY_LENGTH
  key1 = key[:KEY_LENGTH]
  key2 = key[KEY_LENGTH:]

  hash = no_fun_allowed(key2, ct)
  pt = no_fun_allowed(key1,hash)
  return pt.decode("hex")

print "hackim19{" + not_so_fun_anymore(key,flag) + "}"

As a result we get the flag: hackim19{1337_1n_m1ddl38f}

Full Solver:

from hashlib import md5
from binascii import hexlify, unhexlify
import sys
import os
BLOCK_LENGTH = 16
KEY_LENGTH = 3
ROUND_COUNT = 16

sbox = [210, 213, 115, 178, 122, 4, 94, 164, 199, 230, 237, 248, 54, 217, 156, 202, 212, 177, 132, 36, 245, 31, 163, 49, 68, 107, 91, 251, 134, 242, 59, 46, 37, 124, 185, 25, 41, 184, 221, 63, 10, 42, 28, 104, 56, 155, 43, 250, 161, 22, 92, 81, 201, 229, 183, 214, 208, 66, 128, 162, 172, 147, 1, 74, 15, 151, 227, 247, 114, 47, 53, 203, 170, 228, 226, 239, 44, 119, 123, 67, 11, 175, 240, 13, 52, 255, 143, 88, 219, 188, 99, 82, 158, 14, 241, 78, 33, 108, 198, 85, 72, 192, 236, 129, 131, 220, 96, 71, 98, 75, 127, 3, 120, 243, 109, 23, 48, 97, 234, 187, 244, 12, 139, 18, 101, 126, 38, 216, 90, 125, 106, 24, 235, 207, 186, 190, 84, 171, 113, 232, 2, 105, 200, 70, 137, 152, 165, 19, 166, 154, 112, 142, 180, 167, 57, 153, 174, 8, 146, 194, 26, 150, 206, 141, 39, 60, 102, 9, 65, 176, 79, 61, 62, 110, 111, 30, 218, 197, 140, 168, 196, 83, 223, 144, 55, 58, 157, 173, 133, 191, 145, 27, 103, 40, 246, 169, 73, 179, 160, 253, 225, 51, 32, 224, 29, 34, 77, 117, 100, 233, 181, 76, 21, 5, 149, 204, 182, 138, 211, 16, 231, 0, 238, 254, 252, 6, 195, 89, 69, 136, 87, 209, 118, 222, 20, 249, 64, 130, 35, 86, 116, 193, 7, 121, 135, 189, 215, 50, 148, 159, 93, 80, 45, 17, 205, 95]

p = [3, 9, 0, 1, 8, 7, 15, 2, 5, 6, 13, 10, 4, 12, 11, 14]

def inversePerm(p):
  ret = []
  for i in range(0,len(p)):
    ret.append(i)
  for i in range(0,len(p)):
    ret[p[i]] = i
  return ret


invert_p = inversePerm(p)
invert_sbox = inversePerm(sbox)

def xor(a, b):
  return bytearray(map(lambda s: s[0] ^ s[1], zip(a, b)))


def fun(key, pt):
  assert len(pt) == BLOCK_LENGTH
  assert len(key) == KEY_LENGTH
  key = bytearray(unhexlify(md5(key).hexdigest()))
  ct = bytearray(pt)
  for _ in range(ROUND_COUNT):
    ct = xor(ct, key)
    for i in range(BLOCK_LENGTH):
      ct[i] = sbox[ct[i]]
    nct = bytearray(BLOCK_LENGTH)
    for i in range(BLOCK_LENGTH):
      nct[i] = ct[p[i]]
    ct = nct
  return hexlify(ct)

def no_fun_allowed(key, ct):
  ct = unhexlify(ct)
  key = bytearray(unhexlify(md5(key).hexdigest()))
  mitm = bytearray(ct)
  for _ in range(ROUND_COUNT):
    nct = bytearray(BLOCK_LENGTH)
    for i in range(BLOCK_LENGTH):
      nct[i] = mitm[invert_p[i]]
    mitm = nct
    for i in range(BLOCK_LENGTH):
      mitm[i] = invert_sbox[mitm[i]]
    mitm = xor(mitm,key)
  return hexlify(mitm)


# Stage 0 - Collect all possible keys

ptout = open("ptOutput.txt","w")
ctout = open("ctOutput.txt","w")
for i in range(0,256):
  for j in range(0,256):
    for k in range(0,256):
      key = bytearray(3)
      key[0] = i; key[1] = j; key[2] = k
      ptout.write(hexlify(key) + " :: " + fun(key,b'16 bit plaintext') + "\n")
      ctout.write(hexlify(key) + " :: " + no_fun_allowed(key,b'0467a52afa8f15cfb8f0ea40365a6692') + "\n")
ptout.close()
ctout.close()


# Stage 1 - Find Meet In The Middle Collision

ptSet = set(line.split(" ")[2].strip() for line in open("ptOutput.txt","r"))
ctSet = set(line.split(" ")[2].strip() for line in open("ctOutput.txt","r"))

collision =  ptSet & ctSet
print "Collision > " + collision


# Stage 2 - Get Key

key = ""
with open("ptOutput.txt","r") as f:
  for line in f:
    if line.split(" ")[2].strip() == collision:
      print "Left side key > " + line.split(" ")[0].strip()
      key += line.split(" ")[0].strip()
with open("ctOutput.txt","r") as f:
  for line in f:
    if line.split(" ")[2].strip() == collision:
      print "Right side key > " + line.split(" ")[0].strip()
      key += line.split(" ")[0].strip()
print "Key > " + key


# Stage 3 - Get Flag

flag = b'04b34e5af4a1f5260f6043b8b9abb4f8'
key = unhexlify(key)

def not_so_fun_anymore(key,ct):
  assert len(key) == 2 * KEY_LENGTH
  key1 = key[:KEY_LENGTH]
  key2 = key[KEY_LENGTH:]

  hash = no_fun_allowed(key2, ct)
  pt = no_fun_allowed(key1,hash)
  return pt.decode("hex")

print "hackim19{" + not_so_fun_anymore(key,flag) + "}"