Differential Cryptanalysis
I’ll first explain how the attack works and then present a challenge from Unbreakable Teams that no one solved during the contest (minidiff)
Intro to Cryptanalysis
PS: I tried to explain a more general case, if this doesnt make sense skip to the example and read them side by side.
Imagine you had a cipher of the following format
plaintext -> addroundkey0 -> sbox -> addroundkey1 -> sbox -> addroundkey2 -> ciphertext
If we were to run two plaintexts through the cipher and get their ciphertexts, the input differential would be pt0 ^ pt1 and the output differential would be ct0 ^ ct1.
Now we will pick one input differential and take all its resulting output differentials to check if the cipher is biased towards any of those. For example, the characteristic 3 -> 5 -> 6 could hold for 8/16 of the possible plaintexts at each step (the 2nd differential is the one between the two sbox operations).
We can use the above differential to recover the key with less brute force. By checking just the possible values that could have produced each characteristic, we limit the keyspace to 8 * 8 * 8 instead of 16 * 16 * 16. This also means we need to find a pair that respects the differential, which would take some tries. This would be much faster than the brute force alternative (it will make more sense in the example).
Example
We will be analyzing the following cipher
class MiniDiffCipher:
def __init__(self, round_keys):
self.round_keys = round_keys
self.sbox = [7, 10, 13, 0, 3, 6, 9, 12, 2, 15, 5, 8, 11, 14, 1, 4]
def substitute(self, state):
bin_state = format(state, "024b")
out_bits = [int(x) for x in bin_state]
for i in range(24 - 4 + 1):
window = out_bits[i : i + 4]
window_val = int("".join(str(x) for x in window), 2)
sub_val = self.sbox[window_val]
sub_bin = format(sub_val, "04b")
for j in range(4):
out_bits[i + j] = int(sub_bin[j])
out_str = "".join(str(bit) for bit in out_bits)
return int(out_str, 2)
def permute(self, state):
nibbles = [(state >> ((5 - i) * 4)) & 0xF for i in range(6)]
permuted = [nibbles[2], nibbles[5], nibbles[0], nibbles[3], nibbles[4], nibbles[1]]
result = 0
for nib in permuted:
result = (result << 4) | nib
return result
def encrypt_block(self, plaintext):
state = plaintext ^ self.round_keys[0]
for r in range(1, 4):
state = self.substitute(state)
state = self.permute(state)
state ^= self.round_keys[r]
state = self.substitute(state)
state ^= self.round_keys[4]
return state
def permute(self, state):
nibbles = [(state >> ((5 - i) * 4)) & 0xF for i in range(6)]
permuted = [nibbles[2], nibbles[5], nibbles[0], nibbles[3], nibbles[4], nibbles[1]]
result = 0
for nib in permuted:
result = (result << 4) | nib
return result
First, we have to find a valid differential. Since the plaintext space is very big (2**24), we can’t just try all possible differentials, as it would take ages, so I decided to test ones that modified just a few bits (0x1, 0x2, 0x3, 0x4, …, 0xf, 0x10, 0x20, …)
To do this, I implemented the cipher in C and then made a helper to check the input differentials I wanted to test. It is attached here
We first have to attack the last key (because we can’t just see differentials in the middle of the cipher), so we will try to find high-chance differentials for it. Some high chance ones I found and used were
0x800000: Top 30 output differences:
1) 0x048020 occurred 535488 times (3.1918%)
0x400000: Top 30 output differences:
1) 0x024010 occurred 688214 times (4.1021%)
0xC00000: Top 30 output differences:
1) 0x008000 occurred 540672 times (3.2227%)
I had to run multiple small sweeps of the keyspace until it was shortened because my laptop couldn’t handle a lot of operations, even with multiprocessing.
I first used the 0x400000 -> 0x024010 differential (function brute_key4). To brute force, for each of the keys, in this case the whole keyspace, we pick count plaintexts and their differential pair, and also get their encryption. Now for that key guess, we invert the cipher until the round before and check if the differential with max probability holds. Lastly, we just save the remaining candidates for the next sweep. Example code snippet in case this is confusing
#define TRIAL_OFFSET 0x400000
#define TARGET_DIFF_K4 0x024010
void brute_key4(uint32_t count, const uint32_t *round_keys) {
uint32_t *local_cands = malloc(((KEY_SPACE+1)/nthreads + 1) * sizeof *local_cands);
size_t local_count = 0;
// precompute the random plaintexts we are going to use
unsigned int seed = (unsigned int)time(NULL) ^ (tid * 0x9e3779b9);
uint32_t xs[count], ys[count];
for (uint32_t i = 0; i < count; ++i) {
xs[i] = rand_r(&seed) & KEY_SPACE;
ys[i] = xs[i] ^ TRIAL_OFFSET;
}
// test all keys
for (uint32_t k4 = 0; k4 <= KEY_SPACE; ++k4) {
uint32_t diffs[count], cnts[count];
uint32_t nunique = 0;
// for each key, we test the plaintexts I generated, data contains the ciphertexts for the provided plaintexts
for (uint32_t i = 0; i < count; ++i) {
// we decrypt until the round before with the key we are testing
uint32_t ct1 = substitute_inv(data[xs[i]] ^ k4);
uint32_t ct2 = substitute_inv(data[ys[i]] ^ k4);
uint32_t diff = ct1 ^ ct2;
uint32_t j;
for (j = 0; j < nunique; ++j) {
if (diffs[j] == diff) { cnts[j]++; break; }
}
if (j == nunique) {
diffs[nunique] = diff;
cnts[nunique] = 1;
nunique++;
}
}
// here I required just 1 hit of the target differential, I could have added more, but this did best in testing
bool hit = false;
for (uint32_t j = 0; j < nunique; ++j) {
if (diffs[j] == TARGET_DIFF_K4) {
hit = true;
break;
}
}
if (hit) {
local_cands[local_count++] = k4;
}
}
k4_candidates = candidates;
num_candidates = local_cands;
}
Now, combining multiple differential pairs with high probability, we can recover the k4 or at least a very small number of options for k4.
Here is some code where you can test this yourself. You can add the parameters you want to try in the following arrays
int input_diffs[] = {
0x800000,
};
int output_diffs[][5] = {
{0x048020},
};
int lengths[] = {
1,
};
int counts[] = {
200 * 20,
};
int thresholds[] = {
1,
};
Using the k4 candidates that are left, we can then do the same process to recover k3, then k2, then k1, and lastly k0. These should be way easier to recover as they have fewer passes through the sbox, leading to less randomness.
The exercise to figure these values and also create a working script for k3, k2, k1, and k0 is left up to the reader, so they can test their understanding of this topic.
Hope you learned something new :D