fsveritysetup.py: allow specifying the hash algorithm

Signed-off-by: Eric Biggers <ebiggers@google.com>
diff --git a/fsveritysetup.py b/fsveritysetup.py
index ab0106e..692304c 100755
--- a/fsveritysetup.py
+++ b/fsveritysetup.py
@@ -13,14 +13,64 @@
 import subprocess
 import sys
 import tempfile
+import zlib
 
 DATA_BLOCK_SIZE = 4096
 HASH_BLOCK_SIZE = 4096
-HASH_ALGORITHM = 'sha256'
 FS_VERITY_MAGIC = b'TrueBrew'
 FS_VERITY_SALT_SIZE = 8
 FS_VERITY_EXT_ELIDE = 0
 FS_VERITY_EXT_PATCH = 1
+FS_VERITY_ALG_CRC32 = 0
+FS_VERITY_ALG_SHA256 = 1
+
+
+class CRC32Hash(object):
+  """hashlib-compatible wrapper for zlib.crc32()."""
+
+  digest_size = 4
+
+  # Big endian, to be compatible with veritysetup --hash=crc32, which uses
+  # libgcrypt, which uses big endian CRC-32.
+  class Digest(ctypes.BigEndianStructure):
+    _fields_ = [('remainder', ctypes.c_uint32)]
+
+  def __init__(self, remainder=0):
+    self.remainder = remainder
+
+  def update(self, string):
+    self.remainder = zlib.crc32(bytes(string), self.remainder)
+
+  def digest(self):
+    digest = CRC32Hash.Digest()
+    digest.remainder = self.remainder
+    return serialize_struct(digest)
+
+  def hexdigest(self):
+    return binascii.hexlify(self.digest()).decode('ascii')
+
+  def copy(self):
+    return CRC32Hash(self.remainder)
+
+
+class HashAlgorithm(object):
+
+  def __init__(self, code, name, digest_size):
+    self.code = code
+    self.name = name
+    self.digest_size = digest_size
+
+  def create(self):
+    if self.name == 'crc32':
+      return CRC32Hash()
+    else:
+      return hashlib.new(self.name)
+
+
+HASH_ALGORITHMS = [
+    HashAlgorithm(FS_VERITY_ALG_CRC32, 'crc32', 4),
+    HashAlgorithm(FS_VERITY_ALG_SHA256, 'sha256', 32),
+]
 
 
 class fsverity_header(ctypes.LittleEndianStructure):
@@ -105,9 +155,9 @@
   return bytes(ctypes.string_at(ctypes.pointer(struct), ctypes.sizeof(struct)))
 
 
-def veritysetup(data_filename, tree_filename, salt):
+def veritysetup(data_filename, tree_filename, salt, algorithm):
   """Built-in Merkle tree generation algorithm."""
-  salted_hash = hashlib.new(HASH_ALGORITHM)
+  salted_hash = algorithm.create()
   salted_hash.update(salt)
   hashes_per_block = HASH_BLOCK_SIZE // salted_hash.digest_size
   level_blocks = [os.stat(data_filename).st_size // DATA_BLOCK_SIZE]
@@ -215,11 +265,12 @@
 class FSVerityGenerator(object):
   """Sets up a file for fs-verity."""
 
-  def __init__(self, in_filename, out_filename, salt, **kwargs):
+  def __init__(self, in_filename, out_filename, salt, algorithm, **kwargs):
     self.in_filename = in_filename
     self.original_size = os.stat(in_filename).st_size
     self.out_filename = out_filename
     self.salt = salt
+    self.algorithm = algorithm
     assert len(salt) == FS_VERITY_SALT_SIZE
 
     self.extensions = kwargs.get('extensions')
@@ -297,7 +348,8 @@
       tree_filename = f.name
 
     if self.builtin_veritysetup:
-      root_hash = veritysetup(data_filename, tree_filename, self.salt)
+      root_hash = veritysetup(data_filename, tree_filename, self.salt,
+                              self.algorithm)
     else:
       # Delegate to 'veritysetup' to actually build the Merkle tree.
       cmd = [
@@ -307,7 +359,7 @@
           tree_filename,
           '--salt=' + binascii.hexlify(self.salt).decode('ascii'),
           '--no-superblock',
-          '--hash={}'.format(HASH_ALGORITHM),
+          '--hash={}'.format(self.algorithm.name),
           '--data-block-size={}'.format(DATA_BLOCK_SIZE),
           '--hash-block-size={}'.format(HASH_BLOCK_SIZE),
       ]
@@ -332,10 +384,9 @@
     header.maj_version = 1
     header.min_version = 0
     header.log_blocksize = ilog2(DATA_BLOCK_SIZE)
-    assert HASH_ALGORITHM == 'sha256'
-    header.log_arity = ilog2(DATA_BLOCK_SIZE / 32)  # sha256 hash size
-    header.meta_algorithm = 1  # sha256
-    header.data_algorithm = 1  # sha256
+    header.log_arity = ilog2(DATA_BLOCK_SIZE / self.algorithm.digest_size)
+    header.meta_algorithm = self.algorithm.code
+    header.data_algorithm = self.algorithm.code
     header.size = self.original_size
     header.extension_count = len(self.extensions)
     header.salt = self.salt
@@ -389,7 +440,7 @@
         outfile.write(serialize_struct(hdr_offset))
 
         # Compute the fs-verity measurement.
-        measurement = hashlib.new(HASH_ALGORITHM)
+        measurement = self.algorithm.create()
         measurement.update(header)
         measurement.update(extensions)
         measurement.update(binascii.unhexlify(root_hash))
@@ -400,6 +451,15 @@
     return (measurement, root_hash)
 
 
+def convert_hash_argument(argstring):
+  for alg in HASH_ALGORITHMS:
+    if alg.name == argstring:
+      return alg
+  raise argparse.ArgumentTypeError(
+      'Unrecognized algorithm: "{}".  Choices are: {}'.format(
+          argstring, [alg.name for alg in HASH_ALGORITHMS]))
+
+
 def convert_salt_argument(argstring):
   try:
     b = binascii.unhexlify(argstring)
@@ -463,6 +523,12 @@
       help='{}-byte salt, given as a {}-character hex string'.format(
           FS_VERITY_SALT_SIZE, FS_VERITY_SALT_SIZE * 2))
   parser.add_argument(
+      '--hash',
+      type=convert_hash_argument,
+      default='sha256',
+      help="""Hash algorithm to use.  Available algorithms: {}.
+            Default is sha256.""".format([alg.name for alg in HASH_ALGORITHMS]))
+  parser.add_argument(
       '--patch',
       metavar='<offset,patchfile>',
       type=convert_patch_argument,
@@ -499,6 +565,7 @@
         args.in_filename,
         args.out_filename,
         args.salt,
+        args.hash,
         extensions=args.extensions,
         builtin_veritysetup=args.builtin_veritysetup)
   except BadExtensionListError as e: