Merge "Save KeyChainSnapshots to disk" into pi-dev
am: 02c0bbdc47

Change-Id: I603452c6ea44eec44d5c03fab4130034acc7d11d
diff --git a/services/core/java/com/android/server/locksettings/recoverablekeystore/RecoverableKeyStoreManager.java b/services/core/java/com/android/server/locksettings/recoverablekeystore/RecoverableKeyStoreManager.java
index 8a79e4c..77d7c3c 100644
--- a/services/core/java/com/android/server/locksettings/recoverablekeystore/RecoverableKeyStoreManager.java
+++ b/services/core/java/com/android/server/locksettings/recoverablekeystore/RecoverableKeyStoreManager.java
@@ -31,7 +31,6 @@
 import android.app.PendingIntent;
 import android.content.Context;
 import android.os.Binder;
-import android.os.Process;
 import android.os.RemoteException;
 import android.os.ServiceSpecificException;
 import android.os.UserHandle;
@@ -128,7 +127,7 @@
                     db,
                     new RecoverySessionStorage(),
                     Executors.newSingleThreadExecutor(),
-                    new RecoverySnapshotStorage(),
+                    RecoverySnapshotStorage.newInstance(),
                     new RecoverySnapshotListenersStorage(),
                     platformKeyManager,
                     applicationKeyStorage);
diff --git a/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotDeserializer.java b/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotDeserializer.java
index dcaa0b4..f789155 100644
--- a/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotDeserializer.java
+++ b/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotDeserializer.java
@@ -23,6 +23,9 @@
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_ALIAS;
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_APPLICATION_KEY;
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_APPLICATION_KEYS;
+
+import static com.android.server.locksettings.recoverablekeystore.serialization
+        .KeyChainSnapshotSchema.TAG_BACKEND_PUBLIC_KEY;
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_COUNTER_ID;
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_RECOVERY_KEY_MATERIAL;
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_KEY_CHAIN_PROTECTION_PARAMS;
@@ -128,6 +131,11 @@
                     }
                     break;
 
+                case TAG_BACKEND_PUBLIC_KEY:
+                    builder.setTrustedHardwarePublicKey(
+                            readBlobTag(parser, TAG_BACKEND_PUBLIC_KEY));
+                    break;
+
                 case TAG_KEY_CHAIN_PROTECTION_PARAMS_LIST:
                     builder.setKeyChainProtectionParams(readKeyChainProtectionParamsList(parser));
                     break;
diff --git a/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSchema.java b/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSchema.java
index ee8b2cf..ff30ecd 100644
--- a/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSchema.java
+++ b/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSchema.java
@@ -35,6 +35,7 @@
     static final String TAG_RECOVERY_KEY_MATERIAL = "recoveryKeyMaterial";
     static final String TAG_SERVER_PARAMS = "serverParams";
     static final String TAG_TRUSTED_HARDWARE_CERT_PATH = "thmCertPath";
+    static final String TAG_BACKEND_PUBLIC_KEY = "backendPublicKey";
 
     static final String TAG_KEY_CHAIN_PROTECTION_PARAMS_LIST =
             "keyChainProtectionParamsList";
diff --git a/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSerializer.java b/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSerializer.java
index f817a8f..17a16bf 100644
--- a/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSerializer.java
+++ b/services/core/java/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSerializer.java
@@ -24,6 +24,9 @@
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_ALIAS;
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_APPLICATION_KEY;
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_APPLICATION_KEYS;
+
+import static com.android.server.locksettings.recoverablekeystore.serialization
+        .KeyChainSnapshotSchema.TAG_BACKEND_PUBLIC_KEY;
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_COUNTER_ID;
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_RECOVERY_KEY_MATERIAL;
 import static com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSchema.TAG_KEY_CHAIN_PROTECTION_PARAMS;
@@ -159,6 +162,10 @@
         writePropertyTag(xmlSerializer, TAG_SERVER_PARAMS, keyChainSnapshot.getServerParams());
         writePropertyTag(xmlSerializer, TAG_TRUSTED_HARDWARE_CERT_PATH,
                 keyChainSnapshot.getTrustedHardwareCertPath());
