Add support for SHA256 digests to RSA signing operations.

Change-Id: Iacca20554bef0bb3ea3c525af87c00f77df069f9
diff --git a/Makefile b/Makefile
index ad11e1b..8e282dc 100644
--- a/Makefile
+++ b/Makefile
@@ -50,6 +50,7 @@
 	key_blob.cpp \
 	key_blob_test.cpp \
 	logger.cpp \
+	openssl_err.cpp \
 	operation.cpp \
 	openssl_err.cpp \
 	rsa_key.cpp \
diff --git a/google_keymaster_test.cpp b/google_keymaster_test.cpp
index 2b055c8..53528ed 100644
--- a/google_keymaster_test.cpp
+++ b/google_keymaster_test.cpp
@@ -469,7 +469,7 @@
     keymaster_digest_t* digests;
     EXPECT_EQ(KM_ERROR_OK, device()->get_supported_digests(device(), KM_ALGORITHM_RSA,
                                                            KM_PURPOSE_SIGN, &digests, &len));
-    EXPECT_TRUE(ResponseContains({KM_DIGEST_NONE}, digests, len));
+    EXPECT_TRUE(ResponseContains({KM_DIGEST_NONE, KM_DIGEST_SHA_2_256}, digests, len));
     free(digests);
 
     EXPECT_EQ(KM_ERROR_UNSUPPORTED_ALGORITHM,
@@ -662,6 +662,15 @@
     SignMessage(message, &signature);
 }
 
+TEST_F(SigningOperationsTest, RsaSha256DigestSuccess) {
+    // Note that without padding, key size must exactly match digest size.
+    GenerateKey(ParamBuilder().RsaSigningKey(256, KM_DIGEST_SHA_2_256));
+    // Use large message, which won't work without digesting.
+    string message(1024, 'a');
+    string signature;
+    SignMessage(message, &signature);
+}
+
 TEST_F(SigningOperationsTest, EcdsaSuccess) {
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(ParamBuilder().EcdsaSigningKey(224)));
     string message = "123456789012345678901234567890123456789012345678";
@@ -773,6 +782,8 @@
     EXPECT_EQ(0U, signature.length());
 }
 
+// TODO(swillden): Add more verification failure tests.
+
 typedef KeymasterTest VerificationOperationsTest;
 TEST_F(VerificationOperationsTest, RsaSuccess) {
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(ParamBuilder().RsaSigningKey(256)));
@@ -782,6 +793,52 @@
     VerifyMessage(message, signature);
 }
 
+TEST_F(VerificationOperationsTest, RsaSha256DigestSuccess) {
+    // Note that without padding, key size must exactly match digest size.
+    GenerateKey(ParamBuilder().RsaSigningKey(256, KM_DIGEST_SHA_2_256));
+    // Use large message, which won't work without digesting.
+    string message(1024, 'a');
+    string signature;
+    SignMessage(message, &signature);
+    VerifyMessage(message, signature);
+}
+
+TEST_F(VerificationOperationsTest, RsaSha256DigestCorruptSignature) {
+    // Note that without padding, key size must exactly match digest size.
+    GenerateKey(ParamBuilder().RsaSigningKey(256, KM_DIGEST_SHA_2_256));
+    // Use large message, which won't work without digesting.
+    string message(1024, 'a');
+    string signature;
+    SignMessage(message, &signature);
+    ++signature[signature.size() / 2];
+
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_VERIFY));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(message, &result, &input_consumed));
+    EXPECT_EQ(message.size(), input_consumed);
+    EXPECT_EQ(KM_ERROR_VERIFICATION_FAILED, FinishOperation(signature, &result));
+}
+
+TEST_F(VerificationOperationsTest, RsaSha256DigestCorruptInput) {
+    // Note that without padding, key size must exactly match digest size.
+    GenerateKey(ParamBuilder().RsaSigningKey(256, KM_DIGEST_SHA_2_256));
+    // Use large message, which won't work without digesting.
+    string message(1024, 'a');
+    string signature;
+    SignMessage(message, &signature);
+    ++message[message.size() / 2];
+
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_VERIFY));
+
+    string result;
+    size_t input_consumed;
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(message, &result, &input_consumed));
+    EXPECT_EQ(message.size(), input_consumed);
+    EXPECT_EQ(KM_ERROR_VERIFICATION_FAILED, FinishOperation(signature, &result));
+}
+
 TEST_F(VerificationOperationsTest, EcdsaSuccess) {
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(ParamBuilder().EcdsaSigningKey(256)));
     string message = "123456789012345678901234567890123456789012345678";
