Support for P256 curve in RKP for Strongbox
Test: Run Rkp Vts tests

Change-Id: I0972d4e7755d02e138aeb7e9780adfafd1fbfce0
diff --git a/cppcose/cppcose.cpp b/cppcose/cppcose.cpp
index bfe9928..411dc01 100644
--- a/cppcose/cppcose.cpp
+++ b/cppcose/cppcose.cpp
@@ -21,10 +21,17 @@
 
 #include <cppbor.h>
 #include <cppbor_parse.h>
+#include <openssl/ecdsa.h>
 
 #include <openssl/err.h>
 
 namespace cppcose {
+constexpr int kP256AffinePointSize = 32;
+
+using EVP_PKEY_Ptr = bssl::UniquePtr<EVP_PKEY>;
+using EVP_PKEY_CTX_Ptr = bssl::UniquePtr<EVP_PKEY_CTX>;
+using ECDSA_SIG_Ptr = bssl::UniquePtr<ECDSA_SIG>;
+using EC_KEY_Ptr = bssl::UniquePtr<EC_KEY>;
 
 namespace {
 
@@ -51,8 +58,134 @@
     return std::move(ctx);
 }
 
+ErrMsgOr<bytevec> signEcdsaDigest(const bytevec& key, const bytevec& data) {
+    auto bn = BIGNUM_Ptr(BN_bin2bn(key.data(), key.size(), nullptr));
+    if (bn.get() == nullptr) {
+        return "Error creating BIGNUM";
+    }
+
+    auto ec_key = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
+    if (EC_KEY_set_private_key(ec_key.get(), bn.get()) != 1) {
+        return "Error setting private key from BIGNUM";
+    }
+
+    auto sig = ECDSA_SIG_Ptr(ECDSA_do_sign(data.data(), data.size(), ec_key.get()));
+    if (sig == nullptr) {
+        return "Error signing digest";
+    }
+    size_t len = i2d_ECDSA_SIG(sig.get(), nullptr);
+    bytevec signature(len);
+    unsigned char* p = (unsigned char*)signature.data();
+    i2d_ECDSA_SIG(sig.get(), &p);
+    return signature;
+}
+
+ErrMsgOr<bytevec> ecdh(const bytevec& publicKey, const bytevec& privateKey) {
+    auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
+    auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
+    if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
+        1) {
+        return "Error decoding publicKey";
+    }
+    auto ecKey = EC_KEY_Ptr(EC_KEY_new());
+    auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
+    if (ecKey.get() == nullptr || pkey.get() == nullptr) {
+        return "Memory allocation failed";
+    }
+    if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
+        return "Error setting group";
+    }
+    if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
+        return "Error setting point";
+    }
+    if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
+        return "Error setting key";
+    }
+
+    auto bn = BIGNUM_Ptr(BN_bin2bn(privateKey.data(), privateKey.size(), nullptr));
+    if (bn.get() == nullptr) {
+        return "Error creating BIGNUM for private key";
+    }
+    auto privEcKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
+    if (EC_KEY_set_private_key(privEcKey.get(), bn.get()) != 1) {
+        return "Error setting private key from BIGNUM";
+    }
+    auto privPkey = EVP_PKEY_Ptr(EVP_PKEY_new());
+    if (EVP_PKEY_set1_EC_KEY(privPkey.get(), privEcKey.get()) != 1) {
+        return "Error setting private key";
+    }
+
+    auto ctx = EVP_PKEY_CTX_Ptr(EVP_PKEY_CTX_new(privPkey.get(), NULL));
+    if (ctx.get() == nullptr) {
+        return "Error creating context";
+    }
+
+    if (EVP_PKEY_derive_init(ctx.get()) != 1) {
+        return "Error initializing context";
+    }
+
+    if (EVP_PKEY_derive_set_peer(ctx.get(), pkey.get()) != 1) {
+        return "Error setting peer";
+    }
+
+    /* Determine buffer length for shared secret */
+    size_t secretLen = 0;
+    if (EVP_PKEY_derive(ctx.get(), NULL, &secretLen) != 1) {
+        return "Error determing length of shared secret";
+    }
+    bytevec sharedSecret(secretLen);
+
+    if (EVP_PKEY_derive(ctx.get(), sharedSecret.data(), &secretLen) != 1) {
+        return "Error deriving shared secret";
+    }
+    return sharedSecret;
+}
+
 }  // namespace
 
