add SIDs and password storage support to Keyguard base

Change-Id: I2b1bb41a5e40e45e810f2bd299edb6b8765df9e6
diff --git a/keyguard.cpp b/keyguard.cpp
index e435a3f..3e1f338 100644
--- a/keyguard.cpp
+++ b/keyguard.cpp
@@ -22,34 +22,74 @@
 
 namespace keyguard {
 
-Keyguard::~Keyguard() {
-    if (password_key_.buffer.get()) {
-        memset_s(password_key_.buffer.get(), 0, password_key_.length);
-    }
-}
+/**
+ * Internal only structure for easy serialization
+ * and deserialization of password handles.
+ */
+static const uint8_t HANDLE_VERSION = 0;
+struct __attribute__ ((__packed__)) password_handle_t {
+    // fields included in signature
+    uint8_t version = HANDLE_VERSION;
+    secure_id_t user_id;
+    secure_id_t authenticator_id;
+
+    // fields not included in signature
+    salt_t salt;
+    uint8_t signature[32];
+};
 
 void Keyguard::Enroll(const EnrollRequest &request, EnrollResponse *response) {
     if (response == NULL) return;
 
-    SizedBuffer enrolled_password;
     if (!request.provided_password.buffer.get()) {
         response->error = KG_ERROR_INVALID;
         return;
     }
 
-    size_t salt_length;
-    UniquePtr<uint8_t> salt;
-    GetSalt(&salt, &salt_length);
+    secure_id_t user_id = 0;
+    uint8_t *current_password = NULL;
+    size_t current_password_size = 0;
 
-    size_t signature_length;
-    UniquePtr<uint8_t> signature;
-    ComputePasswordSignature(password_key_.buffer.get(),
-                password_key_.length, request.provided_password.buffer.get(),
-                request.provided_password.length, salt.get(), salt_length, &signature,
-                &signature_length);
+    if (request.password_handle.buffer.get() == NULL) {
+        // Password handle does not match what is stored, generate new SecureID
+        GetRandom(&user_id, sizeof(secure_id_t));
+    } else {
+        if (!ValidatePasswordFile(request.user_id, request.password_handle)) {
+           response->error = KG_ERROR_INVALID;
+           return;
+        } else {
+            // Password handle matches password file
+            password_handle_t *pw_handle =
+                reinterpret_cast<password_handle_t *>(request.password_handle.buffer.get());
+            if (!DoVerify(pw_handle, request.enrolled_password)) {
+                // incorrect old password
+                response->error = KG_ERROR_INVALID;
+                return;
+            }
 
-    SerializeHandle(salt.get(), salt_length, signature.get(), signature_length, enrolled_password);
-    response->SetEnrolledPasswordHandle(&enrolled_password);
+            user_id = pw_handle->user_id;
+        }
+    }
+
+    salt_t salt;
+    GetRandom(&salt, sizeof(salt));
+
+    secure_id_t authenticator_id;
+    GetRandom(&authenticator_id, sizeof(authenticator_id));
+
+
+    SizedBuffer password_handle;
+    if(!CreatePasswordHandle(&password_handle,
+            salt, user_id, authenticator_id, request.provided_password.buffer.get(),
+            request.provided_password.length)) {
+        response->error = KG_ERROR_INVALID;
+        return;
+    }
+
+
+    WritePasswordFile(request.user_id, password_handle);
+
+    response->SetEnrolledPasswordHandle(&password_handle);
 }
 
 void Keyguard::Verify(const VerifyRequest &request, VerifyResponse *response) {
@@ -60,104 +100,120 @@
         return;
     }
 
-    size_t salt_length, signature_length;
-    uint8_t *salt, *signature;
-    keyguard_error_t error = DeserializeHandle(
-            &request.password_handle, &salt, &salt_length, &signature, &signature_length);
+    secure_id_t user_id, authenticator_id;
+    password_handle_t *password_handle = reinterpret_cast<password_handle_t *>(
+            request.password_handle.buffer.get());
 
-    if (error != KG_ERROR_OK) {
-        response->error = error;
+    // Sanity check
+    if (password_handle->version != HANDLE_VERSION) {
+        response->error = KG_ERROR_INVALID;
         return;
     }
 
-    size_t provided_password_signature_length;
-    UniquePtr<uint8_t> provided_password_signature;
-    ComputePasswordSignature(password_key_.buffer.get(),
-            password_key_.length, request.provided_password.buffer.get(), request.provided_password.length,
-            salt, salt_length, &provided_password_signature, &provided_password_signature_length);
+    if (!ValidatePasswordFile(request.user_id, request.password_handle)) {
+        // we don't allow access to keys if we can't validate the file.
+        // we must allow this case to support authentication before we decrypt
+        // /data, however.
+        user_id = 0;
+        authenticator_id = 0;
+    } else {
+        user_id = password_handle->user_id;
+        authenticator_id = password_handle->authenticator_id;
+    }
 
-    if (provided_password_signature_length == signature_length &&
-            memcmp_s(signature, provided_password_signature.get(), signature_length) == 0) {
+    struct timespec time;
+    uint64_t timestamp;
+    clock_gettime(CLOCK_MONOTONIC_RAW, &time);
+    timestamp = static_cast<uint32_t>(time.tv_sec);
+
+    if (DoVerify(password_handle, request.provided_password)) {
         // Signature matches
         SizedBuffer auth_token;
-        MintAuthToken(request.user_id, &auth_token.buffer, &auth_token.length);
+        MintAuthToken(&auth_token.buffer, &auth_token.length, timestamp,
+                user_id, authenticator_id);
         response->SetVerificationToken(&auth_token);
     } else {
         response->error = KG_ERROR_INVALID;
     }
 }
 
-void Keyguard::MintAuthToken(uint32_t user_id, UniquePtr<uint8_t> *auth_token, size_t *length) {
+bool Keyguard::CreatePasswordHandle(SizedBuffer *password_handle_buffer, salt_t salt,
+        secure_id_t user_id, secure_id_t authenticator_id, const uint8_t *password,
+        size_t password_length) {
+    password_handle_buffer->buffer.reset(new uint8_t[sizeof(password_handle_t)]);
+    password_handle_buffer->length = sizeof(password_handle_t);
+
+    password_handle_t *password_handle = reinterpret_cast<password_handle_t *>(
+            password_handle_buffer->buffer.get());
+    password_handle->version = HANDLE_VERSION;
+    password_handle->salt = salt;
+    password_handle->user_id = user_id;
+    password_handle->authenticator_id = authenticator_id;
+
+    size_t metadata_length = sizeof(user_id) /* user id */
+        + sizeof(authenticator_id) /* auth id */ + sizeof(uint8_t) /* version */;
+    uint8_t to_sign[password_length + metadata_length];
+    memcpy(to_sign, &password_handle->version, metadata_length);
+    memcpy(to_sign + metadata_length, password, password_length);
+
+    UniquePtr<uint8_t> password_key;
+    size_t password_key_length = 0;
+    GetPasswordKey(&password_key, &password_key_length);
+
+    if (!password_key.get() || password_key_length == 0) {
+        return false;
+    }
+
+    ComputePasswordSignature(password_handle->signature, sizeof(password_handle->signature),
+            password_key.get(), password_key_length, to_sign, sizeof(to_sign), salt);
+    return true;
+}
+
+bool Keyguard::DoVerify(const password_handle_t *expected_handle, const SizedBuffer &password) {
+    if (!password.buffer.get()) return false;
+
+    SizedBuffer provided_handle;
+    if (!CreatePasswordHandle(&provided_handle, expected_handle->salt, expected_handle->user_id,
+            expected_handle->authenticator_id, password.buffer.get(), password.length)) {
+        return false;
+    }
+
+    return memcmp_s(provided_handle.buffer.get(), expected_handle, sizeof(*expected_handle)) == 0;
+}
+
+bool Keyguard::ValidatePasswordFile(uint32_t uid, const SizedBuffer &provided_handle) {
+    SizedBuffer stored_handle;
+    ReadPasswordFile(uid, &stored_handle);
+
+    if (!stored_handle.buffer.get() || stored_handle.length == 0) return false;
+
+    // do we also verify the signature here?
+    return stored_handle.length == provided_handle.length &&
+        memcmp_s(stored_handle.buffer.get(), provided_handle.buffer.get(), stored_handle.length)
+            == 0;
+}
+
+void Keyguard::MintAuthToken(UniquePtr<uint8_t> *auth_token, size_t *length,
+        uint32_t timestamp, secure_id_t user_id, secure_id_t authenticator_id) {
     if (auth_token == NULL) return;
 
     AuthToken *token = new AuthToken;
     SizedBuffer serialized_auth_token;
 
-    struct timespec time;
-    clock_gettime(CLOCK_MONOTONIC_RAW, &time);
-
-    token->auth_token_size = sizeof(AuthToken) -
-        sizeof(token->auth_token_tag) - sizeof(token->auth_token_size);
-    token->user_id = user_id;
-    token->timestamp = static_cast<uint64_t>(time.tv_sec);
+    token->root_secure_user_id = user_id;
+    token->auxiliary_secure_user_id = authenticator_id;
+    token->timestamp = timestamp;
 
     UniquePtr<uint8_t> auth_token_key;
     size_t key_len;
     GetAuthTokenKey(&auth_token_key, &key_len);
 
-    size_t hash_len = (size_t)((uint8_t *)&token->hmac_tag - (uint8_t *)token);
-    size_t signature_len;
-    UniquePtr<uint8_t> signature;
-    ComputeSignature(auth_token_key.get(), key_len,
-            reinterpret_cast<uint8_t *>(token), hash_len, &signature, &signature_len);
+    size_t hash_len = (size_t)((uint8_t *)&token->hmac - (uint8_t *)token);
+    ComputeSignature(token->hmac, sizeof(token->hmac), auth_token_key.get(), key_len,
+            reinterpret_cast<uint8_t *>(token), hash_len);
 
-    memset(&token->hmac, 0, sizeof(token->hmac));
-
-    memcpy(&token->hmac, signature.get(), signature_len > sizeof(token->hmac)
-            ? sizeof(token->hmac) : signature_len);
     if (length != NULL) *length = sizeof(AuthToken);
     auth_token->reset(reinterpret_cast<uint8_t *>(token));
 }
 
-void Keyguard::SerializeHandle(const uint8_t *salt, size_t salt_length, const uint8_t *signature,
-        size_t signature_length, SizedBuffer &result) {
-    const size_t buffer_len = 2 * sizeof(size_t) + salt_length + signature_length;
-    result.buffer.reset(new uint8_t[buffer_len]);
-    result.length = buffer_len;
-    uint8_t *buffer = result.buffer.get();
-    memcpy(buffer, &salt_length, sizeof(salt_length));
-    buffer += sizeof(salt_length);
-    memcpy(buffer, salt, salt_length);
-    buffer += salt_length;
-    memcpy(buffer, &signature_length, sizeof(signature_length));
-    buffer += sizeof(signature_length);
-    memcpy(buffer, signature, signature_length);
-}
-
-keyguard_error_t Keyguard::DeserializeHandle(const SizedBuffer *handle, uint8_t **salt,
-        size_t *salt_length, uint8_t **password, size_t *password_length) {
-    if (handle && handle->length > (2 * sizeof(size_t))) {
-        int read = 0;
-        uint8_t *buffer = handle->buffer.get();
-        memcpy(salt_length, buffer, sizeof(*salt_length));
-        read += sizeof(*salt_length);
-        if (read + *salt_length < handle->length) {
-            *salt = buffer + read;
-            read += *salt_length;
-            if (read + sizeof(*password_length) < handle->length) {
-                buffer += read;
-                memcpy(password_length, buffer, sizeof(*password_length));
-                *password = buffer + sizeof(*password_length);
-            } else {
-                return KG_ERROR_INVALID;
-            }
-        } else {
-            return KG_ERROR_INVALID;
-        }
-
-        return KG_ERROR_OK;
-    }
-    return KG_ERROR_INVALID;
-}
-
 }