diff --git a/rsa_key.cpp b/rsa_key.cpp
index 614dbe8..19a3394 100644
--- a/rsa_key.cpp
+++ b/rsa_key.cpp
@@ -183,7 +183,7 @@
     switch (purpose) {
     case KM_PURPOSE_SIGN:
     case KM_PURPOSE_VERIFY:
-        return digest == KM_DIGEST_NONE;
+        return digest == KM_DIGEST_NONE || digest == KM_DIGEST_SHA_2_256;
         break;
     case KM_PURPOSE_ENCRYPT:
     case KM_PURPOSE_DECRYPT:
diff --git a/rsa_operation.cpp b/rsa_operation.cpp
index a25231b..d13a420 100644
--- a/rsa_operation.cpp
+++ b/rsa_operation.cpp
@@ -91,7 +91,7 @@
     return rsa_key->key();
 }
 
-static const keymaster_digest_t supported_digests[] = {KM_DIGEST_NONE};
+static const keymaster_digest_t supported_digests[] = {KM_DIGEST_NONE, KM_DIGEST_SHA_2_256};
 static const keymaster_padding_t supported_sig_padding[] = {KM_PAD_NONE};
 
 /**
@@ -250,35 +250,96 @@
     return KM_ERROR_OK;
 }
 
+RsaDigestingOperation::RsaDigestingOperation(keymaster_purpose_t purpose, keymaster_digest_t digest,
+                                             keymaster_padding_t padding, RSA* key)
+    : RsaOperation(purpose, padding, key), digest_(digest), digest_algorithm_(NULL) {
+    EVP_MD_CTX_init(&digest_ctx_);
+}
+RsaDigestingOperation::~RsaDigestingOperation() {
+    EVP_MD_CTX_cleanup(&digest_ctx_);
+}
+
+keymaster_error_t RsaDigestingOperation::Begin(const AuthorizationSet& /* input_params */,
+                                               AuthorizationSet* /* output_params */) {
+    if (digest_ == KM_DIGEST_NONE)
+        return KM_ERROR_OK;
+
+    // TODO(swillden): Factor out EVP_MD selection.  It will be done for many operations.
+    switch (digest_) {
+    case KM_DIGEST_SHA_2_256:
+        digest_algorithm_ = EVP_sha256();
+        break;
+    default:
+        return KM_ERROR_UNSUPPORTED_DIGEST;
+    }
+
+    if (!EVP_DigestInit_ex(&digest_ctx_, digest_algorithm_, NULL /* engine */)) {
+        int err = ERR_get_error();
+        LOG_E("Failed to initialize digest: %d %s", err, ERR_error_string(err, NULL));
+        return KM_ERROR_UNKNOWN_ERROR;
+    }
+
+    return KM_ERROR_OK;
+}
+
+keymaster_error_t RsaDigestingOperation::Update(const AuthorizationSet& additional_params,
+                                                const Buffer& input, Buffer* output,
+                                                size_t* input_consumed) {
+    if (digest_ == KM_DIGEST_NONE)
+        return RsaOperation::Update(additional_params, input, output, input_consumed);
+    if (!EVP_DigestUpdate(&digest_ctx_, input.peek_read(), input.available_read())) {
+        int err = ERR_get_error();
+        LOG_E("Failed to update digest: %d %s", err, ERR_error_string(err, NULL));
+        return KM_ERROR_UNKNOWN_ERROR;
+    }
+    *input_consumed = input.available_read();
+    return KM_ERROR_OK;
+}
+
+uint8_t* RsaDigestingOperation::FinishDigest(unsigned* digest_size) {
+    assert(digest_algorithm_ != NULL);
+    UniquePtr<uint8_t[]> digest(new uint8_t[EVP_MAX_MD_SIZE]);
+    if (!EVP_DigestFinal_ex(&digest_ctx_, digest.get(), digest_size)) {
+        int err = ERR_get_error();
+        LOG_E("Failed to finalize digest: %d %s", err, ERR_error_string(err, NULL));
+        return NULL;
+    }
+    assert(*digest_size == static_cast<unsigned>(EVP_MD_size(digest_algorithm_)));
+    return digest.release();
+}
+
 keymaster_error_t RsaSignOperation::Finish(const AuthorizationSet& /* additional_params */,
                                            const Buffer& /* signature */, Buffer* output) {
     assert(output);
     output->Reinitialize(RSA_size(rsa_key_));
-    int bytes_encrypted = RSA_private_encrypt(data_.available_read(), data_.peek_read(),
-                                              output->peek_write(), rsa_key_, RSA_NO_PADDING);
+
+    int bytes_encrypted =
+        (digest_ == KM_DIGEST_NONE) ? SignUndigested(output) : SignDigested(output);
+
     if (bytes_encrypted < 0)
         return KM_ERROR_UNKNOWN_ERROR;
+
     assert(bytes_encrypted == RSA_size(rsa_key_));
     output->advance_write(bytes_encrypted);
     return KM_ERROR_OK;
 }
 
