#! /usr/bin/env python # # Demonstration of the Shamir "Cube" attack from Crypto 2008. # # Version of 2009.09.08. # # Terminology: # # Let's say a set of key bits ks = { k1, k2, k3 } is # "conjugate to" (yoked together with) a set of plaintext # bits ps = { p1, p2, p3 } if the boolean polynomial being # evaluated contains the terms ( k1^k2^k3 )*p1*p2*p3. # # Let's call this relationship between ks and ps a "conjugacy". # # This code demonstrates the Cube attack on a 1-bit keyed # hash function. (What works for 1 bit works for 128 bits.) # The complete preprocessing and attack of a 1-bit keyed hash # function is performed by the subroutine study_function, # defined near the bottom of this file. study_function # calls build_attack_kit to find conjugacies, then sets # up known-answer test cases that are passed to find_key # to demonstrate that find_key correctly finds the vulnerable # bits of the key. # # When you run this code, it builds a list of small, simple, # one-bit functions, and then calls study_function on each # of them. The resulting print-out shows the results of # the analysis of each function and the results of the # attacks on the known-answer test cases. # # Suggested use: # 1. Look at the definition of the Conjugacy class # (search for "class Conjugacy") for a summary of what a # conjugacy is. # 2. Look at the function find_key to see how a conjugacy # helps to find bits of the key. # 3. Look at the function conjugates_of_ps to see a thorough # but slow way of finding conjugacies. # 4. Run this code, and observe what it finds for the suite # of small functions built into it. # 5. Change those small functions to functions of more interest # to you. # 6. If you're interested in real-life applications, modify # the function conjugates_of_ps to make it more practical # for dealing with larger problems. # # This is just a toy program to demonstrate the concept behind # the Cube attack. The code was written to be clear and correct, # not fast. # # This code treats a 1-bit keyed hash function. The extension # to an N-bit cipher or N-bit keyed hash function is fairly obvious, # but don't miss the opportunity for handling all N bits # simultaneously in xor_over_patterns. # # This code was written by Peter Pearson and put into the # public domain. ################################################################## # General-purpose objects: ################################################################## def parity( x ): """Return 1 if the number of 1-bits in x is odd, 0 otherwise. """ assert x >= 0 while x >= 2**32: x = ( x >> 32 ) ^ ( x & ( 2**32 - 1 ) ) for y in ( 16, 8, 4 ): if x >= 2**y: x = ( x >> y ) ^ ( x & ( 2**y - 1 ) ) return ( 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0 )[ x ] def hamming_weight( x ): """Return the number of 1-bits in x. """ return len( list( bits_of( x ) ) ) def test_parity(): for i in xrange( 65536 ): p = 0 x = i while x > 0: p = p^1 x = x&(x-1) assert p == parity( i ) for j in xrange( 32 ): assert p == parity( i << j ) print "Test of parity() completed OK." def bits_of( x ): """This iterator yields one by one the 1-bits in x. Example: >>> for b in bits_of( 6 ): ... print b ... 2 4 """ assert x >= 0 while x > 0: newx = x & ( x - 1 ) yield x & ~ newx x = newx return def all_bit_patterns_in_mask( mask ): """This iterator yields all bit patterns that are unchanged when intersected with mask, which should be a positive integer. Example: >>> for p in all_bit_patterns_in_mask( 5 ): ... print p ... 0 1 4 5 """ assert mask >= 0 if mask > 0: low_bit = mask & ~ ( mask - 1 ) else: low_bit = 1 # (Doesn't matter.) result = 0 while 1: yield result b = low_bit result ^= b while b <= mask and ( result & b ) == 0: while b <= mask: b <<= 1 if ( b & mask ) != 0: break result ^= b if b > mask: return def test_abpim( mask ): """Confirm that all_bit_positions_in_mask returns the right patterns.""" correct = [] for i in xrange( mask + 1 ): if i == ( i & mask ): correct.append( i ) for i in all_bit_patterns_in_mask( mask ): correct.remove( i ) if len( correct ) != 0: print "Failed for mask:", mask, correct def bexplode( x, n ): """Extract n bits of x into a list, LSB first, returning a list of {0,1} values. Example: bexplode( 12, 5 ) returns [0,0,1,1,0]. """ return [ ( x >> i ) & 1 for i in xrange( n ) ] def terms_of_binary_function( bitnames, f_table ): """Return a human-readable string describing a binary function as an XOR combination of terms that are ANDs of bits. Easiest by example: for bitnames = [ "b2", "b1", "b0" ] f_table = int( "01111000", 2 ) we return "b1*b0 ^ b2" The least-significant bit of f_table contains the value of the binary function for the input b0 = b1 = ... = 0. The number of bits in f_table is understood to be 2**len( bitnames ); if shorter, zeros are inferred. """ # Implementation: # # - The terms into which we decompose f(x) are ordered thusly: # 1 x0 x1 x0x1 x2 x0x2 x1x2 x0x1x2 ... # where 1 is considered to be the 0th term. # # - The function f(x) represented by f_table is considered to be # a linear combination, sum_i{ alpha_i f_i(x) }, where f_i(x) # is 1 iff x = 2**i. Thus, alpha_i is just the value of bit i # of f_table. # # - With these definitions, f_i(x) = sum_j{ beta_j term_j }, # where beta_j is 1 if (i|j) == j, 0 otherwise. n_terms = 2**len( bitnames ) # Each bit is present or absent in each term. assert f_table < 2**n_terms # Figure out which terms are included: this_bit = 1 i = 0 term_coeffs = n_terms * [0] while this_bit <= f_table: if ( f_table & this_bit ) != 0: for j in xrange( n_terms ): if ( i | j ) == j: term_coeffs[j] ^= 1 i += 1 this_bit <<= 1 # Figure out the names of the terms: term_names = ["1"] for b in reversed( bitnames ): term_names.extend( [b] + [ "%s*%s" % ( t, b ) for t in term_names[1:] ] ) result = " ^ ".join( [ term_names[i] \ for i in xrange( n_terms ) if term_coeffs[i] ] ) return result ################################################################## # Objects specific to this attack: ################################################################## class KeyedFunction( object ): def __init__( self, name, n_key_bits, n_plaintext_bits, f ): """name is an arbitrary text name. n_key_bits is the number of bits in the key. We assume that any k satisfying 0 <= k < 2**n_key_bits is a legal key. n_plaintext_bits is the number of bits in the plaintext. We assume that any p satisfying 0 <= p < 2**n_plaintext_bits is a legal plaintext. f( key, plaintext ) should return 0 or 1. """ assert n_key_bits > 0 assert n_plaintext_bits > 0 self.name = name self.f = f self.n_key_bits = n_key_bits self.n_plaintext_bits = n_plaintext_bits self.plaintext_bit_mask = ( 1 << n_plaintext_bits ) - 1 self.secret_key = None def all_possible_keys( self ): """Iterator yields all 2**n_key_bits keys. """ for i in xrange( 2**self.n_key_bits ): yield i def set_key( self, key ): """Set the secret key that will be used by evaluate_blind. """ assert 0 <= key < ( 1 << self.n_key_bits ) self.secret_key = key def evaluate( self, key, plaintext ): """Return f( key, plaintext ). """ assert 0 <= plaintext < ( 1 << self.n_plaintext_bits ) if key == None: assert self.secret_key != None return self.f( self.secret_key, plaintext ) assert 0 <= key < ( 1 << self.n_key_bits ) return self.f( key, plaintext ) def evaluate_blind( self, plaintext ): """Return f( self.secret_key, plaintext ). (The point is that the secret key is not explicitly revealed.) """ assert self.secret_key != None assert 0 <= plaintext < ( 1 << self.n_plaintext_bits ) return self.f( self.secret_key, plaintext ) class Conjugacy( object ): """Contains a record of a conjugacy between a set of key bits and a set of plaintext bits. Attributes: ks = a set of key bits, represented as a bitmap. c = the constant term. ps = a set of plaintext bits, represented as a bitmap. The defining property is this: Given the 1-bit function f(k,p) for which this conjugacy holds, for any key k, c = parity( ks & k ) ^ xor_over_patterns( f, k, ps ). """ def __init__( self, ks, c, ps ): self.ks = ks self.c = c self.ps = ps def xor_over_patterns( f, k, p_bitsubset ): """Return the xor of f(k,p) for all p composed entirely of bits in p_bitsubset. """ result = 0 for dp in all_bit_patterns_in_mask( p_bitsubset ): result ^= f( k, dp ) return result def conjugates_of_ps( kf, ps ): """Return ( ks, c ), where ks is a subset of the key bits of keyed function kf that is conjugate to plaintext bit set ps with constant term c. """ assert ps & kf.plaintext_bit_mask == ps # Find the (possibly empty) set of key bits that are # conjugate to this set (ps) of plaintext bits. # Say we have a bit-map ps specifying a subset of plaintext bits. # Define x( k ) = xor_over_patterns( f, k, ps ). # Then we say the set of key bits ks is conjugate to ps # if there is some constant c in {0,1} such that # x( k ) = c ^ parity( k & ks ) for all k. # Implementation note: # Finding conjugate sets of key bits by exhaustive search is # impractical in realistic situations. We do it this way for # clarity and correctness. A practical implementation would # need to be much more clever. for ks in kf.all_possible_keys(): c = None for k in kf.all_possible_keys(): tc = xor_over_patterns( kf.evaluate, k, ps ) ^ parity( k & ks ) if c == None: c = tc elif tc != c: break # This ks doesn't work. else: # (Executed if we don't break out of the for-loop.) return ks, c return 0, 0 def build_attack_kit( kf ): """Find the conjugacy relationships for the keyed function kf. Return a list of conjugacies. """ result = [] keybits_found = set() # Consider every subset of plaintext bits except the # empty set: for ps in kf.all_possible_keys(): if ps == 0: continue ks, c = conjugates_of_ps( kf, ps ) if ks: result.append( Conjugacy( ks, c, ps ) ) result.sort() return result def find_key( f, attacker_kit ): """Find as much as we can of the secret key used by f, given the "attacker kit" returned by build_attack_kit. Return whatever bits of the key we find. """ result = 0 for a in attacker_kit: c = a.c for dp in all_bit_patterns_in_mask( a.ps ): c ^= f( dp ) if c != 0: result |= a.ks return result def terms_of_keyed_function( kf ): """Return a human-readable string describing the keyed function as an XOR combination of terms that are ANDs of bits. """ bitnames = [ "p%02d" % i for i in xrange( kf.n_plaintext_bits ) ] + \ [ "k%02d" % i for i in xrange( kf.n_key_bits ) ] bitnames.reverse() f_table = 0 b = 1 for k in kf.all_possible_keys(): for p in xrange( 2**kf.n_plaintext_bits ): if kf.evaluate( k, p ): f_table |= b b <<= 1 return terms_of_binary_function( bitnames, f_table ) def print_kit( attacker_kit ): print "ks c ps" print "-- - --" for a in attacker_kit: print "%02x %d %02x" % ( a.ks, a.c, a.ps ) def simplify_kit( attacker_kit ): """Return a kit consisting of only the "simple" entries in the given attacker kit. """ # If two entries give the same key bits, use the one # with the smaller number of plaintext bits to be toggled. d = {} for a in attacker_kit: if a.ks not in d \ or hamming_weight( a.ps ) < hamming_weight( d[ a.ks ].ps ): d[ a.ks ] = a # Don't bother with key-bit masks that have more than one # bit. (This is a simple demonstration program. Let's not # get into solving linear equations.) announced_single_bit_policy = False result = [] for a in d.itervalues(): if hamming_weight( a.ks ) > 1: if not announced_single_bit_policy: announced_single_bit_policy = True print "For simplicity, we're ignoring ks masks of more than 1 bit." else: result.append( a ) if len( result ) < len( attacker_kit ): print "Simplified attacker kit:" print_kit( result ) return result def study_function( kf ): """Analyse the given function: build an attacker kit, and demonstrate its effectivness on various secret keys. """ print "\nStudying the function \"%s\":" % kf.name #print "Here is %s broken into terms:" % kf.name #print terms_of_keyed_function( kf ) # Build an "attacker kit", which consists of triplets ( ks, c, ps ), # where ks is a bit mask of key bits, c is 0 or 1, ps is a bit # mask of plaintext bits, and (the interesting part) for every key # k, the sum (mod 2) of f( k, p ) over all p such that p&ps == p # is c + parity( k & ks ). attacker_kit = build_attack_kit( kf ) if len( attacker_kit ) < 1: print "Nothing was found for the attacker kit." return print "Here is the attacker kit we have assembled:" print_kit( attacker_kit ) attacker_kit = simplify_kit( attacker_kit ) if len( attacker_kit ) < 1: print "All conjugacies in the attacker kit were too complex" print "for this simple demo program." return key_bits_claimed = 0 for a in attacker_kit: key_bits_claimed |= a.ks # Prove that this attacker kit works: keys_tried = 0 keys_succeeded = 0 for secret_k in kf.all_possible_keys(): keys_tried += 1 kf.set_key( secret_k ) k_found = find_key( kf.evaluate_blind, attacker_kit ) assert ( k_found & key_bits_claimed ) == k_found if k_found == ( secret_k & key_bits_claimed ): keys_succeeded += 1 else: print "Failure for key = %02x:" % secret_k for kbit in bits_of( key_bits_claimed ): print "For bit 0x%02x, we got 0x%02x, right answer is 0x%02x." % \ ( kbit, k_found & kbit, secret_k & kbit ) print "The key bits in the mask %02x were successfully extracted %d times ouf of %d." % ( key_bits_claimed, keys_succeeded, keys_tried ) ################################################################## if __name__ == "__main__": # We demonstrate the Cube Attack on a list of simple one-bit # functions. In the following lines of code, we define the # list of functions, then we pass the functions one by one # to the study_function subroutine, which completes the # demonstration on each function individually. # Choose some absurdly simple one-bit function of key k and plaintext p. # First, some tables used in some of our one-bit functions: des_sbox_1 = ( 14, 0, 4, 15, 13, 7, 1, 4, 2, 14, 15, 2, 11, 13, 8, 1, 3, 10, 10, 6, 6, 12, 12, 11, 5, 9, 9, 5, 0, 3, 7, 8, 4, 15, 1, 12, 14, 8, 8, 2, 13, 4, 6, 9, 2, 1, 11, 7, 15, 5, 12, 11, 9, 3, 7, 14, 3, 10, 10, 0, 5, 6, 0, 13 ) shuffled16a = ( 1,1,0,1,1,0,1,1,0,0,0,1,1,0,0,0 ) shuffled16b = ( 0,1,0,1,0,0,1,1,1,0,0,0,1,0,1,1 ) # Here's a one-bit function whose terms are explicit: def f_explicit( k, p ): k = bexplode( k, 4 ) p = bexplode( p, 4 ) # # Finds one key bit: result = sum( ( k[1]*p[0]*p[2]*p[3], # Find this term. k[1]*k[2]*p[0]*p[1], # No: two k's. k[0]*k[3]*p[0]*p[1]*p[2], # No: two k's. k[3]*p[1]*p[2], # Find this term. ) ) & 1 return result # Here's a whole table of one-bit functions: various_functions = ( \ # Finds 2 key bits: KeyedFunction( "f_explicit", 4, 4, f_explicit ), # Uninteresting: all terms have multiple k's, so no usable # conjugacies are found: KeyedFunction( "256 random bits", 4, 4, lambda k, p: \ ( 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1 )[ ( k << 4 ) + p ] ), # Finds 3 key bits: KeyedFunction( "((k + p) >> 3 ) & 1", 4, 4, lambda k, p : ((k + p) >> 3 ) & 1 ), # Finds 3 key bits: KeyedFunction( "((k - p) >> 3 ) & 1", 4, 4, lambda k, p : ((k - p) >> 3 ) & 1 ), # Finds 1 key bit: KeyedFunction( "( k & p ) & 1", 4, 4, lambda k, p : ( k & p ) & 1 ), # Finds 1 key bit: KeyedFunction( "( k & p & ( p >> 1 ) ) & 1", 4, 4, lambda k, p : ( k & p & ( p >> 1 ) ) & 1 ), # Finds 3 key bits: KeyedFunction( "( ( k * p ) >> 2 ) & 1", 4, 4, lambda k, p : ( ( k * p ) >> 2 ) & 1 ), # Four conjugacies, but all multi-bit: KeyedFunction( "(shuffled16a)[ k ^ p ]", 4, 4, lambda k, p : shuffled16a[ k ^ p ] ), # Finds 2 key bits: KeyedFunction( "shuffled16b[ k ^ p ]", 4, 4, lambda k, p : shuffled16b[ k ^ p ] ), # Finds 2 key bits: KeyedFunction( "shuffled16b[ (k+p) % 16 ]", 4, 4, lambda k, p : shuffled16b[ (k+p) % 16 ] ), # Finds 3 key bits: KeyedFunction( "sbox[k^p] bit 0", 6, 6, lambda k, p : des_sbox_1[ k ^ p ] & 1 ), # Finds 2 key bits: KeyedFunction( "sbox[k^p] bit 1", 6, 6, lambda k, p : ( des_sbox_1[ k ^ p ] >> 1 ) & 1 ), # Finds 2 key bits: KeyedFunction( "sbox[k^p] bit 2", 6, 6, lambda k, p : ( des_sbox_1[ k ^ p ] >> 2 ) & 1 ), # Finds 4 key bits: KeyedFunction( "sbox[k^p] bit 3", 6, 6, lambda k, p : ( des_sbox_1[ k ^ p ] >> 3 ) & 1 ), ) for kf in various_functions: study_function( kf )