+ErrMsgOr<bytevec> ecdsaCoseSignatureToDer(const bytevec& ecdsaCoseSignature) {
+    if (ecdsaCoseSignature.size() != 64) {
+        return "COSE signature wrong length";
+    }
+
+    auto rBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data(), 32, nullptr));
+    if (rBn.get() == nullptr) {
+        return "Error creating BIGNUM for r";
+    }
+
+    auto sBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data() + 32, 32, nullptr));
+    if (sBn.get() == nullptr) {
+        return "Error creating BIGNUM for s";
+    }
+
+    ECDSA_SIG sig;
+    sig.r = rBn.get();
+    sig.s = sBn.get();
+
+    size_t len = i2d_ECDSA_SIG(&sig, nullptr);
+    bytevec derSignature(len);
+    unsigned char* p = (unsigned char*)derSignature.data();
+    i2d_ECDSA_SIG(&sig, &p);
+    return derSignature;
+}
+
+ErrMsgOr<bytevec> ecdsaDerSignatureToCose(const bytevec& ecdsaSignature) {
+    const unsigned char* p = ecdsaSignature.data();
+    auto sig = ECDSA_SIG_Ptr(d2i_ECDSA_SIG(nullptr, &p, ecdsaSignature.size()));
+    if (sig == nullptr) {
+        return "Error decoding DER signature";
+    }
+
+    bytevec ecdsaCoseSignature(64, 0);
+    if (BN_bn2binpad(ECDSA_SIG_get0_r(sig.get()), ecdsaCoseSignature.data(), 32) != 32) {
+        return "Error encoding r";
+    }
+    if (BN_bn2binpad(ECDSA_SIG_get0_s(sig.get()), ecdsaCoseSignature.data() + 32, 32) != 32) {
+        return "Error encoding s";
+    }
+    return ecdsaCoseSignature;
+}
+
 ErrMsgOr<HmacSha256> generateHmacSha256(const bytevec& key, const bytevec& data) {
     HmacSha256 digest;
     unsigned int outLen;
@@ -134,6 +267,20 @@
     return payload->value();
 }
 
+ErrMsgOr<bytevec> createECDSACoseSign1Signature(const bytevec& key, const bytevec& protectedParams,
+                                                const bytevec& payload, const bytevec& aad) {
+    bytevec signatureInput = cppbor::Array()
+                                 .add("Signature1")  //
+                                 .add(protectedParams)
+                                 .add(aad)
+                                 .add(payload)
+                                 .encode();
+    auto ecdsaSignature = signEcdsaDigest(key, sha256(signatureInput));
+    if (!ecdsaSignature) return ecdsaSignature.moveMessage();
+
+    return ecdsaDerSignatureToCose(*ecdsaSignature);
+}
+
 ErrMsgOr<bytevec> createCoseSign1Signature(const bytevec& key, const bytevec& protectedParams,
                                            const bytevec& payload, const bytevec& aad) {
     bytevec signatureInput = cppbor::Array()
@@ -152,6 +299,19 @@
     return signature;
 }
 
