Validate input sizes for RSA and ECDSA signing/verification ops.

Bug: 21955742
Change-Id: I4385a6539229b174facd5f04ce0391e2e8c3608d
diff --git a/android_keymaster_test.cpp b/android_keymaster_test.cpp
index 542d430..2005e46 100644
--- a/android_keymaster_test.cpp
+++ b/android_keymaster_test.cpp
@@ -485,6 +485,40 @@
         EXPECT_EQ(3, GetParam()->keymaster0_calls());
 }
 
+TEST_P(SigningOperationsTest, RsaPkcs1NoDigestSuccess) {
+    ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
+                                           .RsaSigningKey(512, 3)
+                                           .Digest(KM_DIGEST_NONE)
+                                           .Padding(KM_PAD_RSA_PKCS1_1_5_SIGN)));
+    string message(53, 'a');
+    string signature;
+    SignMessage(message, &signature, KM_DIGEST_NONE, KM_PAD_RSA_PKCS1_1_5_SIGN);
+
+    if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
+        EXPECT_EQ(3, GetParam()->keymaster0_calls());
+}
+
+TEST_P(SigningOperationsTest, RsaPkcs1NoDigestTooLarge) {
+    ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
+                                           .RsaSigningKey(512, 3)
+                                           .Digest(KM_DIGEST_NONE)
+                                           .Padding(KM_PAD_RSA_PKCS1_1_5_SIGN)));
+    string message(54, 'a');
+
+    AuthorizationSet begin_params(client_params());
+    begin_params.push_back(TAG_DIGEST, KM_DIGEST_NONE);
+    begin_params.push_back(TAG_PADDING, KM_PAD_RSA_PKCS1_1_5_SIGN);
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_SIGN, begin_params));
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(message, &result, &input_consumed));
+    string signature;
+    EXPECT_EQ(KM_ERROR_INVALID_INPUT_LENGTH, FinishOperation(&signature));
+
+    if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
+        EXPECT_EQ(2, GetParam()->keymaster0_calls());
+}
+
 TEST_P(SigningOperationsTest, RsaPssSha256TooSmallKey) {
     // Key must be at least 10 bytes larger than hash, to provide eight bytes of random salt, so
     // verify that nine bytes larger than hash won't work.
@@ -501,6 +535,25 @@
     EXPECT_EQ(KM_ERROR_INCOMPATIBLE_DIGEST, BeginOperation(KM_PURPOSE_SIGN, begin_params));
 }
 
+TEST_P(SigningOperationsTest, RsaNoPaddingHugeData) {
+    ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
+                                           .RsaSigningKey(256, 3)
+                                           .Digest(KM_DIGEST_NONE)
+                                           .Padding(KM_PAD_RSA_PKCS1_1_5_SIGN)));
+    string message(64 * 1024, 'a');
+    string signature;
+    AuthorizationSet begin_params(client_params());
+    begin_params.push_back(TAG_DIGEST, KM_DIGEST_NONE);
+    begin_params.push_back(TAG_PADDING, KM_PAD_RSA_PKCS1_1_5_SIGN);
+    ASSERT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_SIGN, begin_params));
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_INVALID_INPUT_LENGTH, UpdateOperation(message, &result, &input_consumed));
+
+    if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
+        EXPECT_EQ(2, GetParam()->keymaster0_calls());
+}
+
 TEST_P(SigningOperationsTest, RsaAbort) {
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
                                            .RsaSigningKey(256, 3)
@@ -587,7 +640,7 @@
     EXPECT_EQ(31U, input_consumed);
 
     string signature;
-    ASSERT_EQ(KM_ERROR_UNKNOWN_ERROR, FinishOperation(&signature));
+    ASSERT_EQ(KM_ERROR_INVALID_INPUT_LENGTH, FinishOperation(&signature));
     EXPECT_EQ(0U, signature.length());
 
     if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
@@ -608,7 +661,7 @@
 TEST_P(SigningOperationsTest, EcdsaSuccess) {
     ASSERT_EQ(KM_ERROR_OK,
               GenerateKey(AuthorizationSetBuilder().EcdsaSigningKey(224).Digest(KM_DIGEST_NONE)));
-    string message(1024, 'a');
+    string message(224 / 8, 'a');
     string signature;
     SignMessage(message, &signature, KM_DIGEST_NONE);
 
@@ -627,6 +680,22 @@
         EXPECT_EQ(3, GetParam()->keymaster0_calls());
 }
 
+TEST_P(SigningOperationsTest, EcdsaNoPaddingHugeData) {
+    ASSERT_EQ(KM_ERROR_OK,
+              GenerateKey(AuthorizationSetBuilder().EcdsaSigningKey(224).Digest(KM_DIGEST_NONE)));
+    string message(64 * 1024, 'a');
+    string signature;
+    AuthorizationSet begin_params(client_params());
+    begin_params.push_back(TAG_DIGEST, KM_DIGEST_NONE);
+    ASSERT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_SIGN, begin_params));
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_INVALID_INPUT_LENGTH, UpdateOperation(message, &result, &input_consumed));
+
+    if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_EC))
+        EXPECT_EQ(2, GetParam()->keymaster0_calls());
+}
+
 TEST_P(SigningOperationsTest, AesEcbSign) {
     ASSERT_EQ(KM_ERROR_OK,
               GenerateKey(AuthorizationSetBuilder().AesEncryptionKey(128).Authorization(
@@ -1227,7 +1296,7 @@
 TEST_P(VerificationOperationsTest, EcdsaSuccess) {
     ASSERT_EQ(KM_ERROR_OK,
               GenerateKey(AuthorizationSetBuilder().EcdsaSigningKey(256).Digest(KM_DIGEST_NONE)));
-    string message = "123456789012345678901234567890123456789012345678";
+    string message = "12345678901234567890123456789012";
     string signature;
     SignMessage(message, &signature, KM_DIGEST_NONE);
     VerifyMessage(message, signature, KM_DIGEST_NONE);
@@ -1241,7 +1310,7 @@
                                            .EcdsaSigningKey(256)
                                            .Digest(KM_DIGEST_SHA_2_256)
                                            .Digest(KM_DIGEST_NONE)));
-    string message = "123456789012345678901234567890123456789012345678";
+    string message = "12345678901234567890123456789012";
     string signature;
     SignMessage(message, &signature, KM_DIGEST_SHA_2_256);
     VerifyMessage(message, signature, KM_DIGEST_SHA_2_256);
@@ -1496,7 +1565,7 @@
     EXPECT_TRUE(contains(sw_enforced(), TAG_ORIGIN, KM_ORIGIN_IMPORTED));
     EXPECT_TRUE(contains(sw_enforced(), KM_TAG_CREATION_DATETIME));
 
-    string message(1024 / 8, 'a');
+    string message(32, 'a');
     string signature;
     SignMessage(message, &signature, KM_DIGEST_NONE);
     VerifyMessage(message, signature, KM_DIGEST_NONE);
@@ -1525,7 +1594,7 @@
     EXPECT_TRUE(contains(sw_enforced(), TAG_ORIGIN, KM_ORIGIN_IMPORTED));
     EXPECT_TRUE(contains(sw_enforced(), KM_TAG_CREATION_DATETIME));
 
-    string message(1024 / 8, 'a');
+    string message(32, 'a');
     string signature;
     SignMessage(message, &signature, KM_DIGEST_NONE);
     VerifyMessage(message, signature, KM_DIGEST_NONE);
@@ -1638,9 +1707,7 @@
 
     string result;
     size_t input_consumed;
-    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(message, &result, &input_consumed));
-    EXPECT_EQ(KM_ERROR_INVALID_INPUT_LENGTH, FinishOperation(&result));
-    EXPECT_EQ(0U, result.size());
+    EXPECT_EQ(KM_ERROR_INVALID_INPUT_LENGTH, UpdateOperation(message, &result, &input_consumed));
 
     if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
         EXPECT_EQ(2, GetParam()->keymaster0_calls());
@@ -2666,7 +2733,7 @@
     memcpy(key_data, km1_sw.data(), km1_sw.length());
     set_key_blob(key_data, km1_sw.length());
 
-    string message(64, 'a');
+    string message(32, static_cast<char>(0xFF));
     string signature;
     SignMessage(message, &signature, KM_DIGEST_NONE, KM_PAD_NONE);
 
diff --git a/ecdsa_operation.cpp b/ecdsa_operation.cpp
index 89bcfa1..5833716 100644
--- a/ecdsa_operation.cpp
+++ b/ecdsa_operation.cpp
@@ -93,9 +93,15 @@
 }
 
 keymaster_error_t EcdsaOperation::StoreData(const Buffer& input, size_t* input_consumed) {
-    if (!data_.reserve(data_.available_read() + input.available_read()) ||
-        !data_.write(input.peek_read(), input.available_read()))
+    if (!data_.reserve(EVP_PKEY_bits(ecdsa_key_) / 8))
         return KM_ERROR_MEMORY_ALLOCATION_FAILED;
+
+    // If the write fails, it's because input length exceeds key size.
+    if (!data_.write(input.peek_read(), input.available_read())) {
+        LOG_E("Input too long: cannot sign %u bytes of data with %u-bit ECDSA key",
+              input.available_read() + data_.available_read(), EVP_PKEY_bits(ecdsa_key_));
+        return KM_ERROR_INVALID_INPUT_LENGTH;
+    }
     *input_consumed = input.available_read();
     return KM_ERROR_OK;
 }
diff --git a/openssl_utils.h b/openssl_utils.h
index 304c241..97c2b5a 100644
--- a/openssl_utils.h
+++ b/openssl_utils.h
@@ -39,6 +39,10 @@
     void operator()(BIGNUM* p) const { BN_free(p); }
 };
 
+struct BN_CTX_Delete {
+    void operator()(BN_CTX* p) const { BN_CTX_free(p); }
+};
+
 struct PKCS8_PRIV_KEY_INFO_Delete {
     void operator()(PKCS8_PRIV_KEY_INFO* p) const { PKCS8_PRIV_KEY_INFO_free(p); }
 };
diff --git a/rsa_key_factory.cpp b/rsa_key_factory.cpp
index dfe2ddd..c17d9e8 100644
--- a/rsa_key_factory.cpp
+++ b/rsa_key_factory.cpp
@@ -25,14 +25,10 @@
 #include "rsa_key.h"
 #include "rsa_operation.h"
 
-#if defined(OPENSSL_IS_BORINGSSL)
-typedef size_t openssl_size_t;
-#else
-typedef int openssl_size_t;
-#endif
-
 namespace keymaster {
 
+const int kMaximumRsaKeySize = 16 * 1024;  // 16kbits should be enough for anyone.
+
 static RsaSigningOperationFactory sign_factory;
 static RsaVerificationOperationFactory verify_factory;
 static RsaEncryptionOperationFactory encrypt_factory;
@@ -70,7 +66,11 @@
 
     uint32_t key_size;
     if (!authorizations.GetTagValue(TAG_KEY_SIZE, &key_size)) {
-        LOG_E("%s", "No key size specified for RSA key generation");
+        LOG_E("No key size specified for RSA key generation", 0);
+        return KM_ERROR_UNSUPPORTED_KEY_SIZE;
+    }
+    if (key_size % 8 != 0 || key_size > kMaximumRsaKeySize) {
+        LOG_E("Invalid key size of %u bits specified for RSA key generation", key_size);
         return KM_ERROR_UNSUPPORTED_KEY_SIZE;
     }
 
@@ -143,14 +143,20 @@
         return KM_ERROR_INVALID_KEY_BLOB;
     if (!updated_description->GetTagValue(TAG_RSA_PUBLIC_EXPONENT, public_exponent))
         updated_description->push_back(TAG_RSA_PUBLIC_EXPONENT, *public_exponent);
-    if (*public_exponent != BN_get_word(rsa_key->e))
+    if (*public_exponent != BN_get_word(rsa_key->e)) {
+        LOG_E("Imported public exponent (%u) does not match specified public exponent (%u)",
+              *public_exponent, BN_get_word(rsa_key->e));
         return KM_ERROR_IMPORT_PARAMETER_MISMATCH;
+    }
 
     *key_size = RSA_size(rsa_key.get()) * 8;
     if (!updated_description->GetTagValue(TAG_KEY_SIZE, key_size))
         updated_description->push_back(TAG_KEY_SIZE, *key_size);
-    if (RSA_size(rsa_key.get()) * 8 != (openssl_size_t)*key_size)
+    if (RSA_size(rsa_key.get()) * 8 != *key_size) {
+        LOG_E("Imported key size (%u bits) does not match specified key size (%u bits)",
+              RSA_size(rsa_key.get()) * 8, *key_size);
         return KM_ERROR_IMPORT_PARAMETER_MISMATCH;
+    }
 
     keymaster_algorithm_t algorithm = KM_ALGORITHM_RSA;
     if (!updated_description->GetTagValue(TAG_ALGORITHM, &algorithm))
diff --git a/rsa_keymaster0_key.cpp b/rsa_keymaster0_key.cpp
index a905f3f..f882f94 100644
--- a/rsa_keymaster0_key.cpp
+++ b/rsa_keymaster0_key.cpp
@@ -123,7 +123,7 @@
     keymaster_error_t error;
     key->reset(new (std::nothrow)
                    RsaKeymaster0Key(rsa.release(), hw_enforced, sw_enforced, engine_, &error));
-    if (!key.get())
+    if (!key->get())
         error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
 
     if (error != KM_ERROR_OK)
diff --git a/rsa_operation.cpp b/rsa_operation.cpp
index c99fbfc..7c80ec5 100644
--- a/rsa_operation.cpp
+++ b/rsa_operation.cpp
@@ -30,6 +30,13 @@
 
 namespace keymaster {
 
+const size_t kPssOverhead = 2;
+const size_t kMinPssSaltSize = 8;
+
+// Overhead for PKCS#1 v1.5 signature padding of undigested messages.  Digested messages have
+// additional overhead, for the digest algorithmIdentifier required by PKCS#1.
+const size_t kPkcs1UndigestedSignaturePaddingOverhead = 11;
+
 /* static */
 EVP_PKEY* RsaOperationFactory::GetRsaKey(const Key& key, keymaster_error_t* error) {
     const RsaKey* rsa_key = static_cast<const RsaKey*>(&key);
@@ -138,9 +145,16 @@
 
 keymaster_error_t RsaOperation::StoreData(const Buffer& input, size_t* input_consumed) {
     assert(input_consumed);
-    if (!data_.reserve(data_.available_read() + input.available_read()) ||
-        !data_.write(input.peek_read(), input.available_read()))
+
+    if (!data_.reserve(EVP_PKEY_size(rsa_key_)))
         return KM_ERROR_MEMORY_ALLOCATION_FAILED;
+    // If the write fails, it's because input length exceeds key size.
+    if (!data_.write(input.peek_read(), input.available_read())) {
+        LOG_E("Input too long: cannot operate on %u bytes of data with %u-bit RSA key",
+              input.available_read() + data_.available_read());
+        return KM_ERROR_INVALID_INPUT_LENGTH;
+    }
+
     *input_consumed = input.available_read();
     return KM_ERROR_OK;
 }
@@ -198,25 +212,22 @@
     }
 }
 
-const size_t PSS_OVERHEAD = 2;
-const size_t MIN_SALT_SIZE = 8;
-
 int RsaDigestingOperation::GetOpensslPadding(keymaster_error_t* error) {
     *error = KM_ERROR_OK;
     switch (padding_) {
     case KM_PAD_NONE:
         return RSA_NO_PADDING;
     case KM_PAD_RSA_PKCS1_1_5_SIGN:
-
         return RSA_PKCS1_PADDING;
     case KM_PAD_RSA_PSS:
         if (digest_ == KM_DIGEST_NONE) {
             *error = KM_ERROR_INCOMPATIBLE_PADDING_MODE;
             return -1;
         }
-        if (EVP_MD_size(digest_algorithm_) + PSS_OVERHEAD + MIN_SALT_SIZE >
+        if (EVP_MD_size(digest_algorithm_) + kPssOverhead + kMinPssSaltSize >
             (size_t)EVP_PKEY_size(rsa_key_)) {
-            LOG_E("%d-byte digest cannot be used with %d-byte RSA key in PSS padding mode",
+            LOG_E("Input too long: %d-byte digest cannot be used with %d-byte RSA key in PSS "
+                  "padding mode",
                   EVP_MD_size(digest_algorithm_), EVP_PKEY_size(rsa_key_));
             *error = KM_ERROR_INCOMPATIBLE_DIGEST;
             return -1;
@@ -284,6 +295,12 @@
         break;
     case KM_PAD_RSA_PKCS1_1_5_SIGN:
         // Does PKCS1 padding without digesting even make sense?  Dunno.  We'll support it.
+        if (data_.available_read() + kPkcs1UndigestedSignaturePaddingOverhead >
+            static_cast<size_t>(EVP_PKEY_size(rsa_key_))) {
+            LOG_E("Input too long: cannot sign %u-byte message with PKCS1 padding with %u-bit key",
+                  data_.available_read(), EVP_PKEY_size(rsa_key_) * 8);
+            return KM_ERROR_INVALID_INPUT_LENGTH;
+        }
         bytes_encrypted = RSA_private_encrypt(data_.available_read(), data_.peek_read(),
                                               output->peek_write(), rsa.get(), RSA_PKCS1_PADDING);
         break;
@@ -369,6 +386,11 @@
         openssl_padding = RSA_NO_PADDING;
         break;
     case KM_PAD_RSA_PKCS1_1_5_SIGN:
+        if (data_.available_read() + kPkcs1UndigestedSignaturePaddingOverhead > key_len) {
+            LOG_E("Input too long: cannot verify %u-byte message with PKCS1 padding && %u-bit key",
+                  data_.available_read(), key_len * 8);
+            return KM_ERROR_INVALID_INPUT_LENGTH;
+        }
         openssl_padding = RSA_PKCS1_PADDING;
         break;
     default: