Merkle-Hellman knapsack and the LLL algorithm - LostMyPlaintext

An ilustrative implementation of the original Merkle-Hellman knapsack cryptosystem and the algorithm that can break it in polynomial time.


Content:


Merkle-Hellman Knapsack Cryptosystem

Making use of elementary math in a pretty clever way, the Merkle-Hellman knapsack, created by Ralph Merkle and Martin Hellman in 1978, was among the first proposed public key cryptosystems. It's security relies on the subset sum problem (specific case of the knapsack problem) which is known to be NP-Complete and can be described as follows:

Say you have a set (or knapsack) A of n seemingly random integers and some other integer x that is not in set A. Assuming it is possible, find a combination of elements of A that sum up to x. Or to put it in a more interesting way: assuming it is possible, find a set B, also of size n, where Bi ∈ {0,1}, such that sum(A0B0+A1B1+...+An-1Bn-1) = x.

When you think of this problem considering the second description you might get an idea of how it could be used to cipher a message: You convert your message to binary and use that as your set B then you generate a set A of equal size and compute sum(A0B0+A1B1+...+An-1Bn-1) = x. And now if you use x as your ciphertext any potential attacker will have great dificulty discovering what your original message was. However, this also means it will be very hard for the person to whom you want to send a secret message to read the original plaintext if they only know the value x and the knapsack A. Consider the following example:

Say you're presented with A = [1875, 621, 1660, 606, 170, 1382, 262, 521] and x = 2802. Since A is quite small in this example you can probably figure out that B = [0,1,1,0,0,0,0,1], which is binary for "a". However, notice that it would be very hard, if not impossible, to do the same for a larger A in a significantly efficient manner.

To figure out how we can create a public-key cryptosystem out of this we first need to talk about superincreasing knapsacks: meaning that every element of the knapsack is greater than the sum of all previous elements. If your knapsack has this property you can actually recover B in polynomial time only knowing A and x, meaning that the subset sum problem can be solved in an efficient way for superincreasing knapsacks. Let's examine the follwing example:

Consider the superincreasing knapsack w = [9, 15, 28, 60, 116, 236, 466, 941] as well as x = 509. You can recover B with the following algorithm1:

1| B = [] 2| for i from 7 to 0 do: 3| if wi <= x then: 4| x = x - wi 5| B.appendLeft(1) 6| else: 7| B.appendLeft(0) 8| return B

Knowing this, Merkle and Hellman figured they could use modular arithmetic to transform a superincreasing knapsack into a general knapsack so that the former, as well the transformation process, could be used as a private key and the latter as a public key. If Alice wants to receive secret messages from Bob she first selects a super increasing knapsack w and applies this transformation to obtain a seemingly general knapsack A. Then she makes A public and has Bob compute sum(A0B0+A1B1+...+An-1Bn-1) = x. (again, we'll refer to the plaintext message in binary as set "B"). Finally Bob sends the ciphertext x to Alice who applies the inverse of the transformation to it and uses the result along side the superincrasing knapsack to solve the easier version of the subset sum problem, obtaining the plaintext message. In this scenario, Alice can easily decrypt the message, because she knows how to apply the transformation while an attacker, only knowing x and A, will apparently need to solve the knapsack problem, known to be hard.

But how does this transformation work exactly?

As mentioned before, it is done through modular arithmetic. After selecting a superincreasing sequence w Alice further selects a modulo q with the follwing property: q > (w0+w1+...+wn-1) (where n is the lenght of w). Furthermore, Alice selects a coefficient r such that q and r are coprime, in other words, such that gcd(q,r) = 1. Alice's private key will consist of the knapsack w and the two values q and r. To generate her public key Alice computes r*wi mod q (for i from 0 to n-1) and the resulting set of values will become her general knapsack A. Consider the follwing example:

Alice generates the superincresing knapsack w = [3, 10, 24, 46, 88, 175, 352, 707], selects the values:

q = 1409 > 3+10+24+46+88+175+352+707 and r = 831 (gcd(1409,831) = 1) and keeps these to herself. Then she computes the values for the public knapsack A:

831*3 mod 1409 = 1084

831*10 mod 1409 = 1265

831*24 mod 1409 = 218

831*46 mod 1409 = 183

831*88 mod 1409 = 1269

831*175 mod 1409 = 298

831*352 mod 1409 = 849

831*707 mod 1409 = 1373

And finally makes A = [1084, 1265, 218, 183, 1269, 298, 849, 1373] public.

Now suppose Bob wants to send Alice the message "01100011" (binary form of "c"). He first computes 1084*0+1265*1+218*1+183*0+1269*0+298*0+849*1+1373*1 = 3705 and sends this result to Alice.

On her end Alice calculates I = r inverse modulo q, in other words, a value I such that r*I ≅ 1 mod q . Now she can inverse the tranformation by calculating 3705*I mod q. By using the result (let's call it x) along side her private superincreasing knapsack she can retrieve the plaintext message using the same algorithm1 presented above. Obviously the message in this example is very short (making this specific ciphertext quite vulnerable because of it) however this whole process seems to make it so that a potential attacker (only knowing A and the ciphertext) would have to solve the general subset problem (or general knapsack problem) in order to recover the plaintext. However, this is not necessarily the case.


LLL Algorithm:

To understand how we can break the Merkle-Hellman knapsack cryptosystem we first need o understand what lattices are (note that you'll need some elementary Linear Algebra knowledge to follow along).

We can define a lattice L as the ℤ-linear span of a set of n linearly independent vectors. Or to put it more simply:

L = {a1v1+a2v2+...+anvn where ai ∈ ℤ}

The vectors v1,...,vn form a basis of L.

Geometrically, a lattice in R2 would, for example, look somthing like the following:

Now let's analyze how we can relate this to breaking the Merkle-Hellman knapsack.

Suppose you have a 1xN matrix X and 1x1 matrix Y and want to figure out the Nx1 solution matrix S to the matrix equation: XS = Y, where the entries of S can only be 0 or 1 (in the context of the Merkle-Hellman knapsack, S would be the plaintext the attacker is after).

Now consider the follwing matrices:

[ 1 0 ... 0 0 ] [ 0 1 ... 0 0 ] [ . . ... . . ] M = [ . . ... . . ] [ . . ... . . ] [ 0 0 ... 1 0 ] [ X1,1 X1,2 ... X1,n -Y1,1 ]

[ S1,1 ] [ S2,1 ] [ . ] K = [ . ] [ . ] [ Sn,1 ] [ 1 ]

[ S1,1 ] [ S2,1 ] [ . ] C = [ . ] [ . ] [ Sn,1 ] [ 0 ]

Note that if S is a solution to the matrix equation XS = Y then the matrix equation MK = C holds true.

Now consider m1,m2,...,mn to be the columns of M. We can write C as C = K1,1m1+K2,1m2...Kn,1mn. Meaning that C is in fact the lattice spanned by the columns of M.

Now further note that because the entries of C are either 0 or 1 the Euclidean vector lenght of the vector C will be quite short:

C ‖ = sqrt(S12,S22...Sn2) <= sqrt(n)

And if we were to calculate C then we would also know S, in other words we would know the binary form of the plaintext.

This is where the LLL algorithm (created by Arjen Lenstra, Hendrik Lenstra and László Lovász) comes in. Given a basis for a lattice as input, this algorithm calculates a reduced basis ( a basis with short, and close to orthogonal, vectors) of the same lattice.

If we use the matrix M as input, the LLL algorithm will output short vectors in the lattice spanned by the columns of M meaning that (although it is not guaranteed) there is a good enough chance that among those vectors we will find C

Consider the following example:

Suppose Alice publishes the following public key: A = [367, 272, 1753, 708, 17, 1623, 1562, 978]. Bob wants to send the message "01100001" to Alice and uses this public key to encrypt it obtaning the following ciphertext: ct = 3003.

Only knowing Alice's public key A and the ciphertext ct and attacker first constructs the following matrix:

	[  1   0    0   0  0    0    0   0     0]
	[  0   1    0   0  0    0    0   0     0]
	[  0   0    1   0  0    0    0   0     0]
	[  0   0    0   1  0    0    0   0     0]
M =	[  0   0    0   0  1    0    0   0     0]
	[  0   0    0   0  0    1    0   0     0]
	[  0   0    0   0  0    0    1   0     0]
	[  0   0    0   0  0    0    0   1     0]
	[367 272 1753 708 17 1623 1562 978 -3003]

Then applies the LLL algorithm using it as input, obtaining the follwing result:

         [0  0  2  1  0 -2 -1 -2 -1] 
         [1 -1  0  0 -1 -1  1  0 -1]
         [1  0  0  0  1 -1  0  1  2]
         [0 -1  1  0  1  0 -1  2 -1]
LLL(M) = [0  0  0  2  1  0  0 -1  1]
         [0  0  0  1 -1 -2 -2  0  2]
         [0  0  1  0 -1  0  0  1  1]
         [1  1  0  1  1  0 -2 -1 -1]
         [0 -2  1 -1 -1  1  1 -1  0]

Notice how the first eight entries of the first column of the resulting matrix form the precise message Bob encrypted: "01100001" (which is binary for "a"). Thus the attacker has recovered the original plaintext using nothing but public information.

The LLL algorithm:

LLL(M): 1 | V = GS(M) # "GS" refers to the Gram-Schmidt process(**) 2 | Wi,j = inner_product(Mi,Vj)/inner_product(Vi,Vj) # where the values of i and j are the most current 3 | h = 1 4 | while h <= n: # n refers to the number of elemts in M 5 | for j from h-1 to 0: 6 | if abs(Wh,j) > 1/2: 7 | Mh = Mh - round(Whj)*Mj 8 | V = GS(M) # update V and W 9 | Wi,j = inner_product(Mi,Vj)/inner_product(Vi,Vj) # where the values of i and j are the most current 10| if inner_product(Vh,Vh) >= (0.99(***)-(Wh,h-1)2)*inner_product(Vh-1,Vh-1): 11| h = h + 1 12| else: 13| swap(Mh,Mh-1) 14| V = GS(M) # update V and W 15| Wi,j = inner_product(Mi,Vj)/inner_product(Vi,Vj) # where the values of i and j are the most current 16| h = max(h-1,1) 17| return M

(**)The Gram-Schmidt process returns an orthogonal basis for the subspace spanned by the columns of M. (***)The value 0.99 is chosen for the Lovász Condition to obtain a strong reduction.


About the implementation:

A few notes and clarifications before presenting the code:

  • This implementation is based on the original knapsack cryptosystem proposed by Merkle and Hellman, if you found this interesting you might want to read on other variations of cryptosystems based on the knapsack problem.
  • For this implementation the public and private knapsacks always have the same size (in bits) as the plaintext. You could also use knapsacks that are shorter and encrypt a given plaintext by blocks.
  • One thing you can do to make this cryptosystem slightly stronger is to permutate the public key values. Note that although this was not implemented it would not prevent it from being broken.
  • Although I wrote this program so that it can be run with python3, the "break cipher" feature will not work properly if you do so. The intended way is to run it with sagemath. This is to avoid problems with floating point overflows.
  • A lot of sagemath features were purposely not used to better illustrate how certain processes work, namely the LLL algorithm.
  • Code can also be found on my github: https://github.com/0xA2/Merkle-Hellman-Knapsack-and-LLL/blob/master/knapsack.sage


The actual code:

####################################################################
#                                                                  #
#         ### Merkle-Hellman Knapsack Sage 9.0 Version ###         #
#                                                                  #
#  WARNING: I wrote this program to be compatible with Python 3,   #
#           however the 'break cipher' feature is only fully       #
#           availiable when ran with sagemath to avoid floating    #
#           point limitations. Also, as I'm sure you can tell,     #
#           this cryptosystem can be easily broken and should      #
#           in no circumstance be considered for any serious       #
#           cryptographic use.                                     #
#                                                                  #
####################################################################

from random import SystemRandom
import binascii

# Private key randomness coefficients
# (NOTE: For demonstration purposes I kept these values very low in order to artificially increase the program's efficiency, in an ideal scenario they should be much higher!)
W_RANGE = 10
Q_RANGE = 10

banner = '''-------------------------------
    Merkle-Hellman Knapsack
-------------------------------'''

# Encryption related functions - start

# Verify if public key has a valid length
# (NOTE: in this implementation key-size is always equal to plaintext length in bits)
def verify_publickey(pt,public_key):
    return len("".join(format(ord(c),'b') for c in pt).rjust(len(pt)*8,"0")) == len(public_key)

def encrypt(pt,public_key):
    return str(sum([(int(bin(int(binascii.hexlify(pt.encode()),16))[2:].rjust(len(pt)*8,"0")[i])*public_key[i]) for i in range(0,len(public_key))]))

# (NOTE: in this implementation public key is not permutated)
def gen_keypair(pt_len):
    # Generating Private Key:
    # Generating random superincreasing set w
    w = []
    s = 2
    for _ in range(0,pt_len):
        value = SystemRandom().randrange(s,s+W_RANGE)
        w.append(value)
        s += value
    # Generating q such that q > sum
    q = SystemRandom().randrange(s,s+Q_RANGE)
    # Generating r such that r and q are coprime
    while True:
        r = SystemRandom().randrange(2,q)
        if egcd(r,q)[0] == 1:
            break
    private_key = (w,q,r)
    #Calculating Public Key:
    public_key = [(n*r)%q for n in w]
    return (public_key, private_key)

# Encryption related functions - end


# Auxiliary functions for gcd and modulo inverse calculations

def egcd(a,b):
    if a == 0:
        return (b,0,1)
    g,y,x = egcd(b%a,a)
    return (g,x-(b//a)*y,y)

def modinverse(a,m):
    g,x,y = egcd(a,m)
    if g != 1:
        raise Exception('Something went wrong, modular inverse does not exist')
    return x%m


# Decryption related functions - start

def verify_privatekey(private_key):
    if egcd(private_key[1],private_key[2])[0] != 1:
        print ("\nError: q and r are not coprime!\n")
        return False
    sum = 0
    for i in range(0,len(private_key[0])):
        if private_key[0][i] <= sum:
            print (private_key[0])
            print ("\nError: w is not a superincreasing sequence!\n")
            return False
        sum += private_key[0][i]
    if sum >= private_key[1]:
        print ("\nError: q is not greater than the sum of all elements of w!\n")
        return False
    return True

def decrypt(ct,private_key):
    s = (ct*modinverse(private_key[2],private_key[1]))%private_key[1]
    pt = ""
    for i in range(len(private_key[0])-1,-1,-1):
        if private_key[0][i] <= s:
            s -= private_key[0][i]
            pt += "1"
        else:
            pt += "0"
    return binascii.unhexlify(hex((int(pt[::-1],2)))[2:].encode()).decode()

# Auxiliary functions for vector operations

def vsum(u,v):
    try:
        ret = []
        for i in range(0,len(v)):
            ret.append(v[i]+u[i])
        return ret
    except:
        print ("\nError in vector sum calculation!\n")

def scalar_product(n,v):
    try:
        ret = []
        for i in range(0,len(v)):
            ret.append(n*v[i])
        return ret
    except:
        print ("\nError in vector scalar product calculation!\n")

def dot_product(u,v):
    try:
        ret = 0
        for i in range(0,len(v)):
            ret += v[i]*u[i]
        return ret
    except:
        print ("\nError in vector dot product calculation!\n")

# Cryptanalysis related functions

def GramSchmidt(M):
    try:
        orthG = [M[0]]
        projection_coefficients = {}
        for j in range(1,len(M)):
            orthG.append(M[j])
            for i in range(0,j):
                projection_coefficients[str(i)+str(j)] = (dot_product(orthG[i],M[j]))/(dot_product(orthG[i],orthG[i]))
                orthG[j] = vsum(orthG[j],scalar_product(-1*projection_coefficients[str(i)+str(j)],orthG[i]))
        return (orthG,projection_coefficients)
    except:
        print ("\nError in Gram-Schmidt orthogonalization process!\n")

def LLL(M,d):
    try:
        while True:
            GSoG, GSpc = GramSchmidt(M)
            for j in range(1,len(M)):
                for i in range(j-1,-1,-1):
                    if abs(GSpc[str(i)+str(j)]) > 1/2:
                        M[j] = vsum(M[j],scalar_product(-1*round(GSpc[str(i)+str(j)]),M[i]))
            GSoG, GSpc = GramSchmidt(M)
            try:
                for j in range(0,len(M)-1):
                    tmp0 = vsum(GSoG[j+1],scalar_product(GSpc[str(j)+str(j+1)],GSoG[j]))
                    if dot_product(tmp0,tmp0) < d*(dot_product(GSoG[j],GSoG[j])):
                        tmp1 = M[j]
                        M[j] = M[j+1]
                        M[j+1] = tmp1
                        raise Exception()
                return M
            except:
                continue
    except:
        print ("\nError in LLL reduction calculations!\n")


def break_cipher(ct,public_key):
    try:
        #Converting the knapsack problem into a lattice problem
        #Initializing and setting up the matrix M
        M = [[1 if i==j else 0 for i in range(0,len(public_key))] for j in range(0,len(public_key))]
        for i in range(0,len(public_key)):
            M[i].append(public_key[i])
        M.append([0 for _ in range(0,len(public_key))])
        M[len(public_key)].append(-ct)
        #Find short vectors in the lattice spanned by the columns of M
        short_vectors = LLL(M,0.99)
        print ("\nShort vectors found > " + str(short_vectors))
        flag = 0
        for vector in short_vectors:
            try:
                cur = ""
                for n in vector:
                    cur += str(n)
                    if n != 1 and n != 0:
                        raise Exception()
                print ("\nPossible plaitext found > " + binascii.unhexlify(hex(int(cur[:-1],2))[2:].encode()).decode() + "\n" )
                flag = 1
            except:
                continue

        if not flag:
            print ("\nNo possible plaintexts found using LLL reduction!\n")

    except:
        print ("\nFailed to break Merkle-Hellman knapsack encryption for desired ciphertext!\n")

# Decryption related functions - end


# The Main Function handles user input, menu conditions and the retrieval of information from provided text files

def main():
    while True:
        print (banner)
        try:
            print ("1) Encrypt\n2) Decrypt\n3) Generate Key Pair\n4) Exit")
            op = str(input("> "))
        except:
            print ("Input Error!")

        # Main menu option 1
        if op == "1":
            try:
                pt = str(input("Plaintext to encrypt > "))
                print ("Public key:\n1) Use your own key\n2) Have key files generated for you")
                op1 = str(input("> "))
            except:
                print ("Input Error!")
                continue

            # Encrypt menu option 1
            if op1 == "1":
                try:
                    pub_file =  str(input("Enter the name of your public key file(file should have one number per line)\n> "))
                    public_key = []
                    with open(pub_file,"r") as f:
                        for line in f:
                            if int(line[:-1]) <= 0:
                                raise Exception()
                            public_key.append(int(line[:-1]))
                except:
                    print ("Invalid key error!")
                    continue
                if not verify_publickey(pt,public_key):
                    print("\nInvalid key error!\n")
                    continue

            # Encrypt menu option 2
            elif op1 == "2":
                try:
                    key = gen_keypair(len(pt)*8)
                    print ("\nKey pair generated to encrypt your plaintext:\n\nPublic Key > " + str(key[0]) + "\n\nPrivate Key(w,q,r) > " + str(key[1]))
                    with open("publickey.txt","w") as pub:
                        for n in key[0]:
                            pub.write(str(n) + "\n")
                    with open("privatekey.txt","w") as prv:
                        prv.write("w:\n")
                        for n in key[1][0]:
                            prv.write(str(n) + "\n")
                        prv.write("q:\n")
                        prv.write(str(key[1][1]) + "\n")
                        prv.write("r:\n")
                        prv.write(str(key[1][2]) + "\n")
                    public_key = key[0]
                    print ("\nPublic and Private keys have been saved to 'publickey.txt' and 'privatekey.txt' respectively.\n")
                except:
                    print ("\nInput Error!\n")
            else:
                print ("\nInvalid option!\n")
                continue
            ct = encrypt(pt,public_key)
            print ("\nCiphertext > " + ct + "\n") 


        # Main menu option 2
        elif op == "2":
            try:
                ct = int(input("Ciphertext to decrypt (in decimal) > "))
            except:
                print ("\nInput error!\n")
                continue
            print ("\nPrivate key:\n1) Use your own key\n2) Break Cipher (no private key required)")
            op2 = str(input("> "))

            # Decrypt menu option 1
            if op2 == "1":
                try:
                    prv_file = str(input("\nEnter the name of a private key file:\n> "))
                    values = []
                    with open(prv_file,"r") as prv:
                        for line in prv:
                            if "w:" in line or "q:" in line or "r:" in line:
                                continue
                            if int(line[:-1]) <= 0:
                                raise Exception()
                            values.append(int(line[:-1]))
                    w = values[:-2]
                    q = values[-2:-1][0]
                    r = values[-1:][0]
                    private_key = (w,q,r)
                    if not verify_privatekey(private_key):
                        print ("\nInvalid key error!\n")
                        continue
                    pt = decrypt(ct,private_key)
                    print ("\nPlaintext > " + pt + "\n")
                except:
                    print ("\nInvalid key error!\n")
                    continue


            # Decrypt menu option 2
            elif op2 == "2":
                try:
                    pub_file =  str(input("Enter the name of a public key file\n> "))
                    public_key = []
                    with open(pub_file,"r") as pub:
                        for line in pub:
                            if int(line[:-1]) <= 0:
                                raise Exception()
                            public_key.append(int(line[:-1]))
                except:
                    print ("\nInvalid key error!\n")
                    continue
                break_cipher(ct,public_key)
            else:
                print ("\nInvalid option!\n")
                continue


        # Main menu option 3
        elif op == "3":
            try:
                size = int(input("Enter key size(in bytes):\n> "))
                key = gen_keypair(size*8)
                print ("\nKey pair generated to encrypt your plaintext:\n\nPublic Key > " + str(key[0]) + "\n\nPrivate Key(w,q,r) > " + str(key[1]))
                with open("publickey.txt","w") as pub:
                    for n in key[0]:
                        pub.write(str(n) + "\n")
                with open("privatekey.txt","w") as prv:
                    prv.write("w:\n")
                    for n in key[1][0]:
                        prv.write(str(n) + "\n")
                    prv.write("q:\n")
                    prv.write(str(key[1][1]) + "\n")
                    prv.write("r:\n")
                    prv.write(str(key[1][2]) + "\n")
                print ("\nPublic and Private keys have been saved to 'publickey.txt' and 'privatekey.txt' respectively.\n")
            except:
                print ("\nInput error!\n")
                continue


        # Main menu option 4
        elif op == "4":
            return 0
        else:
            print ("\nInvalid option!\n")

if __name__ == "__main__":
    main()