+ErrMsgOr<cppbor::Array> constructECDSACoseSign1(const bytevec& key, cppbor::Map protectedParams,
+                                                const bytevec& payload, const bytevec& aad) {
+    bytevec protParms = protectedParams.add(ALGORITHM, ES256).canonicalize().encode();
+    auto signature = createECDSACoseSign1Signature(key, protParms, payload, aad);
+    if (!signature) return signature.moveMessage();
+
+    return cppbor::Array()
+        .add(std::move(protParms))
+        .add(cppbor::Map() /* unprotected parameters */)
+        .add(std::move(payload))
+        .add(std::move(*signature));
+}
+
 ErrMsgOr<cppbor::Array> constructCoseSign1(const bytevec& key, cppbor::Map protectedParams,
                                            const bytevec& payload, const bytevec& aad) {
     bytevec protParms = protectedParams.add(ALGORITHM, EDDSA).canonicalize().encode();
@@ -193,7 +353,8 @@
     }
 
     auto& algorithm = parsedProtParams->asMap()->get(ALGORITHM);
-    if (!algorithm || !algorithm->asInt() || algorithm->asInt()->value() != EDDSA) {
+    if (!algorithm || !algorithm->asInt() ||
+        !(algorithm->asInt()->value() == EDDSA || algorithm->asInt()->value() == ES256)) {
         return "Unsupported signature algorithm";
     }
 
@@ -203,17 +364,36 @@
     }
 
     bool selfSigned = signingCoseKey.empty();
-    auto key = CoseKey::parseEd25519(selfSigned ? payload->value() : signingCoseKey);
-    if (!key || key->getBstrValue(CoseKey::PUBKEY_X)->empty()) {
-        return "Bad signing key: " + key.moveMessage();
-    }
-
     bytevec signatureInput =
         cppbor::Array().add("Signature1").add(*protectedParams).add(aad).add(*payload).encode();