+        if (keyChainSnapshot.getTrustedHardwarePublicKey() != null) {
+            writePropertyTag(xmlSerializer, TAG_BACKEND_PUBLIC_KEY,
+                    keyChainSnapshot.getTrustedHardwarePublicKey());
+        }
     }
 
     private static void writePropertyTag(
diff --git a/services/core/java/com/android/server/locksettings/recoverablekeystore/storage/RecoverySnapshotStorage.java b/services/core/java/com/android/server/locksettings/recoverablekeystore/storage/RecoverySnapshotStorage.java
index 3f93cc6..c02b103 100644
--- a/services/core/java/com/android/server/locksettings/recoverablekeystore/storage/RecoverySnapshotStorage.java
+++ b/services/core/java/com/android/server/locksettings/recoverablekeystore/storage/RecoverySnapshotStorage.java
@@ -17,13 +17,28 @@
 package com.android.server.locksettings.recoverablekeystore.storage;
 
 import android.annotation.Nullable;
+import android.os.Environment;
 import android.security.keystore.recovery.KeyChainSnapshot;
+import android.util.Log;
 import android.util.SparseArray;
 
 import com.android.internal.annotations.GuardedBy;
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.server.locksettings.recoverablekeystore.serialization
+        .KeyChainSnapshotDeserializer;
+import com.android.server.locksettings.recoverablekeystore.serialization
+        .KeyChainSnapshotParserException;
+import com.android.server.locksettings.recoverablekeystore.serialization.KeyChainSnapshotSerializer;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.security.cert.CertificateEncodingException;
+import java.util.Locale;
 
 /**
- * In-memory storage for recovery snapshots.
+ * Storage for recovery snapshots. Stores snapshots in memory, backed by disk storage.
  *
  * <p>Recovery snapshots are generated after a successful screen unlock. They are only generated if
  * the recoverable keystore has been mutated since the previous snapshot. This class stores only the
@@ -33,14 +48,46 @@
  * {@link com.android.server.locksettings.recoverablekeystore.KeySyncTask} thread.
  */
 public class RecoverySnapshotStorage {
+
+    private static final String TAG = "RecoverySnapshotStorage";
+
+    private static final String ROOT_PATH = "system";
+    private static final String STORAGE_PATH = "recoverablekeystore/snapshots/";
+
     @GuardedBy("this")
     private final SparseArray<KeyChainSnapshot> mSnapshotByUid = new SparseArray<>();
 
+    private final File rootDirectory;
+
+    /**
+     * A new instance, storing snapshots in /data/system/recoverablekeystore/snapshots.
+     *
+     * <p>NOTE: calling this multiple times DOES NOT return the same instance, so will NOT be backed
+     * by the same in-memory store.
+     */
+    public static RecoverySnapshotStorage newInstance() {
+        return new RecoverySnapshotStorage(
+                new File(Environment.getDataDirectory(), ROOT_PATH));
+    }
+
+    @VisibleForTesting
+    public RecoverySnapshotStorage(File rootDirectory) {
+        this.rootDirectory = rootDirectory;
+    }
+
     /**
      * Sets the latest {@code snapshot} for the recovery agent {@code uid}.
      */
     public synchronized void put(int uid, KeyChainSnapshot snapshot) {
         mSnapshotByUid.put(uid, snapshot);
+
+        try {
+            writeToDisk(uid, snapshot);
+        } catch (IOException | CertificateEncodingException e) {
+            Log.e(TAG,
+                    String.format(Locale.US, "Error persisting snapshot for %d to disk", uid),
+                    e);
+        }
     }
 
     /**
@@ -48,7 +95,17 @@
      */
     @Nullable
     public synchronized KeyChainSnapshot get(int uid) {
-        return mSnapshotByUid.get(uid);
+        KeyChainSnapshot snapshot = mSnapshotByUid.get(uid);
+        if (snapshot != null) {
+            return snapshot;
+        }
+
+        try {
+            return readFromDisk(uid);
+        } catch (IOException | KeyChainSnapshotParserException e) {
+            Log.e(TAG, String.format(Locale.US, "Error reading snapshot for %d from disk", uid), e);
+            return null;
+        }
     }
 
     /**
@@ -56,5 +113,66 @@
      */
     public synchronized void remove(int uid) {
         mSnapshotByUid.remove(uid);
+        getSnapshotFile(uid).delete();
+    }
+
+    /**
+     * Writes the snapshot for recovery agent {@code uid} to disk.
+     *
+     * @throws IOException if an IO error occurs writing to disk.
+     */
+    private void writeToDisk(int uid, KeyChainSnapshot snapshot)
+            throws IOException, CertificateEncodingException {
+        File snapshotFile = getSnapshotFile(uid);
+
+        try (
+            FileOutputStream fileOutputStream = new FileOutputStream(snapshotFile)
+        ) {
+            KeyChainSnapshotSerializer.serialize(snapshot, fileOutputStream);
+        } catch (IOException | CertificateEncodingException e) {
+            // If we fail to write the latest snapshot, we should delete any older snapshot that
+            // happens to be around. Otherwise snapshot syncs might end up going 'back in time'.
+            snapshotFile.delete();
+            throw e;
+        }
+    }
+
+    /**
+     * Reads the last snapshot for recovery agent {@code uid} from disk.
+     *
+     * @return The snapshot, or null if none existed.
+     * @throws IOException if an IO error occurs reading from disk.
+     */
+    @Nullable
+    private KeyChainSnapshot readFromDisk(int uid)
+            throws IOException, KeyChainSnapshotParserException {
+        File snapshotFile = getSnapshotFile(uid);
+
+        try (
+            FileInputStream fileInputStream = new FileInputStream(snapshotFile)
+        ) {
+            return KeyChainSnapshotDeserializer.deserialize(fileInputStream);
+        } catch (IOException | KeyChainSnapshotParserException e) {
+            // If we fail to read the latest snapshot, we should delete it in case it is in some way
+            // corrupted. We can regenerate snapshots anyway.
+            snapshotFile.delete();
+            throw e;
+        }
+    }
+
+    private File getSnapshotFile(int uid) {
+        File folder = getStorageFolder();
+        String fileName = getSnapshotFileName(uid);
+        return new File(folder, fileName);
+    }
+
+    private String getSnapshotFileName(int uid) {
+        return String.format(Locale.US, "%d.xml", uid);
+    }
+
+    private File getStorageFolder() {
+        File folder = new File(rootDirectory, STORAGE_PATH);
+        folder.mkdirs();
+        return folder;
     }
 }
diff --git a/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/KeySyncTaskTest.java b/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/KeySyncTaskTest.java
index 9ae45ea..81a73efd 100644
--- a/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/KeySyncTaskTest.java
+++ b/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/KeySyncTaskTest.java
@@ -38,6 +38,7 @@
 import static org.mockito.Mockito.when;
 
 import android.content.Context;
+import android.os.FileUtils;
 import android.security.keystore.AndroidKeyStoreSecretKey;
 import android.security.keystore.KeyGenParameterSpec;
 import android.security.keystore.KeyProperties;
@@ -49,7 +50,6 @@
 import android.support.test.filters.SmallTest;
 import android.support.test.runner.AndroidJUnit4;
 
-import android.util.Log;
 import com.android.server.locksettings.recoverablekeystore.storage.RecoverableKeyStoreDb;
 import com.android.server.locksettings.recoverablekeystore.storage.RecoverySnapshotStorage;
 
@@ -72,6 +72,9 @@
 @SmallTest
 @RunWith(AndroidJUnit4.class)
 public class KeySyncTaskTest {
+
+    private static final String SNAPSHOT_TOP_LEVEL_DIRECTORY = "recoverablekeystore";
+
     private static final String KEY_ALGORITHM = "AES";
     private static final String ANDROID_KEY_STORE_PROVIDER = "AndroidKeyStore";
     private static final String TEST_ROOT_CERT_ALIAS = "trusted_root";
@@ -117,7 +120,7 @@
                 TEST_ROOT_CERT_ALIAS);
         mRecoverableKeyStoreDb.setActiveRootOfTrust(TEST_USER_ID, TEST_RECOVERY_AGENT_UID2,
                 TEST_ROOT_CERT_ALIAS);
-        mRecoverySnapshotStorage = new RecoverySnapshotStorage();
+        mRecoverySnapshotStorage = new RecoverySnapshotStorage(context.getFilesDir());
 
         mKeySyncTask = new KeySyncTask(
                 mRecoverableKeyStoreDb,
@@ -139,6 +142,10 @@
     public void tearDown() {
         mRecoverableKeyStoreDb.close();
         mDatabaseFile.delete();
+
+        File file = new File(InstrumentationRegistry.getTargetContext().getFilesDir(),
+                SNAPSHOT_TOP_LEVEL_DIRECTORY);
+        FileUtils.deleteContentsAndDir(file);
     }
 
     @Test
diff --git a/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSerializerTest.java b/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSerializerTest.java
index 6c2958e..2f4da86 100644
--- a/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSerializerTest.java
+++ b/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/serialization/KeyChainSnapshotSerializerTest.java
@@ -45,6 +45,7 @@
     private static final int MAX_ATTEMPTS = 21;
     private static final byte[] SERVER_PARAMS = new byte[] { 8, 2, 4 };
     private static final byte[] KEY_BLOB = new byte[] { 124, 53, 53, 53 };
+    private static final byte[] PUBLIC_KEY_BLOB = new byte[] { 6, 6, 6, 6, 6, 6, 7 };
     private static final CertPath CERT_PATH = TestData.CERT_PATH_1;
     private static final int SECRET_TYPE = KeyChainProtectionParams.TYPE_LOCKSCREEN;
     private static final int LOCK_SCREEN_UI = KeyChainProtectionParams.UI_FORMAT_PASSWORD;
@@ -93,6 +94,11 @@
     }
 
     @Test
+    public void roundTrip_persistsBackendPublicKey() throws Exception {
+        assertThat(roundTrip().getTrustedHardwarePublicKey()).isEqualTo(PUBLIC_KEY_BLOB);
+    }
+
+    @Test
     public void roundTrip_persistsParamsList() throws Exception {
         assertThat(roundTrip().getKeyChainProtectionParams()).hasSize(1);
     }
@@ -163,6 +169,12 @@
         assertThat(roundTripKeys().get(2).getEncryptedKeyMaterial()).isEqualTo(TEST_KEY_3_BYTES);
     }
 
