Add AES OCB decryption.

Also, refactor to extract functionality that will be common to all AEAD modes.

Change-Id: I4bcf12c9d2d464ab1af559c69031904ffae45e25
diff --git a/google_keymaster_test.cpp b/google_keymaster_test.cpp
index 6c672be..f5a9f6b 100644
--- a/google_keymaster_test.cpp
+++ b/google_keymaster_test.cpp
@@ -99,6 +99,8 @@
         blob_.key_material = NULL;
     }
 
+    const keymaster_key_blob_t& key_blob() { return blob_; }
+
     SoftKeymasterDevice device_;
 
     AuthorizationSet params_;
@@ -923,6 +925,7 @@
   protected:
     void GenerateKey(keymaster_algorithm_t algorithm, keymaster_padding_t padding,
                      uint32_t key_size) {
+        params_.Clear();
         params_.push_back(Authorization(TAG_PURPOSE, KM_PURPOSE_ENCRYPT));
         params_.push_back(Authorization(TAG_PURPOSE, KM_PURPOSE_DECRYPT));
         params_.push_back(Authorization(TAG_ALGORITHM, algorithm));
@@ -933,12 +936,16 @@
         params_.push_back(Authorization(TAG_AUTH_TIMEOUT, 300));
         if (static_cast<int>(padding) != -1)
             params_.push_back(TAG_PADDING, padding);
+
+        FreeKeyBlob();
+        FreeCharacteristics();
         EXPECT_EQ(KM_ERROR_OK, device()->generate_key(device(), params_.data(), params_.size(),
                                                       &blob_, &characteristics_));
     }
 
     void GenerateSymmetricKey(keymaster_algorithm_t algorithm, uint32_t key_size,
                               keymaster_block_mode_t block_mode, uint32_t chunk_length) {
+        params_.Clear();
         params_.push_back(Authorization(TAG_PURPOSE, KM_PURPOSE_ENCRYPT));
         params_.push_back(Authorization(TAG_PURPOSE, KM_PURPOSE_DECRYPT));
         params_.push_back(Authorization(TAG_ALGORITHM, algorithm));
@@ -951,6 +958,8 @@
         params_.push_back(Authorization(TAG_APPLICATION_ID, "app_id", 6));
         params_.push_back(Authorization(TAG_AUTH_TIMEOUT, 300));
 
+        FreeKeyBlob();
+        FreeCharacteristics();
         EXPECT_EQ(KM_ERROR_OK, device()->generate_key(device(), params_.data(), params_.size(),
                                                       &blob_, &characteristics_));
     }
@@ -970,7 +979,7 @@
             device()->update(device(), op_handle, reinterpret_cast<const uint8_t*>(message), size,
                              input_consumed, &out_tmp, &out_length);
         if (out_tmp)
-            *output = string(reinterpret_cast<char*>(out_tmp), out_length);
+            output->append(reinterpret_cast<char*>(out_tmp), out_length);
         free(out_tmp);
         return error;
     }
@@ -981,7 +990,7 @@
         keymaster_error_t error = device()->finish(device(), op_handle, NULL /* signature */,
                                                    0 /* signature_length */, &out_tmp, &out_length);
         if (out_tmp)
-            *output = string(reinterpret_cast<char*>(out_tmp), out_length);
+            output->append(reinterpret_cast<char*>(out_tmp), out_length);
         free(out_tmp);
         return error;
     }
@@ -996,16 +1005,16 @@
         EXPECT_EQ(KM_ERROR_OK, UpdateOperation(op_handle, message, size, &result, &input_consumed));
         EXPECT_EQ(size, input_consumed);
         EXPECT_EQ(KM_ERROR_OK, FinishOperation(op_handle, &result));
-
+        EXPECT_EQ(KM_ERROR_INVALID_OPERATION_HANDLE, device()->abort(device(), op_handle));
         return result;
     }
 
