Add AES OCB decryption.

Also, refactor to extract functionality that will be common to all AEAD modes.

Change-Id: I4bcf12c9d2d464ab1af559c69031904ffae45e25
diff --git a/Makefile b/Makefile
index 3378f70..2950f07 100644
--- a/Makefile
+++ b/Makefile
@@ -25,6 +25,7 @@
 LDLIBS=-lcrypto -lpthread -lstdc++
 
 CPPSRCS=\
+	aead_mode_operation.cpp \
 	aes_key.cpp \
 	aes_operation.cpp \
 	asymmetric_key.cpp \
@@ -126,6 +127,7 @@
 	$(GTEST)/src/gtest-all.o
 
 google_keymaster_test: google_keymaster_test.o \
+	aead_mode_operation.o \
 	aes_key.o \
 	aes_operation.o \
 	asymmetric_key.o \
diff --git a/aead_mode_operation.cpp b/aead_mode_operation.cpp
new file mode 100644
index 0000000..44d98cc
--- /dev/null
+++ b/aead_mode_operation.cpp
@@ -0,0 +1,174 @@
+/*
+ * Copyright 2014 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <stdio.h>
+
+#include <openssl/aes.h>
+#include <openssl/rand.h>
+
+#include "aead_mode_operation.h"
+
+namespace keymaster {
+
+keymaster_error_t AeadModeOperation::Begin() {
+    keymaster_error_t error = Initialize(key_, key_size_, nonce_length_, tag_length_);
+    if (error == KM_ERROR_OK) {
+        buffer_end_ = 0;
+        buffer_.reset(new uint8_t[processing_unit_]);
+        if (!buffer_.get())
+            error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
+    }
+    return error;
+}
+
+inline size_t min(size_t a, size_t b) {
+    if (a < b)
+        return a;
+    return b;
+}
+
+keymaster_error_t AeadModeOperation::Update(const Buffer& input, Buffer* output,
+                                            size_t* input_consumed) {
+    // Make an effort to reserve enough output space.  The output buffer will be extended if needed,
+    // but this reduces reallocations.
+    if (!output->reserve(EstimateOutputSize(input, output)))
+        return KM_ERROR_MEMORY_ALLOCATION_FAILED;
+
+    keymaster_error_t error = KM_ERROR_OK;
+    *input_consumed = 0;
+
+    const uint8_t* plaintext = input.peek_read();
+    const uint8_t* plaintext_end = plaintext + input.available_read();
+    while (plaintext < plaintext_end && error == KM_ERROR_OK) {
+        if (buffered_data_length() == processing_unit_) {
+            assert(nonce_handled_);
+            if (!nonce_handled_)
+                return KM_ERROR_UNKNOWN_ERROR;
+            error = ProcessChunk(output);
+            ClearBuffer();
+            IncrementNonce();
+        }
+        plaintext = AppendToBuffer(plaintext, plaintext_end - plaintext);
+        *input_consumed = plaintext - input.peek_read();
+        if (!nonce_handled_)
+            error = HandleNonce(output);
+    }
+    return error;
+}
+
+keymaster_error_t AeadModeOperation::Finish(const Buffer& /* signature */, Buffer* output) {
+    keymaster_error_t error = KM_ERROR_OK;
+    if (!nonce_handled_)
+        error = HandleNonce(output);
+    if (error != KM_ERROR_OK)
+        return error;
+    return ProcessChunk(output);
+}
+
+keymaster_error_t AeadModeOperation::ProcessChunk(Buffer* output) {
+    if (!nonce_handled_)
+        return KM_ERROR_INVALID_INPUT_LENGTH;
+
+    keymaster_error_t error = KM_ERROR_OK;
+    if (purpose() == KM_PURPOSE_DECRYPT) {
+        if (buffered_data_length() < tag_length_)
+            return KM_ERROR_INVALID_INPUT_LENGTH;
+        ExtractTagFromBuffer();
+        logger().info("AeadMode decrypting %d", buffered_data_length());
+        if (!output->reserve(output->available_read() + buffered_data_length()))
+            error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
+        else
+            error = DecryptChunk(nonce_, nonce_length_, tag_, tag_length_, additional_data_,
+                                 buffer_.get(), buffered_data_length(), output);
+    } else {
+        if (!output->reserve(output->available_read() + buffered_data_length() + tag_length_))
+            error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
+        else
+            error = EncryptChunk(nonce_, nonce_length_, tag_length_, additional_data_,
+                                 buffer_.get(), buffered_data_length(), output);
+    }
+    return error;
+}
+
+size_t AeadModeOperation::EstimateOutputSize(const Buffer& input, Buffer* output) {
+    switch (purpose()) {
+    case KM_PURPOSE_ENCRYPT: {
+        size_t chunk_length = processing_unit_;
+        size_t chunk_count = (input.available_read() + chunk_length - 1) / chunk_length;
+        return output->available_read() + nonce_length_ +
+               chunk_count * (chunk_length + tag_length_);
+    }
+    case KM_PURPOSE_DECRYPT: {
+        size_t chunk_length = processing_unit_ - tag_length_;
+        size_t chunk_count =
+            (input.available_read() - nonce_length_ + processing_unit_ - 1) / processing_unit_;
+        return output->available_read() + chunk_length * chunk_count;
+    }
+    default:
+        logger().error("Encountered invalid purpose %d", purpose());
+        return 0;
+    }
+}
+
+keymaster_error_t AeadModeOperation::HandleNonce(Buffer* output) {
+    switch (purpose()) {
+    case KM_PURPOSE_ENCRYPT:
+        if (!RAND_bytes(nonce_, nonce_length_)) {
+            logger().error("Failed to generate nonce");
+            return KM_ERROR_UNKNOWN_ERROR;
+        }
+        if (!output->reserve(nonce_length_))
+            return KM_ERROR_MEMORY_ALLOCATION_FAILED;
+        output->write(nonce_, nonce_length_);
+        nonce_handled_ = true;
+        break;
+    case KM_PURPOSE_DECRYPT:
+        if (buffered_data_length() >= nonce_length_) {
+            memcpy(nonce_, buffer_.get(), nonce_length_);
+            memmove(buffer_.get(), buffer_.get() + nonce_length_,
+                    buffered_data_length() - nonce_length_);
+            buffer_end_ -= nonce_length_;
+            nonce_handled_ = true;
+        }
+        break;
+    default:
+        return KM_ERROR_UNSUPPORTED_PURPOSE;
+    }
+    return KM_ERROR_OK;
+}
+
+void AeadModeOperation::IncrementNonce() {
+    for (int i = nonce_length_ - 1; i > 0; --i)
+        if (++nonce_[i])
+            break;
+}
+
+const uint8_t* AeadModeOperation::AppendToBuffer(const uint8_t* data, size_t data_length) {
+    // Only take as much data as we can fit.
+    if (data_length > buffer_free_space())
+        data_length = buffer_free_space();
+    memcpy(buffer_.get() + buffer_end_, data, data_length);
+    buffer_end_ += data_length;
+    return data + data_length;
+}
+
+void AeadModeOperation::ExtractTagFromBuffer() {
+    assert(buffered_data_length() >= tag_length_);
+    memcpy(tag_, buffer_.get() + buffer_end_ - tag_length_, tag_length_);
+    buffer_end_ -= tag_length_;
+}
+
+}  // namespace keymaster
diff --git a/aead_mode_operation.h b/aead_mode_operation.h
new file mode 100644
index 0000000..02a30ab
--- /dev/null
+++ b/aead_mode_operation.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright 2014 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SYSTEM_KEYMASTER_AEAD_MODE_OPERATION_H_
+#define SYSTEM_KEYMASTER_AEAD_MODE_OPERATION_H_
+
+#include "operation.h"
+
+namespace keymaster {
+
+class AeadModeOperation : public Operation {
+  public:
+    static const size_t MAX_NONCE_LENGTH = 12;
+    static const size_t MAX_TAG_LENGTH = 16;
+    static const size_t MAX_KEY_LENGTH = 32;
+
+    AeadModeOperation(keymaster_purpose_t purpose, const Logger& logger, uint8_t* key,
+                      size_t key_size, size_t chunk_length, size_t tag_length, size_t nonce_length,
+                      keymaster_blob_t additional_data)
+        : Operation(purpose, logger), key_size_(key_size), tag_length_(tag_length),
+          nonce_length_(nonce_length),
+          processing_unit_(purpose == KM_PURPOSE_DECRYPT ? chunk_length + tag_length
+                                                         : chunk_length),
+          additional_data_(additional_data), nonce_handled_(false) {
+
+        assert(key_size <= MAX_KEY_LENGTH);
+        memcpy(key_, key, key_size);
+    }
+    ~AeadModeOperation() {
+        // Wipe sensitive buffers.
+        memset_s(buffer_.get(), 0, processing_unit_);
+        memset_s(const_cast<uint8_t*>(additional_data_.data), 0, additional_data_.data_length);
+        memset_s(key_, 0, MAX_KEY_LENGTH);
+        delete[] additional_data_.data;
+    }
+
+    virtual keymaster_error_t Begin();
+    virtual keymaster_error_t Update(const Buffer& input, Buffer* output, size_t* input_consumed);
+    virtual keymaster_error_t Finish(const Buffer& /* signature */, Buffer* output);
+
+  protected:
+    size_t buffered_data_length() const { return buffer_end_; }
+    const uint8_t* key() const { return key_; }
+    size_t key_size() const { return key_size_; }
+
+  private:
+    /*
+     * These methods do the actual crypto operations.
+     *
+     * TODO(swillden): Consider refactoring these to a separate class, integrating them via
+     * composition rather than inheritance.
+     */
+    virtual keymaster_error_t Initialize(uint8_t* key, size_t key_size, size_t nonce_length,
+                                         size_t tag_length) = 0;
+    virtual keymaster_error_t EncryptChunk(const uint8_t* nonce, size_t nonce_length,
+                                           size_t tag_length,
+                                           const keymaster_blob_t additional_data, uint8_t* chunk,
+                                           size_t chunk_size, Buffer* output) = 0;
+    virtual keymaster_error_t DecryptChunk(const uint8_t* nonce, size_t nonce_length,
+                                           const uint8_t* tag, size_t tag_length,
+                                           const keymaster_blob_t additional_data, uint8_t* chunk,
+                                           size_t chunk_size, Buffer* output) = 0;
+
+    size_t EstimateOutputSize(const Buffer& input, Buffer* output);
+    keymaster_error_t ProcessChunk(Buffer* output);
+
+    size_t buffer_free_space() const { return processing_unit_ - buffer_end_; }
+
+    const uint8_t* AppendToBuffer(const uint8_t* data, size_t data_length);
+    void ExtractNonceFromBuffer();
+    void ExtractTagFromBuffer();
+    void ClearBuffer() { buffer_end_ = 0; }
+    keymaster_error_t HandleNonce(Buffer* output);
+    void IncrementNonce();
+
+    const size_t key_size_;
+    const size_t tag_length_;
+    const size_t nonce_length_;
+    const size_t processing_unit_;
+    const keymaster_blob_t additional_data_;
+    UniquePtr<uint8_t[]> buffer_;
+    size_t buffer_end_;
+    bool nonce_handled_;
+    uint8_t __attribute__((aligned(16))) key_[MAX_KEY_LENGTH];
+    uint8_t __attribute__((aligned(16))) tag_[MAX_TAG_LENGTH];
+    uint8_t __attribute__((aligned(16))) nonce_[MAX_NONCE_LENGTH];
+};
+
+}  // namespace keymaster
+
+#endif  // SYSTEM_KEYMASTER_AEAD_MODE_OPERATION_H_
diff --git a/aes_key.cpp b/aes_key.cpp
index 5bd1f8a..4ea13e5 100644
--- a/aes_key.cpp
+++ b/aes_key.cpp
@@ -64,7 +64,7 @@
         return NULL;
     }
 