+int RsaSignOperation::SignUndigested(Buffer* output) {
+    return RSA_private_encrypt(data_.available_read(), data_.peek_read(), output->peek_write(),
+                               rsa_key_, RSA_NO_PADDING);
+}
+
+int RsaSignOperation::SignDigested(Buffer* output) {
+    unsigned digest_size = 0;
+    UniquePtr<uint8_t[]> digest(FinishDigest(&digest_size));
+    if (!digest.get())
+        return KM_ERROR_UNKNOWN_ERROR;
+    return RSA_private_encrypt(digest_size, digest.get(), output->peek_write(), rsa_key_,
+                               RSA_NO_PADDING);
+}
+
 keymaster_error_t RsaVerifyOperation::Finish(const AuthorizationSet& /* additional_params */,
                                              const Buffer& signature, Buffer* /* output */) {
-#if defined(OPENSSL_IS_BORINGSSL)
-    size_t message_size = data_.available_read();
-#else
-    if (data_.available_read() > INT_MAX)
-        return KM_ERROR_INVALID_INPUT_LENGTH;
-    int message_size = (int)data_.available_read();
-#endif
-
-    if (message_size != RSA_size(rsa_key_))
-        return KM_ERROR_INVALID_INPUT_LENGTH;
-
-    if (data_.available_read() != signature.available_read())
-        return KM_ERROR_VERIFICATION_FAILED;
-
     UniquePtr<uint8_t[]> decrypted_data(new uint8_t[RSA_size(rsa_key_)]);
     int bytes_decrypted = RSA_public_decrypt(signature.available_read(), signature.peek_read(),
                                              decrypted_data.get(), rsa_key_, RSA_NO_PADDING);
@@ -286,8 +347,40 @@
         return KM_ERROR_UNKNOWN_ERROR;
     assert(bytes_decrypted == RSA_size(rsa_key_));
 
-    if (memcmp_s(decrypted_data.get(), data_.peek_read(), data_.available_read()) == 0)
+    if (digest_ == KM_DIGEST_NONE) {
+        if (data_.available_read() != signature.available_read())
+            return KM_ERROR_VERIFICATION_FAILED;
+        return VerifyUndigested(decrypted_data.get());
+    }
+    return VerifyDigested(decrypted_data.get());
+}
+
+keymaster_error_t RsaVerifyOperation::VerifyUndigested(uint8_t* decrypted_data) {
+#if defined(OPENSSL_IS_BORINGSSL)
+    size_t message_size = data_.available_read();
+#else
+    if (data_.available_read() > INT_MAX)
+        return KM_ERROR_INVALID_INPUT_LENGTH;
+    int message_size = (int)data_.available_read();
+#endif
+    if (message_size != RSA_size(rsa_key_))
+        return KM_ERROR_INVALID_INPUT_LENGTH;
+
+    if (memcmp_s(decrypted_data, data_.peek_read(), data_.available_read()) == 0)
         return KM_ERROR_OK;
+
+    return KM_ERROR_VERIFICATION_FAILED;
+}
+
+keymaster_error_t RsaVerifyOperation::VerifyDigested(uint8_t* decrypted_data) {
+    unsigned digest_size = 0;
+    UniquePtr<uint8_t[]> digest(FinishDigest(&digest_size));
+    if (!digest.get())
+        return KM_ERROR_UNKNOWN_ERROR;
+
+    if (memcmp_s(decrypted_data, digest.get(), digest_size) == 0)
+        return KM_ERROR_OK;
+
     return KM_ERROR_VERIFICATION_FAILED;
 }
 
diff --git a/rsa_operation.h b/rsa_operation.h
index 99530f8..f890e2f 100644
--- a/rsa_operation.h
+++ b/rsa_operation.h
@@ -28,6 +28,11 @@
 
 namespace keymaster {
 
+/**
+ * Base class for all RSA operations.
+ *
+ * This class provides RSA key management, plus buffering of data for non-digesting modes.
+ */
 class RsaOperation : public Operation {
   public:
     RsaOperation(keymaster_purpose_t purpose, keymaster_padding_t padding, RSA* key)
@@ -50,28 +55,64 @@
     Buffer data_;
 };
 
-class RsaSignOperation : public RsaOperation {
+/**
+ * Base class for all RSA operations.
+ *
+ * This class adds digesting support, for digesting modes.  For non-digesting modes, it falls back
+ * on the RsaOperation input buffering.
+ */
+class RsaDigestingOperation : public RsaOperation {
+  public:
+    RsaDigestingOperation(keymaster_purpose_t purpose, keymaster_digest_t digest,
+                          keymaster_padding_t padding, RSA* key);
+    ~RsaDigestingOperation();
+
+    virtual keymaster_error_t Begin(const AuthorizationSet& input_params,
+                                    AuthorizationSet* output_params);
+    virtual keymaster_error_t Update(const AuthorizationSet& additional_params, const Buffer& input,
+                                     Buffer* output, size_t* input_consumed);
+
+  protected:
+    uint8_t* FinishDigest(unsigned* digest_size);
+
+    const keymaster_digest_t digest_;
+    const EVP_MD* digest_algorithm_;
+    EVP_MD_CTX digest_ctx_;
+};
+
+/**
+ * RSA private key signing operation.
+ */
+class RsaSignOperation : public RsaDigestingOperation {
   public:
     RsaSignOperation(keymaster_digest_t digest, keymaster_padding_t padding, RSA* key)
-        : RsaOperation(KM_PURPOSE_SIGN, padding, key), digest_(digest) {}
+        : RsaDigestingOperation(KM_PURPOSE_SIGN, digest, padding, key) {}
     virtual keymaster_error_t Finish(const AuthorizationSet& additional_params,
                                      const Buffer& signature, Buffer* output);
 
   private:
-    keymaster_digest_t digest_;
+    int SignUndigested(Buffer* output);
+    int SignDigested(Buffer* output);
 };
 
-class RsaVerifyOperation : public RsaOperation {
+/**
+ * RSA public key verification operation.
+ */
+class RsaVerifyOperation : public RsaDigestingOperation {
   public:
     RsaVerifyOperation(keymaster_digest_t digest, keymaster_padding_t padding, RSA* key)
-        : RsaOperation(KM_PURPOSE_VERIFY, padding, key), digest_(digest) {}
+        : RsaDigestingOperation(KM_PURPOSE_VERIFY, digest, padding, key) {}
     virtual keymaster_error_t Finish(const AuthorizationSet& additional_params,
                                      const Buffer& signature, Buffer* output);
 
   private:
-    keymaster_digest_t digest_;
+    keymaster_error_t VerifyUndigested(uint8_t* decrypted_data);
+    keymaster_error_t VerifyDigested(uint8_t* decrypted_data);
 };
 
+/**
+ * RSA public key encryption operation.
+ */
 class RsaEncryptOperation : public RsaOperation {
   public:
     RsaEncryptOperation(keymaster_padding_t padding, RSA* key)
@@ -80,6 +121,9 @@
                                      const Buffer& signature, Buffer* output);
 };
 
+/**
+ * RSA private key decryption operation.
+ */
 class RsaDecryptOperation : public RsaOperation {
   public:
     RsaDecryptOperation(keymaster_padding_t padding, RSA* key)