+    if (algorithm->asInt()->value() == EDDSA) {
+        auto key = CoseKey::parseEd25519(selfSigned ? payload->value() : signingCoseKey);
+        if (!key || key->getBstrValue(CoseKey::PUBKEY_X)->empty()) {
+            return "Bad signing key: " + key.moveMessage();
+        }
 
-    if (!ED25519_verify(signatureInput.data(), signatureInput.size(), signature->value().data(),
-                        key->getBstrValue(CoseKey::PUBKEY_X)->data())) {
-        return "Signature verification failed";
+        if (!ED25519_verify(signatureInput.data(), signatureInput.size(), signature->value().data(),
+                            key->getBstrValue(CoseKey::PUBKEY_X)->data())) {
+            return "Signature verification failed";
+        }
+    } else {  // P256
+        auto key = CoseKey::parseP256(selfSigned ? payload->value() : signingCoseKey);
+        if (!key || key->getBstrValue(CoseKey::PUBKEY_X)->empty() ||
+            key->getBstrValue(CoseKey::PUBKEY_Y)->empty()) {
+            return "Bad signing key: " + key.moveMessage();
+        }
+        auto publicKey = key->getEcPublicKey();
+        if (!publicKey) return publicKey.moveMessage();
+
+        auto ecdsaDerSignature = ecdsaCoseSignatureToDer(signature->value());
+        if (!ecdsaDerSignature) return ecdsaDerSignature.moveMessage();
+
+        // convert public key to uncompressed form by prepending 0x04 at begin.
+        publicKey->insert(publicKey->begin(), 0x04);
+
+        if (!verifyEcdsaDigest(publicKey.moveValue(), sha256(signatureInput), *ecdsaDerSignature)) {
+            return "Signature verification failed";
+        }
     }
 
     return payload->value();
@@ -294,28 +474,47 @@
     if (!senderCoseKey || !senderCoseKey->asMap()) return "Invalid sender COSE_Key";
 
     auto& keyType = senderCoseKey->asMap()->get(CoseKey::KEY_TYPE);
-    if (!keyType || !keyType->asInt() || keyType->asInt()->value() != OCTET_KEY_PAIR) {
+    if (!keyType || !keyType->asInt() ||
+        (keyType->asInt()->value() != OCTET_KEY_PAIR && keyType->asInt()->value() != EC2)) {
         return "Invalid key type";
     }
 
     auto& curve = senderCoseKey->asMap()->get(CoseKey::CURVE);
-    if (!curve || !curve->asInt() || curve->asInt()->value() != X25519) {
+    if (!curve || !curve->asInt() ||
+        (keyType->asInt()->value() == OCTET_KEY_PAIR && curve->asInt()->value() != X25519) ||
+        (keyType->asInt()->value() == EC2 && curve->asInt()->value() != P256)) {
         return "Unsupported curve";
     }
 
-    auto& pubkey = senderCoseKey->asMap()->get(CoseKey::PUBKEY_X);
-    if (!pubkey || !pubkey->asBstr() ||
-        pubkey->asBstr()->value().size() != X25519_PUBLIC_VALUE_LEN) {
-        return "Invalid X25519 public key";
+    bytevec publicKey;
+    if (keyType->asInt()->value() == EC2) {
+        auto& pubX = senderCoseKey->asMap()->get(CoseKey::PUBKEY_X);
+        if (!pubX || !pubX->asBstr() || pubX->asBstr()->value().size() != kP256AffinePointSize) {
+            return "Invalid EC public key";
+        }
+        auto& pubY = senderCoseKey->asMap()->get(CoseKey::PUBKEY_Y);
+        if (!pubY || !pubY->asBstr() || pubY->asBstr()->value().size() != kP256AffinePointSize) {
+            return "Invalid EC public key";
+        }
+        auto key = CoseKey::getEcPublicKey(pubX->asBstr()->value(), pubY->asBstr()->value());
+        if (!key) return key.moveMessage();
+        publicKey = key.moveValue();
+    } else {
+        auto& pubkey = senderCoseKey->asMap()->get(CoseKey::PUBKEY_X);
+        if (!pubkey || !pubkey->asBstr() ||
+            pubkey->asBstr()->value().size() != X25519_PUBLIC_VALUE_LEN) {
+            return "Invalid X25519 public key";
+        }
+        publicKey = pubkey->asBstr()->value();
     }
 
     auto& key_id = unprotParms->asMap()->get(KEY_ID);
     if (key_id && key_id->asBstr()) {
-        return std::make_pair(pubkey->asBstr()->value(), key_id->asBstr()->value());
+        return std::make_pair(publicKey, key_id->asBstr()->value());
     }
 
     // If no key ID, just return an empty vector.
-    return std::make_pair(pubkey->asBstr()->value(), bytevec{});
+    return std::make_pair(publicKey, bytevec{});
 }
 
 ErrMsgOr<bytevec> decryptCoseEncrypt(const bytevec& key, const cppbor::Item* coseEncrypt,
@@ -367,17 +566,12 @@
     return aesGcmDecrypt(key, nonce->asBstr()->value(), aad, ciphertext->asBstr()->value());
 }
 
-ErrMsgOr<bytevec> x25519_HKDF_DeriveKey(const bytevec& pubKeyA, const bytevec& privKeyA,
-                                        const bytevec& pubKeyB, bool senderIsA) {
+ErrMsgOr<bytevec> consructKdfContext(const bytevec& pubKeyA, const bytevec& privKeyA,
+                                     const bytevec& pubKeyB, bool senderIsA) {
     if (privKeyA.empty() || pubKeyA.empty() || pubKeyB.empty()) {
         return "Missing input key parameters";
     }
 
-    bytevec rawSharedKey(X25519_SHARED_KEY_LEN);
-    if (!::X25519(rawSharedKey.data(), privKeyA.data(), pubKeyB.data())) {
-        return "ECDH operation failed";
-    }
-
     bytevec kdfContext = cppbor::Array()
                              .add(AES_GCM_256)
                              .add(cppbor::Array()  // Sender Info
@@ -392,6 +586,51 @@
                                       .add(kAesGcmKeySizeBits)  // output key length
                                       .add(bytevec{}))          // protected
                              .encode();
+    return kdfContext;
+}
+
+ErrMsgOr<bytevec> ECDH_HKDF_DeriveKey(const bytevec& pubKeyA, const bytevec& privKeyA,
+                                      const bytevec& pubKeyB, bool senderIsA) {
+    if (privKeyA.empty() || pubKeyA.empty() || pubKeyB.empty()) {
+        return "Missing input key parameters";
+    }
+
+    // convert public key to uncompressed form by prepending 0x04 at begin
+    bytevec publicKey;
+    publicKey.insert(publicKey.begin(), 0x04);
+    publicKey.insert(publicKey.end(), pubKeyB.begin(), pubKeyB.end());
+    auto rawSharedKey = ecdh(publicKey, privKeyA);
+    if (!rawSharedKey) return rawSharedKey.moveMessage();
+
+    auto kdfContext = consructKdfContext(pubKeyA, privKeyA, pubKeyB, senderIsA);
+    if (!kdfContext) return kdfContext.moveMessage();
+
+    bytevec retval(SHA256_DIGEST_LENGTH);
+    bytevec salt{};
+    if (!HKDF(retval.data(), retval.size(),                //
+              EVP_sha256(),                                //
+              rawSharedKey->data(), rawSharedKey->size(),  //
+              salt.data(), salt.size(),                    //
+              kdfContext->data(), kdfContext->size())) {
+        return "ECDH HKDF failed";
+    }
+
+    return retval;
+}
+
+ErrMsgOr<bytevec> x25519_HKDF_DeriveKey(const bytevec& pubKeyA, const bytevec& privKeyA,
+                                        const bytevec& pubKeyB, bool senderIsA) {
+    if (privKeyA.empty() || pubKeyA.empty() || pubKeyB.empty()) {
+        return "Missing input key parameters";
+    }
+
+    bytevec rawSharedKey(X25519_SHARED_KEY_LEN);
+    if (!::X25519(rawSharedKey.data(), privKeyA.data(), pubKeyB.data())) {
+        return "ECDH operation failed";
+    }
+
+    auto kdfContext = consructKdfContext(pubKeyA, privKeyA, pubKeyB, senderIsA);
+    if (!kdfContext) return kdfContext.moveMessage();
 
     bytevec retval(SHA256_DIGEST_LENGTH);
     bytevec salt{};
@@ -399,7 +638,7 @@
               EVP_sha256(),                              //
               rawSharedKey.data(), rawSharedKey.size(),  //
               salt.data(), salt.size(),                  //
-              kdfContext.data(), kdfContext.size())) {
+              kdfContext->data(), kdfContext->size())) {
         return "ECDH HKDF failed";
     }
 
@@ -460,4 +699,43 @@
     return plaintext;
 }
 
+bytevec sha256(const bytevec& data) {
+    bytevec ret(SHA256_DIGEST_LENGTH);
+    SHA256_CTX ctx;
+    SHA256_Init(&ctx);
+    SHA256_Update(&ctx, data.data(), data.size());
+    SHA256_Final((unsigned char*)ret.data(), &ctx);
+    return ret;
+}
+
+bool verifyEcdsaDigest(const bytevec& key, const bytevec& digest, const bytevec& signature) {
+    const unsigned char* p = (unsigned char*)signature.data();
+    auto sig = ECDSA_SIG_Ptr(d2i_ECDSA_SIG(nullptr, &p, signature.size()));
+    if (sig.get() == nullptr) {
+        return false;
+    }
+
+    auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
+    auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
+    if (EC_POINT_oct2point(group.get(), point.get(), key.data(), key.size(), nullptr) != 1) {
+        return false;
+    }
+    auto ecKey = EC_KEY_Ptr(EC_KEY_new());
+    if (ecKey.get() == nullptr) {
+        return false;
+    }
+    if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
+        return false;
+    }
+    if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
+        return false;
+    }
+
+    int rc = ECDSA_do_verify(digest.data(), digest.size(), sig.get(), ecKey.get());
+    if (rc != 1) {
+        return false;
+    }
+    return true;
+}
+
 }  // namespace cppcose