-    // Check required for some modes.
+    // Mac required for some modes.
     uint32_t mac_length;
     if (mac_length_required(block_mode)) {
         if (!authorizations.GetTagValue(TAG_MAC_LENGTH, &mac_length) ||
@@ -102,9 +102,9 @@
         switch (purpose) {
         case KM_PURPOSE_SIGN:
         case KM_PURPOSE_VERIFY:
-            if (block_mode < KM_MODE_FIRST_AUTHENTICATED) {
-                logger.error("Only MACing or authenticated modes are supported for signing and "
-                             "verification purposes.");
+            if (block_mode < KM_MODE_FIRST_MAC) {
+                logger.error("Only MACing modes are supported for signing and verification "
+                             "purposes.");
                 return false;
             }
             break;
@@ -182,8 +182,9 @@
     Operation* op = NULL;
     switch (purpose) {
     case KM_PURPOSE_ENCRYPT:
-        op = new AesOcbEncryptOperation(logger_, key_data_, key_data_size_, chunk_length,
-                                        tag_length, additional_data);
+    case KM_PURPOSE_DECRYPT:
+        op = new AesOcbOperation(purpose, logger_, key_data_, key_data_size_, chunk_length,
+                                 tag_length, additional_data);
         break;
     default:
         *error = KM_ERROR_UNSUPPORTED_PURPOSE;
diff --git a/aes_operation.cpp b/aes_operation.cpp
index 346c857..0c7a86c 100644
--- a/aes_operation.cpp
+++ b/aes_operation.cpp
@@ -14,6 +14,8 @@
  * limitations under the License.
  */
 
+#include <stdio.h>
+
 #include <openssl/aes.h>
 #include <openssl/rand.h>
 
@@ -21,85 +23,58 @@
 
 namespace keymaster {
 
-keymaster_error_t AesOcbEncryptOperation::Begin() {
-    chunk_.reset(new uint8_t[chunk_length_]);
-    if (!chunk_.get())
-        return KM_ERROR_MEMORY_ALLOCATION_FAILED;
+keymaster_error_t AesOcbOperation::Initialize(uint8_t* key, size_t key_size, size_t nonce_length,
+                                              size_t tag_length) {
+    if (tag_length > MAX_TAG_LENGTH ||nonce_length > MAX_NONCE_LENGTH)
+        return KM_ERROR_INVALID_KEY_BLOB;
 
-    if (!RAND_bytes(nonce_, NONCE_LENGTH))
-        return KM_ERROR_UNKNOWN_ERROR;
-
-    if (ae_init(ctx_.get(), key_, key_size_, array_size(nonce_), tag_length_) != AE_SUCCESS) {
-        memset_s(ctx_.get(), 0, ae_ctx_sizeof());
+    if (ae_init(ctx(), key, key_size, nonce_length, tag_length) != AE_SUCCESS) {
+        memset_s(ctx(), 0, ae_ctx_sizeof());
         return KM_ERROR_UNKNOWN_ERROR;
     }
-
     return KM_ERROR_OK;
 }
 
-keymaster_error_t AesOcbEncryptOperation::Update(const Buffer& input, Buffer* output,
-                                                 size_t* input_consumed) {
-    const uint8_t* plaintext = input.peek_read();
-    const uint8_t* plaintext_end = plaintext + input.available_read();
+keymaster_error_t AesOcbOperation::EncryptChunk(const uint8_t* nonce, size_t /* nonce_length */,
+                                                size_t tag_length,
+                                                const keymaster_blob_t additional_data,
+                                                uint8_t* chunk, size_t chunk_size, Buffer* output) {
+    if (!ctx())
+        return KM_ERROR_UNKNOWN_ERROR;
+    uint8_t __attribute__((aligned(16))) tag[MAX_TAG_LENGTH];
 
-    while (plaintext + chunk_unfilled_space() < plaintext_end) {
-        size_t to_process = chunk_unfilled_space();
-        memcpy(chunk_.get() + chunk_offset_, plaintext, to_process);
-        chunk_offset_ += to_process;
-        assert(chunk_offset_ == chunk_length_);
+    // Encrypt chunk in place.
+    int ae_err = ae_encrypt(ctx(), nonce, chunk, chunk_size, additional_data.data,
+                            additional_data.data_length, chunk, tag, AE_FINALIZE);
 
-        keymaster_error_t error = ProcessChunk(output);
-        if (error != KM_ERROR_OK)
-            return error;
-        plaintext += to_process;
-    }
-
-    // Copy remaining data into chunk_.
-    assert(plaintext_end - plaintext < chunk_unfilled_space());
-    memcpy(chunk_.get() + chunk_offset_, plaintext, plaintext_end - plaintext);
-    chunk_offset_ += (plaintext_end - plaintext);
-
-    *input_consumed = input.available_read();
-    return KM_ERROR_OK;
-}
-
-keymaster_error_t AesOcbEncryptOperation::Finish(const Buffer& /* signature */, Buffer* output) {
-    keymaster_error_t error = KM_ERROR_OK;
-    if (chunk_offset_ > 0)
-        error = ProcessChunk(output);
-    return error;
-}
-
-keymaster_error_t AesOcbEncryptOperation::ProcessChunk(Buffer* output) {
-    if (!nonce_written_) {
-        if (!output->reserve(NONCE_LENGTH + chunk_length_ + tag_length_))
-            return KM_ERROR_MEMORY_ALLOCATION_FAILED;
-        output->write(nonce_, NONCE_LENGTH);
-        nonce_written_ = true;
-    } else {
-        IncrementNonce();
-    }
-
-    if (!output->reserve(output->available_read() + chunk_offset_ + tag_length_))
-        return KM_ERROR_MEMORY_ALLOCATION_FAILED;
-
-    int ae_err = ae_encrypt(ctx_.get(), nonce_, chunk_.get(), chunk_offset_, additional_data_.data,
-                            additional_data_.data_length, output->peek_write(), tag_, AE_FINALIZE);
     if (ae_err < 0)
         return KM_ERROR_UNKNOWN_ERROR;
-    output->advance_write(chunk_offset_);
-    chunk_offset_ = 0;
+    assert(ae_err == (int)buffered_data_length());
 
-    // Output the tag.
-    output->write(tag_, tag_length_);
+    output->write(chunk, buffered_data_length());
+    output->write(tag, tag_length);
 
     return KM_ERROR_OK;
 }
 
-void AesOcbEncryptOperation::IncrementNonce() {
-    for (int i = NONCE_LENGTH - 1; i > 0; --i)
-        if (++nonce_[i])
-            break;
+keymaster_error_t AesOcbOperation::DecryptChunk(const uint8_t* nonce, size_t /* nonce_length */,
+                                                const uint8_t* tag, size_t /* tag_length */,
+                                                const keymaster_blob_t additional_data,
+                                                uint8_t* chunk, size_t chunk_size, Buffer* output) {
+    if (!ctx())
+        return KM_ERROR_UNKNOWN_ERROR;
+
+    // Decrypt chunk in place
+    int ae_err = ae_decrypt(ctx(), nonce, chunk, chunk_size, additional_data.data,
+                            additional_data.data_length, chunk, tag, AE_FINALIZE);
+    if (ae_err == AE_INVALID)
+        return KM_ERROR_VERIFICATION_FAILED;
+    else if (ae_err < 0)
+        return KM_ERROR_UNKNOWN_ERROR;
+    assert(ae_err == (int)buffered_data_length());
+    output->write(chunk, chunk_size);
+
+    return KM_ERROR_OK;
 }
 
 }  // namespace keymaster
diff --git a/aes_operation.h b/aes_operation.h
index 774e3ed..0c9168a 100644
--- a/aes_operation.h
+++ b/aes_operation.h
@@ -17,61 +17,44 @@
 #ifndef SYSTEM_KEYMASTER_AES_OPERATION_H_
 #define SYSTEM_KEYMASTER_AES_OPERATION_H_
 
-#include <keymaster/key_blob.h>
-
+#include "aead_mode_operation.h"
 #include "ocb_utils.h"
 #include "operation.h"
 
 namespace keymaster {
 
-class AesOcbEncryptOperation : public Operation {
+class AesOcbOperation : public AeadModeOperation {
   public:
     static const size_t NONCE_LENGTH = 12;
     static const size_t MAX_TAG_LENGTH = 16;
     static const size_t MAX_KEY_LENGTH = 32;
 
-    AesOcbEncryptOperation(const Logger& logger, uint8_t* key, size_t key_size, size_t chunk_length,
-                           size_t tag_length, keymaster_blob_t additional_data)
-        : Operation(KM_PURPOSE_ENCRYPT, logger), key_size_(key_size), chunk_length_(chunk_length),
-          chunk_offset_(0), tag_length_(tag_length), additional_data_(additional_data),
-          nonce_written_(false) {
-        assert(key_size <= MAX_KEY_LENGTH);
-        memcpy(key_, key, key_size);
-    }
-    ~AesOcbEncryptOperation() {
-        // Wipe sensitive buffers.
-        memset_s(chunk_.get(), 0, chunk_length_);
-        memset_s(const_cast<uint8_t*>(additional_data_.data), 0, additional_data_.data_length);
-        memset_s(key_, 0, MAX_KEY_LENGTH);
-        memset_s(tag_, 0, MAX_TAG_LENGTH);
-        delete[] additional_data_.data;
+    AesOcbOperation(keymaster_purpose_t purpose, const Logger& logger, uint8_t* key,
+                    size_t key_size, size_t chunk_length, size_t tag_length,
+                    keymaster_blob_t additional_data)
+        : AeadModeOperation(purpose, logger, key, key_size, chunk_length, tag_length, NONCE_LENGTH,
+                            additional_data) {}
+
+    virtual keymaster_error_t Abort() {
+        /* All cleanup is in the dtor */
+        return KM_ERROR_OK;
     }
 
-    virtual keymaster_error_t Begin();
-    virtual keymaster_error_t Update(const Buffer& input, Buffer* output, size_t* input_consumed);
-    virtual keymaster_error_t Finish(const Buffer& /* signature */, Buffer* output);
-    virtual keymaster_error_t Abort() { return KM_ERROR_UNIMPLEMENTED; }
+  protected:
+    ae_ctx* ctx() { return ctx_.get(); }
 
   private:
-    ptrdiff_t chunk_unfilled_space() { return chunk_length_ - chunk_offset_; }
-
-    keymaster_error_t StartIncrementalEncryption();
-    keymaster_error_t DoIncrementalEncryption(const uint8_t* input, size_t input_size,
-                                              Buffer* output, size_t* input_consumed);
-    keymaster_error_t ProcessChunk(Buffer* output);
-    void IncrementNonce();
-
+    virtual keymaster_error_t Initialize(uint8_t* key, size_t key_size, size_t nonce_length,
+                                         size_t tag_length);
+    virtual keymaster_error_t EncryptChunk(const uint8_t* nonce, size_t nonce_length,
+                                           size_t tag_length,
+                                           const keymaster_blob_t additional_data, uint8_t* chunk,
+                                           size_t chunk_size, Buffer* output);
+    virtual keymaster_error_t DecryptChunk(const uint8_t* nonce, size_t nonce_length,
+                                           const uint8_t* tag, size_t tag_length,
+                                           const keymaster_blob_t additional_data, uint8_t* chunk,
+                                           size_t chunk_size, Buffer* output);
     AeCtx ctx_;
-    size_t key_size_;
-    size_t chunk_length_;
-    UniquePtr<uint8_t[]> chunk_;
-    size_t chunk_offset_;
-    size_t tag_length_;
-    keymaster_blob_t additional_data_;
-    uint8_t __attribute__((aligned(16))) key_[MAX_KEY_LENGTH];
-    uint8_t __attribute__((aligned(16))) nonce_[NONCE_LENGTH];
-    bool nonce_written_;
-    uint8_t __attribute__((aligned(16))) tag_[MAX_TAG_LENGTH];
 };
 
 }  // namespace keymaster
diff --git a/google_keymaster_test.cpp b/google_keymaster_test.cpp
index 6c672be..f5a9f6b 100644
--- a/google_keymaster_test.cpp
+++ b/google_keymaster_test.cpp
@@ -99,6 +99,8 @@
         blob_.key_material = NULL;
     }
 
+    const keymaster_key_blob_t& key_blob() { return blob_; }
+
     SoftKeymasterDevice device_;
 
     AuthorizationSet params_;
@@ -923,6 +925,7 @@
   protected:
     void GenerateKey(keymaster_algorithm_t algorithm, keymaster_padding_t padding,
                      uint32_t key_size) {
+        params_.Clear();
         params_.push_back(Authorization(TAG_PURPOSE, KM_PURPOSE_ENCRYPT));
         params_.push_back(Authorization(TAG_PURPOSE, KM_PURPOSE_DECRYPT));
         params_.push_back(Authorization(TAG_ALGORITHM, algorithm));
@@ -933,12 +936,16 @@
         params_.push_back(Authorization(TAG_AUTH_TIMEOUT, 300));
         if (static_cast<int>(padding) != -1)
             params_.push_back(TAG_PADDING, padding);
+
+        FreeKeyBlob();
+        FreeCharacteristics();
         EXPECT_EQ(KM_ERROR_OK, device()->generate_key(device(), params_.data(), params_.size(),
                                                       &blob_, &characteristics_));
     }
 
     void GenerateSymmetricKey(keymaster_algorithm_t algorithm, uint32_t key_size,
                               keymaster_block_mode_t block_mode, uint32_t chunk_length) {
+        params_.Clear();
         params_.push_back(Authorization(TAG_PURPOSE, KM_PURPOSE_ENCRYPT));
         params_.push_back(Authorization(TAG_PURPOSE, KM_PURPOSE_DECRYPT));
         params_.push_back(Authorization(TAG_ALGORITHM, algorithm));
@@ -951,6 +958,8 @@
         params_.push_back(Authorization(TAG_APPLICATION_ID, "app_id", 6));
         params_.push_back(Authorization(TAG_AUTH_TIMEOUT, 300));
 
+        FreeKeyBlob();
+        FreeCharacteristics();
         EXPECT_EQ(KM_ERROR_OK, device()->generate_key(device(), params_.data(), params_.size(),
                                                       &blob_, &characteristics_));
     }
@@ -970,7 +979,7 @@
             device()->update(device(), op_handle, reinterpret_cast<const uint8_t*>(message), size,
                              input_consumed, &out_tmp, &out_length);
         if (out_tmp)
-            *output = string(reinterpret_cast<char*>(out_tmp), out_length);
+            output->append(reinterpret_cast<char*>(out_tmp), out_length);
         free(out_tmp);
         return error;
     }
@@ -981,7 +990,7 @@
         keymaster_error_t error = device()->finish(device(), op_handle, NULL /* signature */,
                                                    0 /* signature_length */, &out_tmp, &out_length);
         if (out_tmp)
-            *output = string(reinterpret_cast<char*>(out_tmp), out_length);
+            output->append(reinterpret_cast<char*>(out_tmp), out_length);
         free(out_tmp);
         return error;
     }
@@ -996,16 +1005,16 @@
         EXPECT_EQ(KM_ERROR_OK, UpdateOperation(op_handle, message, size, &result, &input_consumed));
         EXPECT_EQ(size, input_consumed);
         EXPECT_EQ(KM_ERROR_OK, FinishOperation(op_handle, &result));
-
+        EXPECT_EQ(KM_ERROR_INVALID_OPERATION_HANDLE, device()->abort(device(), op_handle));
         return result;
     }
 
-    string EncryptMessage(const void* message, size_t size) {
-        return ProcessMessage(KM_PURPOSE_ENCRYPT, blob_, message, size);
+    string EncryptMessage(const string& message) {
+        return ProcessMessage(KM_PURPOSE_ENCRYPT, blob_, message.c_str(), message.length());
     }
 
-    string DecryptMessage(const void* ciphertext, size_t size) {
-        return ProcessMessage(KM_PURPOSE_DECRYPT, blob_, ciphertext, size);
+    string DecryptMessage(const string& ciphertext) {
+        return ProcessMessage(KM_PURPOSE_DECRYPT, blob_, ciphertext.c_str(), ciphertext.length());
     }
 
     void AddClientParams(AuthorizationSet* set) { set->push_back(TAG_APPLICATION_ID, "app_id", 6); }
@@ -1027,10 +1036,10 @@
 TEST_F(EncryptionOperationsTest, RsaOaepSuccess) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_OAEP, 512);
     const char message[] = "Hello World!";
-    string ciphertext1 = EncryptMessage(message, strlen(message));
+    string ciphertext1 = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext1.size());
 
-    string ciphertext2 = EncryptMessage(message, strlen(message));
+    string ciphertext2 = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext2.size());
 
     // OAEP randomizes padding so every result should be different.
@@ -1040,10 +1049,10 @@
 TEST_F(EncryptionOperationsTest, RsaOaepRoundTrip) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_OAEP, 512);
     const char message[] = "Hello World!";
-    string ciphertext = EncryptMessage(message, strlen(message));
+    string ciphertext = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext.size());
 
-    string plaintext = DecryptMessage(ciphertext.data(), ciphertext.size());
+    string plaintext = DecryptMessage(ciphertext);
     EXPECT_EQ(message, plaintext);
 }
 
@@ -1064,7 +1073,7 @@
 TEST_F(EncryptionOperationsTest, RsaOaepCorruptedDecrypt) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_OAEP, 512);
     const char message[] = "Hello World!";
-    string ciphertext = EncryptMessage(message, strlen(message));
+    string ciphertext = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext.size());
 
     // Corrupt the ciphertext
@@ -1083,10 +1092,10 @@
 TEST_F(EncryptionOperationsTest, RsaPkcs1Success) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_PKCS1_1_5_ENCRYPT, 512);
     const char message[] = "Hello World!";
-    string ciphertext1 = EncryptMessage(message, strlen(message));
+    string ciphertext1 = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext1.size());
 
-    string ciphertext2 = EncryptMessage(message, strlen(message));
+    string ciphertext2 = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext2.size());
 
     // PKCS1 v1.5 randomizes padding so every result should be different.
@@ -1096,10 +1105,10 @@
 TEST_F(EncryptionOperationsTest, RsaPkcs1RoundTrip) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_PKCS1_1_5_ENCRYPT, 512);
     const char message[] = "Hello World!";
-    string ciphertext = EncryptMessage(message, strlen(message));
+    string ciphertext = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext.size());
 
-    string plaintext = DecryptMessage(ciphertext.data(), ciphertext.size());
+    string plaintext = DecryptMessage(ciphertext);
     EXPECT_EQ(message, plaintext);
 }
 
@@ -1120,7 +1129,7 @@
 TEST_F(EncryptionOperationsTest, RsaPkcs1CorruptedDecrypt) {
     GenerateKey(KM_ALGORITHM_RSA, KM_PAD_RSA_PKCS1_1_5_ENCRYPT, 512);
     const char message[] = "Hello World!";
-    string ciphertext = EncryptMessage(message, strlen(message));
+    string ciphertext = EncryptMessage(string(message));
     EXPECT_EQ(512 / 8, ciphertext.size());
 
     // Corrupt the ciphertext
@@ -1139,15 +1148,146 @@
 TEST_F(EncryptionOperationsTest, AesOcbSuccess) {
     GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
     const char message[] = "Hello World!";
-    string ciphertext1 = EncryptMessage(message, strlen(message));
+    string ciphertext1 = EncryptMessage(string(message));
     EXPECT_EQ(12 /* nonce */ + strlen(message) + 16 /* tag */, ciphertext1.size());
 
-    string ciphertext2 = EncryptMessage(message, strlen(message));
+    string ciphertext2 = EncryptMessage(string(message));
     EXPECT_EQ(12 /* nonce */ + strlen(message) + 16 /* tag */, ciphertext2.size());
 
     // OCB uses a random nonce, so every output should be different
     EXPECT_NE(ciphertext1, ciphertext2);
 }
 
+TEST_F(EncryptionOperationsTest, AesOcbRoundTripSuccess) {
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    string message = "Hello World!";
+    string ciphertext = EncryptMessage(message);
+    EXPECT_EQ(12 /* nonce */ + message.length() + 16 /* tag */, ciphertext.size());
+
+    string plaintext = DecryptMessage(ciphertext);
+    EXPECT_EQ(message, plaintext);
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbRoundTripCorrupted) {
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    const char message[] = "Hello World!";
+    string ciphertext = EncryptMessage(string(message));
+    EXPECT_EQ(12 /* nonce */ + strlen(message) + 16 /* tag */, ciphertext.size());
+
+    ciphertext[ciphertext.size() / 2]++;
+
+    uint64_t op_handle;
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_DECRYPT, key_blob(), &op_handle));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(op_handle, ciphertext.c_str(), ciphertext.length(),
+                                           &result, &input_consumed));
+    EXPECT_EQ(ciphertext.length(), input_consumed);
+    EXPECT_EQ(KM_ERROR_VERIFICATION_FAILED, FinishOperation(op_handle, &result));
+}
+
+TEST_F(EncryptionOperationsTest, AesDecryptGarbage) {
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    string ciphertext(128, 'a');
+    uint64_t op_handle;
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_DECRYPT, key_blob(), &op_handle));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(op_handle, ciphertext.c_str(), ciphertext.length(),
+                                           &result, &input_consumed));
+    EXPECT_EQ(ciphertext.length(), input_consumed);
+    EXPECT_EQ(KM_ERROR_VERIFICATION_FAILED, FinishOperation(op_handle, &result));
+}
+
+TEST_F(EncryptionOperationsTest, AesDecryptTooShort) {
+    // Try decrypting garbage ciphertext that is too short to be valid (< nonce + tag).
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    string ciphertext(12 + 15, 'a');
+    uint64_t op_handle;
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_DECRYPT, key_blob(), &op_handle));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(op_handle, ciphertext.c_str(), ciphertext.length(),
+                                           &result, &input_consumed));
+    EXPECT_EQ(ciphertext.length(), input_consumed);
+    EXPECT_EQ(KM_ERROR_INVALID_INPUT_LENGTH, FinishOperation(op_handle, &result));
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbRoundTripEmptySuccess) {
+    // Empty messages should work fine.
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    const char message[] = "";
+    string ciphertext = EncryptMessage(string(message));
+    EXPECT_EQ(12 /* nonce */ + strlen(message) + 16 /* tag */, ciphertext.size());
+
+    string plaintext = DecryptMessage(ciphertext);
+    EXPECT_EQ(message, plaintext);
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbRoundTripEmptyCorrupted) {
+    // Should even detect corruption of empty messages.
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    const char message[] = "";
+    string ciphertext = EncryptMessage(string(message));
+    EXPECT_EQ(12 /* nonce */ + strlen(message) + 16 /* tag */, ciphertext.size());
+
+    ciphertext[ciphertext.size() / 2]++;
+
+    uint64_t op_handle;
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_DECRYPT, key_blob(), &op_handle));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(op_handle, ciphertext.c_str(), ciphertext.length(),
+                                           &result, &input_consumed));
+    EXPECT_EQ(ciphertext.length(), input_consumed);
+    EXPECT_EQ(KM_ERROR_VERIFICATION_FAILED, FinishOperation(op_handle, &result));
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbFullChunk) {
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    string message(4096, 'a');
+    string ciphertext = EncryptMessage(message);
+    EXPECT_EQ(12 /* nonce */ + message.length() + 16 /* tag */, ciphertext.size());
+
+    string plaintext = DecryptMessage(ciphertext);
+    EXPECT_EQ(message, plaintext);
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbVariousChunkLengths) {
+    for (unsigned chunk_length = 1; chunk_length <= 128; ++chunk_length) {
+        GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, chunk_length);
+        string message(128, 'a');
+        string ciphertext = EncryptMessage(message);
+        int expected_tag_count = (message.length() + chunk_length - 1) / chunk_length;
+        EXPECT_EQ(12 /* nonce */ + message.length() + 16 * expected_tag_count, ciphertext.size())
+            << "Unexpected ciphertext size for chunk length " << chunk_length
+            << " expected tag count was " << expected_tag_count
+            << " but actual tag count was probably "
+            << (ciphertext.size() - message.length() - 12) / 16;
+
+        string plaintext = DecryptMessage(ciphertext);
+        EXPECT_EQ(message, plaintext);
+    }
+}
+
+TEST_F(EncryptionOperationsTest, AesOcbAbort) {
+    GenerateSymmetricKey(KM_ALGORITHM_AES, 128, KM_MODE_OCB, 4096);
+    const char message[] = "Hello";
+
+    uint64_t op_handle;
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_ENCRYPT, key_blob(), &op_handle));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK,
+              UpdateOperation(op_handle, message, strlen(message), &result, &input_consumed));
+    EXPECT_EQ(strlen(message), input_consumed);
+    EXPECT_EQ(KM_ERROR_OK, device()->abort(device(), op_handle));
+}
+
 }  // namespace test
 }  // namespace keymaster