Use specified digest for RSA OAEP.

Bug: 22405614
Change-Id: Ia5eb67a571a9d46acca4b4e708bb8178bd3acd0d
diff --git a/android_keymaster_test.cpp b/android_keymaster_test.cpp
index 2ee8147..c0d6bb7 100644
--- a/android_keymaster_test.cpp
+++ b/android_keymaster_test.cpp
@@ -1861,17 +1861,18 @@
 }
 
 TEST_P(EncryptionOperationsTest, RsaOaepSuccess) {
+    size_t key_size = 768;
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
-                                           .RsaEncryptionKey(512, 3)
+                                           .RsaEncryptionKey(key_size, 3)
                                            .Padding(KM_PAD_RSA_OAEP)
                                            .Digest(KM_DIGEST_SHA_2_256)));
 
-    string message = "Hello World!";
+    string message = "Hello";
     string ciphertext1 = EncryptMessage(string(message), KM_DIGEST_SHA_2_256, KM_PAD_RSA_OAEP);
-    EXPECT_EQ(512U / 8, ciphertext1.size());
+    EXPECT_EQ(key_size / 8, ciphertext1.size());
 
     string ciphertext2 = EncryptMessage(string(message), KM_DIGEST_SHA_2_256, KM_PAD_RSA_OAEP);
-    EXPECT_EQ(512U / 8, ciphertext2.size());
+    EXPECT_EQ(key_size / 8, ciphertext2.size());
 
     // OAEP randomizes padding so every result should be different.
     EXPECT_NE(ciphertext1, ciphertext2);
@@ -1881,13 +1882,14 @@
 }
 
 TEST_P(EncryptionOperationsTest, RsaOaepRoundTrip) {
+    size_t key_size = 768;
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
-                                           .RsaEncryptionKey(512, 3)
+                                           .RsaEncryptionKey(key_size, 3)
                                            .Padding(KM_PAD_RSA_OAEP)
                                            .Digest(KM_DIGEST_SHA_2_256)));
     string message = "Hello World!";
     string ciphertext = EncryptMessage(string(message), KM_DIGEST_SHA_2_256, KM_PAD_RSA_OAEP);
-    EXPECT_EQ(512U / 8, ciphertext.size());
+    EXPECT_EQ(key_size / 8, ciphertext.size());
 
     string plaintext = DecryptMessage(ciphertext, KM_DIGEST_SHA_2_256, KM_PAD_RSA_OAEP);
     EXPECT_EQ(message, plaintext);
@@ -1912,18 +1914,59 @@
         EXPECT_EQ(2, GetParam()->keymaster0_calls());
 }
 
+TEST_P(EncryptionOperationsTest, RsaOaepUnauthorizedDigest) {
+    ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
+                                       .RsaEncryptionKey(512, 3)
+                                           .Padding(KM_PAD_RSA_OAEP)
+                                           .Digest(KM_DIGEST_SHA_2_256)));
+    string message = "Hello World!";
+    // Works because encryption is a public key operation.
+    EncryptMessage(string(message), KM_DIGEST_SHA1, KM_PAD_RSA_OAEP);
+
+    AuthorizationSet begin_params(client_params());
+    begin_params.push_back(TAG_PADDING, KM_PAD_RSA_OAEP);
+    begin_params.push_back(TAG_DIGEST, KM_DIGEST_SHA1);
+    EXPECT_EQ(KM_ERROR_INCOMPATIBLE_DIGEST, BeginOperation(KM_PURPOSE_DECRYPT, begin_params));
+
+    if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
+        EXPECT_EQ(3, GetParam()->keymaster0_calls());
+}
+
+TEST_P(EncryptionOperationsTest, RsaOaepDecryptWithWrongDigest) {
+    ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
+                                           .RsaEncryptionKey(768, 3)
+                                           .Padding(KM_PAD_RSA_OAEP)
+                                           .Digest(KM_DIGEST_SHA_2_256)
+                                           .Digest(KM_DIGEST_SHA_2_384)));
+    string message = "Hello World!";
+    string ciphertext = EncryptMessage(string(message), KM_DIGEST_SHA_2_256, KM_PAD_RSA_OAEP);
+
+    string result;
+    size_t input_consumed;
+    AuthorizationSet begin_params(client_params());
+    begin_params.push_back(TAG_PADDING, KM_PAD_RSA_OAEP);
+    begin_params.push_back(TAG_DIGEST, KM_DIGEST_SHA_2_384);
+    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_DECRYPT, begin_params));
+    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(ciphertext, &result, &input_consumed));
+    EXPECT_EQ(KM_ERROR_UNKNOWN_ERROR, FinishOperation(&result));
+    EXPECT_EQ(0U, result.size());
+
+    if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
+        EXPECT_EQ(4, GetParam()->keymaster0_calls());
+}
+
 TEST_P(EncryptionOperationsTest, RsaOaepTooLarge) {
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
                                            .RsaEncryptionKey(512, 3)
                                            .Padding(KM_PAD_RSA_OAEP)
-                                           .Digest(KM_DIGEST_SHA_2_256)));
+                                           .Digest(KM_DIGEST_SHA1)));
     string message = "12345678901234567890123";
     string result;
     size_t input_consumed;
 
     AuthorizationSet begin_params(client_params());
     begin_params.push_back(TAG_PADDING, KM_PAD_RSA_OAEP);
-    begin_params.push_back(TAG_DIGEST, KM_DIGEST_SHA_2_256);
+    begin_params.push_back(TAG_DIGEST, KM_DIGEST_SHA1);
     EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_ENCRYPT, begin_params));
     EXPECT_EQ(KM_ERROR_OK, UpdateOperation(message, &result, &input_consumed));
     EXPECT_EQ(KM_ERROR_INVALID_INPUT_LENGTH, FinishOperation(&result));
@@ -1934,16 +1977,17 @@
 }
 
 TEST_P(EncryptionOperationsTest, RsaOaepCorruptedDecrypt) {
+    size_t key_size = 768;
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
-                                           .RsaEncryptionKey(512, 3)
+                                           .RsaEncryptionKey(768, 3)
                                            .Padding(KM_PAD_RSA_OAEP)
                                            .Digest(KM_DIGEST_SHA_2_256)));
     string message = "Hello World!";
     string ciphertext = EncryptMessage(string(message), KM_DIGEST_SHA_2_256, KM_PAD_RSA_OAEP);
-    EXPECT_EQ(512U / 8, ciphertext.size());
+    EXPECT_EQ(key_size / 8, ciphertext.size());
 
     // Corrupt the ciphertext
-    ciphertext[512 / 8 / 2]++;
+    ciphertext[key_size / 8 / 2]++;
 
     string result;
     size_t input_consumed;
diff --git a/openssl_err.cpp b/openssl_err.cpp
index 38edc05..2548d5c 100644
--- a/openssl_err.cpp
+++ b/openssl_err.cpp
@@ -145,6 +145,9 @@
 
 keymaster_error_t TranslateRsaError(int reason) {
     switch (reason) {
+    case RSA_R_KEY_SIZE_TOO_SMALL:
+        LOG_W("RSA key is too small to use with selected padding/digest", 0);
+        return KM_ERROR_INCOMPATIBLE_PADDING_MODE;
     case RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE:
     case RSA_R_DATA_TOO_SMALL_FOR_KEY_SIZE:
         return KM_ERROR_INVALID_INPUT_LENGTH;
diff --git a/rsa_operation.cpp b/rsa_operation.cpp
index ce3e2a1..d9217fd 100644
--- a/rsa_operation.cpp
+++ b/rsa_operation.cpp
@@ -67,13 +67,13 @@
                                                       const AuthorizationSet& begin_params,
                                                       keymaster_error_t* error) {
     keymaster_padding_t padding;
-    keymaster_digest_t digest = KM_DIGEST_NONE;
     if (!GetAndValidatePadding(begin_params, key, &padding, error))
         return nullptr;
 
     bool require_digest = (purpose() == KM_PURPOSE_SIGN || purpose() == KM_PURPOSE_VERIFY ||
                            padding == KM_PAD_RSA_OAEP);
 
+    keymaster_digest_t digest = KM_DIGEST_NONE;
     if (require_digest && !GetAndValidateDigest(begin_params, key, &digest, error))
         return nullptr;
     if (!require_digest && begin_params.find(TAG_DIGEST) != -1) {
@@ -141,6 +141,11 @@
         EVP_PKEY_free(rsa_key_);
 }
 
+keymaster_error_t RsaOperation::Begin(const AuthorizationSet& /* input_params */,
+                                      AuthorizationSet* /* output_params */) {
+    return InitDigest();
+}
+
 keymaster_error_t RsaOperation::Update(const AuthorizationSet& /* additional_params */,
                                        const Buffer& input, AuthorizationSet* /* output_params */,
                                        Buffer* /* output */, size_t* input_consumed) {
@@ -251,9 +256,9 @@
     }
 }
 
-keymaster_error_t RsaSignOperation::Begin(const AuthorizationSet& /* input_params */,
-                                          AuthorizationSet* /* output_params */) {
-    keymaster_error_t error = InitDigest();
+keymaster_error_t RsaSignOperation::Begin(const AuthorizationSet& input_params,
+                                          AuthorizationSet* output_params) {
+    keymaster_error_t error = RsaDigestingOperation::Begin(input_params, output_params);
     if (error != KM_ERROR_OK)
         return error;
 
@@ -344,9 +349,9 @@
     return KM_ERROR_OK;
 }
 
-keymaster_error_t RsaVerifyOperation::Begin(const AuthorizationSet& /* input_params */,
-                                            AuthorizationSet* /* output_params */) {
-    keymaster_error_t error = InitDigest();
+keymaster_error_t RsaVerifyOperation::Begin(const AuthorizationSet& input_params,
+                                            AuthorizationSet* output_params) {
+    keymaster_error_t error = RsaDigestingOperation::Begin(input_params, output_params);
     if (error != KM_ERROR_OK)
         return error;
 
@@ -429,6 +434,21 @@
     return KM_ERROR_OK;
 }
 
+keymaster_error_t RsaCryptOperation::SetOaepDigestIfRequired(EVP_PKEY_CTX* pkey_ctx) {
+    if (padding() != KM_PAD_RSA_OAEP)
+        return KM_ERROR_OK;
+
+    assert(digest_algorithm_ != nullptr);
+    if (!EVP_PKEY_CTX_set_rsa_oaep_md(pkey_ctx, digest_algorithm_))
+        return TranslateLastOpenSslError();
+
+    // MGF1 MD is always SHA1.
+    if (!EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, EVP_sha1()))
+        return TranslateLastOpenSslError();
+
+    return KM_ERROR_OK;
+}
+
 int RsaCryptOperation::GetOpensslPadding(keymaster_error_t* error) {
     *error = KM_ERROR_OK;
     switch (padding_) {
@@ -464,6 +484,9 @@
     keymaster_error_t error = SetRsaPaddingInEvpContext(ctx.get());
     if (error != KM_ERROR_OK)
         return error;
+    error = SetOaepDigestIfRequired(ctx.get());
+    if (error != KM_ERROR_OK)
+        return error;
 
     size_t outlen;
     if (EVP_PKEY_encrypt(ctx.get(), nullptr /* out */, &outlen, data_.peek_read(),
@@ -499,6 +522,9 @@
     keymaster_error_t error = SetRsaPaddingInEvpContext(ctx.get());
     if (error != KM_ERROR_OK)
         return error;
+    error = SetOaepDigestIfRequired(ctx.get());
+    if (error != KM_ERROR_OK)
+        return error;
 
     size_t outlen;
     if (EVP_PKEY_decrypt(ctx.get(), nullptr /* out */, &outlen, data_.peek_read(),
diff --git a/rsa_operation.h b/rsa_operation.h
index 89d4f37..30ea3c5 100644
--- a/rsa_operation.h
+++ b/rsa_operation.h
@@ -39,10 +39,8 @@
           digest_algorithm_(nullptr) {}
     ~RsaOperation();
 
-    keymaster_error_t Begin(const AuthorizationSet& /* input_params */,
-                            AuthorizationSet* /* output_params */) override {
-        return KM_ERROR_OK;
-    }
+    keymaster_error_t Begin(const AuthorizationSet& input_params,
+                            AuthorizationSet* output_params) override;
     keymaster_error_t Update(const AuthorizationSet& additional_params, const Buffer& input,
                              AuthorizationSet* output_params, Buffer* output,
                              size_t* input_consumed) override;
@@ -135,6 +133,9 @@
                       EVP_PKEY* key)
         : RsaOperation(KM_PURPOSE_ENCRYPT, digest, padding, key) {}
 
+  protected:
+    keymaster_error_t SetOaepDigestIfRequired(EVP_PKEY_CTX* pkey_ctx);
+
   private:
     int GetOpensslPadding(keymaster_error_t* error) override;
     bool require_digest() const override { return padding_ == KM_PAD_RSA_OAEP; }