-    string EncryptMessage(const void* message, size_t size) {
-        return ProcessMessage(KM_PURPOSE_ENCRYPT, blob_, message, size);
+    string EncryptMessage(const string& message) {
+        return ProcessMessage(KM_PURPOSE_ENCRYPT, blob_, message.c_str(), message.length());
     }
 
-    string DecryptMessage(const void* ciphertext, size_t size) {
-        return ProcessMessage(KM_PURPOSE_DECRYPT, blob_, ciphertext, size);
+    string DecryptMessage(const string& ciphertext) {
+        return ProcessMessage(KM_PURPOSE_DECRYPT, blob_, ciphertext.c_str(), ciphertext.length());
     }
 
     void AddClientParams(AuthorizationSet* set) { set->push_back(TAG_APPLICATION_ID, "app_id", 6); }
@@ -1027,10 +1036,10 @@
 TEST_F(EncryptionOperationsTest, RsaOaepSuccess) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_OAEP, 512);
     const char message[] = "Hello World!";
-    string ciphertext1 = EncryptMessage(message, strlen(message));
+    string ciphertext1 = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext1.size());
 
-    string ciphertext2 = EncryptMessage(message, strlen(message));
+    string ciphertext2 = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext2.size());
 
     // OAEP randomizes padding so every result should be different.
@@ -1040,10 +1049,10 @@
 TEST_F(EncryptionOperationsTest, RsaOaepRoundTrip) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_OAEP, 512);
     const char message[] = "Hello World!";
-    string ciphertext = EncryptMessage(message, strlen(message));
+    string ciphertext = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext.size());
 
-    string plaintext = DecryptMessage(ciphertext.data(), ciphertext.size());
+    string plaintext = DecryptMessage(ciphertext);
     EXPECT_EQ(message, plaintext);
 }
 
@@ -1064,7 +1073,7 @@
 TEST_F(EncryptionOperationsTest, RsaOaepCorruptedDecrypt) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_OAEP, 512);
     const char message[] = "Hello World!";
-    string ciphertext = EncryptMessage(message, strlen(message));
+    string ciphertext = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext.size());
 
     // Corrupt the ciphertext
@@ -1083,10 +1092,10 @@
 TEST_F(EncryptionOperationsTest, RsaPkcs1Success) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_PKCS1_1_5_ENCRYPT, 512);
     const char message[] = "Hello World!";
-    string ciphertext1 = EncryptMessage(message, strlen(message));
+    string ciphertext1 = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext1.size());
 
-    string ciphertext2 = EncryptMessage(message, strlen(message));
+    string ciphertext2 = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext2.size());
 
     // PKCS1 v1.5 randomizes padding so every result should be different.
@@ -1096,10 +1105,10 @@
 TEST_F(EncryptionOperationsTest, RsaPkcs1RoundTrip) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_PKCS1_1_5_ENCRYPT, 512);
     const char message[] = "Hello World!";
-    string ciphertext = EncryptMessage(message, strlen(message));
+    string ciphertext = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext.size());
 
-    string plaintext = DecryptMessage(ciphertext.data(), ciphertext.size());
+    string plaintext = DecryptMessage(ciphertext);
     EXPECT_EQ(message, plaintext);
 }
 
@@ -1120,7 +1129,7 @@
 TEST_F(EncryptionOperationsTest, RsaPkcs1CorruptedDecrypt) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_PKCS1_1_5_ENCRYPT, 512);
     const char message[] = "Hello World!";
-    string ciphertext = EncryptMessage(message, strlen(message));
+    string ciphertext = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext.size());
 
     // Corrupt the ciphertext
@@ -1139,15 +1148,146 @@
 TEST_F(EncryptionOperationsTest, AesOcbSuccess) {
     GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
     const char message[] = "Hello World!";
-    string ciphertext1 = EncryptMessage(message, strlen(message));
+    string ciphertext1 = EncryptMessage(string(message));
     EXPECT_EQ(12 /* nonce */ + strlen(message) + 16 /* tag */, ciphertext1.size());
 
-    string ciphertext2 = EncryptMessage(message, strlen(message));
+    string ciphertext2 = EncryptMessage(string(message));
     EXPECT_EQ(12 /* nonce */ + strlen(message) + 16 /* tag */, ciphertext2.size());
 
     // OCB uses a random nonce, so every output should be different
     EXPECT_NE(ciphertext1, ciphertext2);
 }
 
+TEST_F(EncryptionOperationsTest, AesOcbRoundTripSuccess) {
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    string message = "Hello World!";
+    string ciphertext = EncryptMessage(message);
+    EXPECT_EQ(12 /* nonce */ + message.length() + 16 /* tag */, ciphertext.size());
+
+    string plaintext = DecryptMessage(ciphertext);
+    EXPECT_EQ(message, plaintext);
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbRoundTripCorrupted) {
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    const char message[] = "Hello World!";
+    string ciphertext = EncryptMessage(string(message));
+    EXPECT_EQ(12 /* nonce */ + strlen(message) + 16 /* tag */, ciphertext.size());
+
+    ciphertext[ciphertext.size() / 2]++;
+
+    uint64_t op_handle;
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_DECRYPT, key_blob(), &op_handle));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(op_handle, ciphertext.c_str(), ciphertext.length(),
+                                           &result, &input_consumed));
+    EXPECT_EQ(ciphertext.length(), input_consumed);
+    EXPECT_EQ(KM_ERROR_VERIFICATION_FAILED, FinishOperation(op_handle, &result));
+}
+
+TEST_F(EncryptionOperationsTest, AesDecryptGarbage) {
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    string ciphertext(128, 'a');
+    uint64_t op_handle;
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_DECRYPT, key_blob(), &op_handle));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(op_handle, ciphertext.c_str(), ciphertext.length(),
+                                           &result, &input_consumed));
+    EXPECT_EQ(ciphertext.length(), input_consumed);
+    EXPECT_EQ(KM_ERROR_VERIFICATION_FAILED, FinishOperation(op_handle, &result));
+}
+
+TEST_F(EncryptionOperationsTest, AesDecryptTooShort) {
+    // Try decrypting garbage ciphertext that is too short to be valid (< nonce + tag).
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    string ciphertext(12 + 15, 'a');
+    uint64_t op_handle;
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_DECRYPT, key_blob(), &op_handle));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(op_handle, ciphertext.c_str(), ciphertext.length(),
+                                           &result, &input_consumed));
+    EXPECT_EQ(ciphertext.length(), input_consumed);
+    EXPECT_EQ(KM_ERROR_INVALID_INPUT_LENGTH, FinishOperation(op_handle, &result));
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbRoundTripEmptySuccess) {
+    // Empty messages should work fine.
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    const char message[] = "";
+    string ciphertext = EncryptMessage(string(message));
+    EXPECT_EQ(12 /* nonce */ + strlen(message) + 16 /* tag */, ciphertext.size());
+
+    string plaintext = DecryptMessage(ciphertext);
+    EXPECT_EQ(message, plaintext);
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbRoundTripEmptyCorrupted) {
+    // Should even detect corruption of empty messages.
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    const char message[] = "";
+    string ciphertext = EncryptMessage(string(message));
+    EXPECT_EQ(12 /* nonce */ + strlen(message) + 16 /* tag */, ciphertext.size());
+
+    ciphertext[ciphertext.size() / 2]++;
+
+    uint64_t op_handle;
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_DECRYPT, key_blob(), &op_handle));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(op_handle, ciphertext.c_str(), ciphertext.length(),
+                                           &result, &input_consumed));
+    EXPECT_EQ(ciphertext.length(), input_consumed);
+    EXPECT_EQ(KM_ERROR_VERIFICATION_FAILED, FinishOperation(op_handle, &result));
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbFullChunk) {
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    string message(4096, 'a');
+    string ciphertext = EncryptMessage(message);
+    EXPECT_EQ(12 /* nonce */ + message.length() + 16 /* tag */, ciphertext.size());
+
+    string plaintext = DecryptMessage(ciphertext);
+    EXPECT_EQ(message, plaintext);
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbVariousChunkLengths) {
+    for (unsigned chunk_length = 1; chunk_length <= 128; ++chunk_length) {
+        GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, chunk_length);
+        string message(128, 'a');
+        string ciphertext = EncryptMessage(message);
+        int expected_tag_count = (message.length() + chunk_length - 1) / chunk_length;
+        EXPECT_EQ(12 /* nonce */ + message.length() + 16 * expected_tag_count, ciphertext.size())
+            << "Unexpected ciphertext size for chunk length " << chunk_length
+            << " expected tag count was " << expected_tag_count
+            << " but actual tag count was probably "
+            << (ciphertext.size() - message.length() - 12) / 16;
+
+        string plaintext = DecryptMessage(ciphertext);
+        EXPECT_EQ(message, plaintext);
+    }
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbAbort) {
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    const char message[] = "Hello";
+
+    uint64_t op_handle;
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_ENCRYPT, key_blob(), &op_handle));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK,
+              UpdateOperation(op_handle, message, strlen(message), &result, &input_consumed));
+    EXPECT_EQ(strlen(message), input_consumed);
+    EXPECT_EQ(KM_ERROR_OK, device()->abort(device(), op_handle));
+}
+
 }  // namespace test
 }  // namespace keymaster