Use specified digest for RSA OAEP.

Bug: 22405614
Change-Id: Ia5eb67a571a9d46acca4b4e708bb8178bd3acd0d
diff --git a/rsa_operation.cpp b/rsa_operation.cpp
index ce3e2a1..d9217fd 100644
--- a/rsa_operation.cpp
+++ b/rsa_operation.cpp
@@ -67,13 +67,13 @@
                                                       const AuthorizationSet& begin_params,
                                                       keymaster_error_t* error) {
     keymaster_padding_t padding;
-    keymaster_digest_t digest = KM_DIGEST_NONE;
     if (!GetAndValidatePadding(begin_params, key, &padding, error))
         return nullptr;
 
     bool require_digest = (purpose() == KM_PURPOSE_SIGN || purpose() == KM_PURPOSE_VERIFY ||
                            padding == KM_PAD_RSA_OAEP);
 
+    keymaster_digest_t digest = KM_DIGEST_NONE;
     if (require_digest && !GetAndValidateDigest(begin_params, key, &digest, error))
         return nullptr;
     if (!require_digest && begin_params.find(TAG_DIGEST) != -1) {
@@ -141,6 +141,11 @@
         EVP_PKEY_free(rsa_key_);
 }
 
+keymaster_error_t RsaOperation::Begin(const AuthorizationSet& /* input_params */,
+                                      AuthorizationSet* /* output_params */) {
+    return InitDigest();
+}
+
 keymaster_error_t RsaOperation::Update(const AuthorizationSet& /* additional_params */,
                                        const Buffer& input, AuthorizationSet* /* output_params */,
                                        Buffer* /* output */, size_t* input_consumed) {
@@ -251,9 +256,9 @@
     }
 }
 
-keymaster_error_t RsaSignOperation::Begin(const AuthorizationSet& /* input_params */,
-                                          AuthorizationSet* /* output_params */) {
-    keymaster_error_t error = InitDigest();
+keymaster_error_t RsaSignOperation::Begin(const AuthorizationSet& input_params,
+                                          AuthorizationSet* output_params) {
+    keymaster_error_t error = RsaDigestingOperation::Begin(input_params, output_params);
     if (error != KM_ERROR_OK)
         return error;
 
@@ -344,9 +349,9 @@
     return KM_ERROR_OK;
 }
 
-keymaster_error_t RsaVerifyOperation::Begin(const AuthorizationSet& /* input_params */,
-                                            AuthorizationSet* /* output_params */) {
-    keymaster_error_t error = InitDigest();
+keymaster_error_t RsaVerifyOperation::Begin(const AuthorizationSet& input_params,
+                                            AuthorizationSet* output_params) {
+    keymaster_error_t error = RsaDigestingOperation::Begin(input_params, output_params);
     if (error != KM_ERROR_OK)
         return error;
 
@@ -429,6 +434,21 @@
     return KM_ERROR_OK;
 }
 
+keymaster_error_t RsaCryptOperation::SetOaepDigestIfRequired(EVP_PKEY_CTX* pkey_ctx) {
+    if (padding() != KM_PAD_RSA_OAEP)
+        return KM_ERROR_OK;
+
+    assert(digest_algorithm_ != nullptr);
+    if (!EVP_PKEY_CTX_set_rsa_oaep_md(pkey_ctx, digest_algorithm_))
+        return TranslateLastOpenSslError();
+
+    // MGF1 MD is always SHA1.
+    if (!EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, EVP_sha1()))
+        return TranslateLastOpenSslError();
+
+    return KM_ERROR_OK;
+}
+
 int RsaCryptOperation::GetOpensslPadding(keymaster_error_t* error) {
     *error = KM_ERROR_OK;
     switch (padding_) {
@@ -464,6 +484,9 @@
     keymaster_error_t error = SetRsaPaddingInEvpContext(ctx.get());
     if (error != KM_ERROR_OK)
         return error;
+    error = SetOaepDigestIfRequired(ctx.get());
+    if (error != KM_ERROR_OK)
+        return error;
 
     size_t outlen;
     if (EVP_PKEY_encrypt(ctx.get(), nullptr /* out */, &outlen, data_.peek_read(),
@@ -499,6 +522,9 @@
     keymaster_error_t error = SetRsaPaddingInEvpContext(ctx.get());
     if (error != KM_ERROR_OK)
         return error;
+    error = SetOaepDigestIfRequired(ctx.get());
+    if (error != KM_ERROR_OK)
+        return error;
 
     size_t outlen;
     if (EVP_PKEY_decrypt(ctx.get(), nullptr /* out */, &outlen, data_.peek_read(),