+    @Test
+    public void serialize_doesNotThrowForNullPublicKey() throws Exception {
+        KeyChainSnapshotSerializer.serialize(
+                createTestKeyChainSnapshotNoPublicKey(), new ByteArrayOutputStream());
+    }
+
     private static List<WrappedApplicationKey> roundTripKeys() throws Exception {
         return roundTrip().getWrappedApplicationKeys();
     }
@@ -180,6 +192,41 @@
     }
 
     private static KeyChainSnapshot createTestKeyChainSnapshot() throws Exception {
+        return new KeyChainSnapshot.Builder()
+                .setCounterId(COUNTER_ID)
+                .setSnapshotVersion(SNAPSHOT_VERSION)
+                .setServerParams(SERVER_PARAMS)
+                .setMaxAttempts(MAX_ATTEMPTS)
+                .setEncryptedRecoveryKeyBlob(KEY_BLOB)
+                .setKeyChainProtectionParams(createKeyChainProtectionParamsList())
+                .setWrappedApplicationKeys(createKeys())
+                .setTrustedHardwareCertPath(CERT_PATH)
+                .setTrustedHardwarePublicKey(PUBLIC_KEY_BLOB)
+                .build();
+    }
+
+    private static KeyChainSnapshot createTestKeyChainSnapshotNoPublicKey() throws Exception {
+        return new KeyChainSnapshot.Builder()
+                .setCounterId(COUNTER_ID)
+                .setSnapshotVersion(SNAPSHOT_VERSION)
+                .setServerParams(SERVER_PARAMS)
+                .setMaxAttempts(MAX_ATTEMPTS)
+                .setEncryptedRecoveryKeyBlob(KEY_BLOB)
+                .setKeyChainProtectionParams(createKeyChainProtectionParamsList())
+                .setWrappedApplicationKeys(createKeys())
+                .setTrustedHardwareCertPath(CERT_PATH)
+                .build();
+    }
+
+    private static List<WrappedApplicationKey> createKeys() {
+        ArrayList<WrappedApplicationKey> keyList = new ArrayList<>();
+        keyList.add(createKey(TEST_KEY_1_ALIAS, TEST_KEY_1_BYTES));
+        keyList.add(createKey(TEST_KEY_2_ALIAS, TEST_KEY_2_BYTES));
+        keyList.add(createKey(TEST_KEY_3_ALIAS, TEST_KEY_3_BYTES));
+        return keyList;
+    }
+
+    private static List<KeyChainProtectionParams> createKeyChainProtectionParamsList() {
         KeyDerivationParams keyDerivationParams =
                 KeyDerivationParams.createScryptParams(SALT, MEMORY_DIFFICULTY);
         KeyChainProtectionParams keyChainProtectionParams = new KeyChainProtectionParams.Builder()
@@ -191,22 +238,7 @@
         ArrayList<KeyChainProtectionParams> keyChainProtectionParamsList =
                 new ArrayList<>(1);
         keyChainProtectionParamsList.add(keyChainProtectionParams);
-
-        ArrayList<WrappedApplicationKey> keyList = new ArrayList<>();
-        keyList.add(createKey(TEST_KEY_1_ALIAS, TEST_KEY_1_BYTES));
-        keyList.add(createKey(TEST_KEY_2_ALIAS, TEST_KEY_2_BYTES));
-        keyList.add(createKey(TEST_KEY_3_ALIAS, TEST_KEY_3_BYTES));
-
-        return new KeyChainSnapshot.Builder()
-                .setCounterId(COUNTER_ID)
-                .setSnapshotVersion(SNAPSHOT_VERSION)
-                .setServerParams(SERVER_PARAMS)
-                .setMaxAttempts(MAX_ATTEMPTS)
-                .setEncryptedRecoveryKeyBlob(KEY_BLOB)
-                .setKeyChainProtectionParams(keyChainProtectionParamsList)
-                .setWrappedApplicationKeys(keyList)
-                .setTrustedHardwareCertPath(CERT_PATH)
-                .build();
+        return keyChainProtectionParamsList;
     }
 
     private static WrappedApplicationKey createKey(String alias, byte[] bytes) {
diff --git a/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/storage/RecoverySnapshotStorageTest.java b/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/storage/RecoverySnapshotStorageTest.java
index c772956..ad14c3a 100644
--- a/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/storage/RecoverySnapshotStorageTest.java
+++ b/services/tests/servicestests/src/com/android/server/locksettings/recoverablekeystore/storage/RecoverySnapshotStorageTest.java
@@ -1,27 +1,82 @@
 package com.android.server.locksettings.recoverablekeystore.storage;
 
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
 
+import android.content.Context;
+import android.os.FileUtils;
+import android.security.keystore.recovery.KeyChainProtectionParams;
 import android.security.keystore.recovery.KeyChainSnapshot;
+import android.security.keystore.recovery.KeyDerivationParams;
+import android.security.keystore.recovery.WrappedApplicationKey;
+import android.support.test.InstrumentationRegistry;
 import android.support.test.filters.SmallTest;
 import android.support.test.runner.AndroidJUnit4;
 
 import com.android.server.locksettings.recoverablekeystore.TestData;
 
+import com.google.common.io.Files;
+
+import org.junit.After;
+import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.io.File;
+import java.nio.charset.StandardCharsets;
+import java.security.cert.CertPath;
 import java.security.cert.CertificateException;
 import java.util.ArrayList;
+import java.util.List;
 
 @SmallTest
 @RunWith(AndroidJUnit4.class)
 public class RecoverySnapshotStorageTest {
-    private static final KeyChainSnapshot MINIMAL_KEYCHAIN_SNAPSHOT =
-            createMinimalKeyChainSnapshot();
+    private static final int COUNTER_ID = 432546;
+    private static final int MAX_ATTEMPTS = 10;
+    private static final byte[] SERVER_PARAMS = new byte[] { 12, 8, 2, 4, 15, 64 };
+    private static final byte[] KEY_BLOB = new byte[] { 124, 56, 53, 99, 0, 0, 1 };
+    private static final CertPath CERT_PATH = TestData.CERT_PATH_2;
+    private static final int SECRET_TYPE = KeyChainProtectionParams.TYPE_LOCKSCREEN;
+    private static final int LOCK_SCREEN_UI = KeyChainProtectionParams.UI_FORMAT_PATTERN;
+    private static final byte[] SALT = new byte[] { 1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1 };
+    private static final int MEMORY_DIFFICULTY = 12;
+    private static final byte[] SECRET = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0 };
 
-    private final RecoverySnapshotStorage mRecoverySnapshotStorage = new RecoverySnapshotStorage();
+    private static final String TEST_KEY_1_ALIAS = "alias1";
+    private static final byte[] TEST_KEY_1_BYTES = new byte[] { 100, 32, 43, 66, 77, 88 };
+
+    private static final String TEST_KEY_2_ALIAS = "alias11";
+    private static final byte[] TEST_KEY_2_BYTES = new byte[] { 100, 0, 0, 99, 33, 11 };
+
+    private static final String TEST_KEY_3_ALIAS = "alias111";
+    private static final byte[] TEST_KEY_3_BYTES = new byte[] { 1, 1, 1, 0, 2, 8, 100 };
+
+    private static final int TEST_UID = 1000;
+    private static final String SNAPSHOT_DIRECTORY = "recoverablekeystore/snapshots";
+    private static final String SNAPSHOT_FILE_PATH = "1000.xml";
+    private static final String SNAPSHOT_TOP_LEVEL_DIRECTORY = "recoverablekeystore";
+
+    private static final KeyChainSnapshot MINIMAL_KEYCHAIN_SNAPSHOT =
+            createTestKeyChainSnapshot(1);
+
+    private Context mContext;
+    private RecoverySnapshotStorage mRecoverySnapshotStorage;
+
+    @Before
+    public void setUp() {
+        mContext = InstrumentationRegistry.getTargetContext();
+        mRecoverySnapshotStorage = new RecoverySnapshotStorage(mContext.getFilesDir());
+    }
+
+    @After
+    public void tearDown() {
+        File file = new File(mContext.getFilesDir(), SNAPSHOT_TOP_LEVEL_DIRECTORY);
+        FileUtils.deleteContentsAndDir(file);
+    }
 
     @Test
     public void get_isNullForNonExistentSnapshot() {
@@ -30,37 +85,153 @@
 
     @Test
     public void get_returnsSetSnapshot() {
-        int userId = 1000;
+        mRecoverySnapshotStorage.put(TEST_UID, MINIMAL_KEYCHAIN_SNAPSHOT);
 
-        mRecoverySnapshotStorage.put(userId, MINIMAL_KEYCHAIN_SNAPSHOT);
-
-        assertEquals(MINIMAL_KEYCHAIN_SNAPSHOT, mRecoverySnapshotStorage.get(userId));
+        assertEquals(MINIMAL_KEYCHAIN_SNAPSHOT, mRecoverySnapshotStorage.get(TEST_UID));
     }
 
     @Test
-    public void remove_removesSnapshots() {
-        int userId = 1000;
+    public void get_readsFromDiskIfNoneInMemory() {
+        mRecoverySnapshotStorage.put(TEST_UID, MINIMAL_KEYCHAIN_SNAPSHOT);
+        RecoverySnapshotStorage storage = new RecoverySnapshotStorage(mContext.getFilesDir());
 
-        mRecoverySnapshotStorage.put(userId, MINIMAL_KEYCHAIN_SNAPSHOT);
-        mRecoverySnapshotStorage.remove(userId);
-
-        assertNull(mRecoverySnapshotStorage.get(1000));
+        assertKeyChainSnapshotsAreEqual(MINIMAL_KEYCHAIN_SNAPSHOT, storage.get(TEST_UID));
     }
 
-    private static KeyChainSnapshot createMinimalKeyChainSnapshot() {
+    @Test
+    public void get_deletesFileIfItIsInvalidSnapshot() throws Exception {
+        File folder = new File(mContext.getFilesDir(), SNAPSHOT_DIRECTORY);
+        folder.mkdirs();
+        File file = new File(folder, SNAPSHOT_FILE_PATH);
+        byte[] fileContents = "<keyChainSnapshot></keyChainSnapshot>".getBytes(
+                StandardCharsets.UTF_8);
+        Files.write(fileContents, file);
+        assertTrue(file.exists());
+
+        assertNull(mRecoverySnapshotStorage.get(TEST_UID));
+
+        assertFalse(file.exists());
+    }
+
+    @Test
+    public void put_overwritesOldFiles() {
+        int snapshotVersion = 2;
+        mRecoverySnapshotStorage.put(TEST_UID, MINIMAL_KEYCHAIN_SNAPSHOT);
+
+        mRecoverySnapshotStorage.put(TEST_UID, createTestKeyChainSnapshot(snapshotVersion));
+
+        KeyChainSnapshot snapshot = new RecoverySnapshotStorage(mContext.getFilesDir())
+                .get(TEST_UID);
+        assertEquals(snapshotVersion, snapshot.getSnapshotVersion());
+    }
+
+    @Test
+    public void put_doesNotThrowIfCannotCreateFiles() throws Exception {
+        File evilFile = new File(mContext.getFilesDir(), "recoverablekeystore");
+        Files.write(new byte[] { 1 }, evilFile);
+
+        mRecoverySnapshotStorage.put(TEST_UID, MINIMAL_KEYCHAIN_SNAPSHOT);
+
+        assertNull(new RecoverySnapshotStorage(mContext.getFilesDir()).get(TEST_UID));
+    }
+
+    @Test
+    public void remove_removesSnapshotsFromMemory() {
+        mRecoverySnapshotStorage.put(TEST_UID, MINIMAL_KEYCHAIN_SNAPSHOT);
+        mRecoverySnapshotStorage.remove(TEST_UID);
+
+        assertNull(mRecoverySnapshotStorage.get(TEST_UID));
+    }
+
+    @Test
+    public void remove_removesSnapshotsFromDisk() {
+        mRecoverySnapshotStorage.put(TEST_UID, MINIMAL_KEYCHAIN_SNAPSHOT);
+
+        new RecoverySnapshotStorage(mContext.getFilesDir()).remove(TEST_UID);
+
+        assertNull(new RecoverySnapshotStorage(mContext.getFilesDir()).get(TEST_UID));
+    }
+
+    private void assertKeyChainSnapshotsAreEqual(KeyChainSnapshot a, KeyChainSnapshot b) {
+        assertEquals(b.getCounterId(), a.getCounterId());
+        assertEquals(b.getSnapshotVersion(), a.getSnapshotVersion());
+        assertArrayEquals(b.getServerParams(), a.getServerParams());
+        assertEquals(b.getMaxAttempts(), a.getMaxAttempts());
+        assertArrayEquals(b.getEncryptedRecoveryKeyBlob(), a.getEncryptedRecoveryKeyBlob());
+        assertEquals(b.getTrustedHardwareCertPath(), a.getTrustedHardwareCertPath());
+
+        List<WrappedApplicationKey> aKeys = a.getWrappedApplicationKeys();
+        List<WrappedApplicationKey> bKeys = b.getWrappedApplicationKeys();
+        assertEquals(bKeys.size(), aKeys.size());
+        for (int i = 0; i < aKeys.size(); i++) {
+            assertWrappedApplicationKeysAreEqual(aKeys.get(i), bKeys.get(i));
+        }
+
+        List<KeyChainProtectionParams> aParams = a.getKeyChainProtectionParams();
+        List<KeyChainProtectionParams> bParams = b.getKeyChainProtectionParams();
+        assertEquals(bParams.size(), aParams.size());
+        for (int i = 0; i < aParams.size(); i++) {
+            assertKeyChainProtectionParamsAreEqual(aParams.get(i), bParams.get(i));
+        }
+    }
+
+    private void assertWrappedApplicationKeysAreEqual(
+            WrappedApplicationKey a, WrappedApplicationKey b) {
+        assertEquals(b.getAlias(), a.getAlias());
+        assertArrayEquals(b.getEncryptedKeyMaterial(), a.getEncryptedKeyMaterial());
+    }
+
+    private void assertKeyChainProtectionParamsAreEqual(
+            KeyChainProtectionParams a, KeyChainProtectionParams b) {
+        assertEquals(b.getUserSecretType(), a.getUserSecretType());
+        assertEquals(b.getLockScreenUiFormat(), a.getLockScreenUiFormat());
+        assertKeyDerivationParamsAreEqual(a.getKeyDerivationParams(), b.getKeyDerivationParams());
+    }
+
+    private void assertKeyDerivationParamsAreEqual(KeyDerivationParams a, KeyDerivationParams b) {
+        assertEquals(b.getAlgorithm(), a.getAlgorithm());
+        assertEquals(b.getMemoryDifficulty(), a.getMemoryDifficulty());
+        assertArrayEquals(b.getSalt(), a.getSalt());
+    }
+
+    private static KeyChainSnapshot createTestKeyChainSnapshot(int snapshotVersion) {
+        KeyDerivationParams keyDerivationParams =
+                KeyDerivationParams.createScryptParams(SALT, MEMORY_DIFFICULTY);
+        KeyChainProtectionParams keyChainProtectionParams = new KeyChainProtectionParams.Builder()
+                .setKeyDerivationParams(keyDerivationParams)
+                .setUserSecretType(SECRET_TYPE)
+                .setLockScreenUiFormat(LOCK_SCREEN_UI)
+                .setSecret(SECRET)
+                .build();
+        ArrayList<KeyChainProtectionParams> keyChainProtectionParamsList =
+                new ArrayList<>(1);
+        keyChainProtectionParamsList.add(keyChainProtectionParams);
+
+        ArrayList<WrappedApplicationKey> keyList = new ArrayList<>();
+        keyList.add(createKey(TEST_KEY_1_ALIAS, TEST_KEY_1_BYTES));
+        keyList.add(createKey(TEST_KEY_2_ALIAS, TEST_KEY_2_BYTES));
+        keyList.add(createKey(TEST_KEY_3_ALIAS, TEST_KEY_3_BYTES));
+
         try {
             return new KeyChainSnapshot.Builder()
-                    .setCounterId(1)
-                    .setSnapshotVersion(1)
-                    .setServerParams(new byte[0])
-                    .setMaxAttempts(10)
-                    .setEncryptedRecoveryKeyBlob(new byte[0])
-                    .setKeyChainProtectionParams(new ArrayList<>())
-                    .setWrappedApplicationKeys(new ArrayList<>())
-                    .setTrustedHardwareCertPath(TestData.CERT_PATH_1)
+                    .setCounterId(COUNTER_ID)
+                    .setSnapshotVersion(snapshotVersion)
+                    .setServerParams(SERVER_PARAMS)
+                    .setMaxAttempts(MAX_ATTEMPTS)
+                    .setEncryptedRecoveryKeyBlob(KEY_BLOB)
+                    .setKeyChainProtectionParams(keyChainProtectionParamsList)
+                    .setWrappedApplicationKeys(keyList)
+                    .setTrustedHardwareCertPath(CERT_PATH)
                     .build();
         } catch (CertificateException e) {
             throw new RuntimeException(e);
         }
     }
+
+    private static WrappedApplicationKey createKey(String alias, byte[] bytes) {
+        return new WrappedApplicationKey.Builder()
+                .setAlias(alias)
+                .setEncryptedKeyMaterial(bytes)
+                .build();
+    }
 }