Convert BN_MONT_CTX to new-style locking.

This introduces a per-RSA/DSA/DH lock. This is good for lock contention,
although pthread locks are depressingly bloated.

Change-Id: I07c4d1606fc35135fc141ebe6ba904a28c8f8a0c
Reviewed-on: https://boringssl-review.googlesource.com/4324
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/bn/montgomery.c b/crypto/bn/montgomery.c
index 5a9d686..152cf2d 100644
--- a/crypto/bn/montgomery.c
+++ b/crypto/bn/montgomery.c
@@ -114,6 +114,7 @@
 #include <openssl/thread.h>
 
 #include "internal.h"
+#include "../internal.h"
 
 
 #if !defined(OPENSSL_NO_ASM) && \
@@ -292,44 +293,36 @@
   return ret;
 }
 
-BN_MONT_CTX *BN_MONT_CTX_set_locked(BN_MONT_CTX **pmont, int lock,
-                                    const BIGNUM *mod, BN_CTX *ctx) {
-  BN_MONT_CTX *ret;
+BN_MONT_CTX *BN_MONT_CTX_set_locked(BN_MONT_CTX **pmont, CRYPTO_MUTEX *lock,
+                                    const BIGNUM *mod, BN_CTX *bn_ctx) {
+  CRYPTO_MUTEX_lock_read(lock);
+  BN_MONT_CTX *ctx = *pmont;
+  CRYPTO_MUTEX_unlock(lock);
 
-  CRYPTO_r_lock(lock);
-  ret = *pmont;
-  CRYPTO_r_unlock(lock);
-  if (ret) {
-    return ret;
+  if (ctx) {
+    return ctx;
   }
 
-  /* We don't want to serialise globally while doing our lazy-init math in
-   * BN_MONT_CTX_set. That punishes threads that are doing independent
-   * things. Instead, punish the case where more than one thread tries to
-   * lazy-init the same 'pmont', by having each do the lazy-init math work
-   * independently and only use the one from the thread that wins the race
-   * (the losers throw away the work they've done). */
-  ret = BN_MONT_CTX_new();
-  if (!ret) {
-    return NULL;
-  }
-  if (!BN_MONT_CTX_set(ret, mod, ctx)) {
-    BN_MONT_CTX_free(ret);
-    return NULL;
+  CRYPTO_MUTEX_lock_write(lock);
+  ctx = *pmont;
+  if (ctx) {
+    goto out;
   }
 
-  /* The locked compare-and-set, after the local work is done. */
-  CRYPTO_w_lock(lock);
-  if (*pmont) {
-    BN_MONT_CTX_free(ret);
-    ret = *pmont;
-  } else {
-    *pmont = ret;
+  ctx = BN_MONT_CTX_new();
+  if (ctx == NULL) {
+    goto out;
   }
+  if (!BN_MONT_CTX_set(ctx, mod, bn_ctx)) {
+    BN_MONT_CTX_free(ctx);
+    ctx = NULL;
+    goto out;
+  }
+  *pmont = ctx;
 
-  CRYPTO_w_unlock(lock);
-
-  return ret;
+out:
+  CRYPTO_MUTEX_unlock(lock);
+  return ctx;
 }
 
 int BN_to_montgomery(BIGNUM *ret, const BIGNUM *a, const BN_MONT_CTX *mont,
diff --git a/crypto/dh/dh.c b/crypto/dh/dh.c
index 7a50da7..86804bf 100644
--- a/crypto/dh/dh.c
+++ b/crypto/dh/dh.c
@@ -66,6 +66,7 @@
 #include <openssl/thread.h>
 
 #include "internal.h"
+#include "../internal.h"
 
 
 extern const DH_METHOD DH_default_method;
@@ -90,6 +91,8 @@
   }
   METHOD_ref(dh->meth);
 
+  CRYPTO_MUTEX_init(&dh->method_mont_p_lock);
+
   dh->references = 1;
   if (!CRYPTO_new_ex_data(CRYPTO_EX_INDEX_DH, dh, &dh->ex_data)) {
     OPENSSL_free(dh);
@@ -131,6 +134,7 @@
   if (dh->counter != NULL) BN_clear_free(dh->counter);
   if (dh->pub_key != NULL) BN_clear_free(dh->pub_key);
   if (dh->priv_key != NULL) BN_clear_free(dh->priv_key);
+  CRYPTO_MUTEX_cleanup(&dh->method_mont_p_lock);
 
   OPENSSL_free(dh);
 }
diff --git a/crypto/dh/dh_impl.c b/crypto/dh/dh_impl.c
index 5c4d637..81d777d 100644
--- a/crypto/dh/dh_impl.c
+++ b/crypto/dh/dh_impl.c
@@ -62,6 +62,7 @@
 
 #include "internal.h"
 
+
 #define OPENSSL_DH_MAX_MODULUS_BITS 10000
 
 static int generate_parameters(DH *ret, int prime_bits, int generator, BN_GENCB *cb) {
@@ -207,8 +208,8 @@
     pub_key = dh->pub_key;
   }
 
-  mont =
-      BN_MONT_CTX_set_locked(&dh->method_mont_p, CRYPTO_LOCK_DH, dh->p, ctx);
+  mont = BN_MONT_CTX_set_locked(&dh->method_mont_p, &dh->method_mont_p_lock,
+                                dh->p, ctx);
   if (!mont) {
     goto err;
   }
@@ -282,8 +283,8 @@
     goto err;
   }
 
-  mont =
-      BN_MONT_CTX_set_locked(&dh->method_mont_p, CRYPTO_LOCK_DH, dh->p, ctx);
+  mont = BN_MONT_CTX_set_locked(&dh->method_mont_p, &dh->method_mont_p_lock,
+                                dh->p, ctx);
   if (!mont) {
     goto err;
   }
diff --git a/crypto/dsa/dsa.c b/crypto/dsa/dsa.c
index 5303714..c580956 100644
--- a/crypto/dsa/dsa.c
+++ b/crypto/dsa/dsa.c
@@ -70,6 +70,7 @@
 #include <openssl/thread.h>
 
 #include "internal.h"
+#include "../internal.h"
 
 
 extern const DSA_METHOD DSA_default_method;
@@ -97,6 +98,8 @@
   dsa->write_params = 1;
   dsa->references = 1;
 
+  CRYPTO_MUTEX_init(&dsa->method_mont_p_lock);
+
   if (!CRYPTO_new_ex_data(CRYPTO_EX_INDEX_DSA, dsa, &dsa->ex_data)) {
     METHOD_unref(dsa->meth);
     OPENSSL_free(dsa);
@@ -150,6 +153,7 @@
   if (dsa->r != NULL) {
     BN_clear_free(dsa->r);
   }
+  CRYPTO_MUTEX_cleanup(&dsa->method_mont_p_lock);
   OPENSSL_free(dsa);
 }
 
diff --git a/crypto/dsa/dsa_impl.c b/crypto/dsa/dsa_impl.c
index aba7f85..c4df80b 100644
--- a/crypto/dsa/dsa_impl.c
+++ b/crypto/dsa/dsa_impl.c
@@ -123,8 +123,9 @@
 
   BN_set_flags(&k, BN_FLG_CONSTTIME);
 
-  if (!BN_MONT_CTX_set_locked((BN_MONT_CTX **)&dsa->method_mont_p,
-                              CRYPTO_LOCK_DSA, dsa->p, ctx)) {
+  if (BN_MONT_CTX_set_locked((BN_MONT_CTX **)&dsa->method_mont_p,
+                             (CRYPTO_MUTEX *)&dsa->method_mont_p_lock, dsa->p,
+                             ctx) == NULL) {
     goto err;
   }
 
@@ -365,12 +366,14 @@
   }
 
   mont = BN_MONT_CTX_set_locked((BN_MONT_CTX **)&dsa->method_mont_p,
-                                CRYPTO_LOCK_DSA, dsa->p, ctx);
+                                (CRYPTO_MUTEX *)&dsa->method_mont_p_lock,
+                                dsa->p, ctx);
   if (!mont) {
     goto err;
   }
 
-  if (!BN_mod_exp2_mont(&t1, dsa->g, &u1, dsa->pub_key, &u2, dsa->p, ctx, mont)) {
+  if (!BN_mod_exp2_mont(&t1, dsa->g, &u1, dsa->pub_key, &u2, dsa->p, ctx,
+                        mont)) {
     goto err;
   }
 
diff --git a/crypto/rsa/blinding.c b/crypto/rsa/blinding.c
index c5a1604..88682ef 100644
--- a/crypto/rsa/blinding.c
+++ b/crypto/rsa/blinding.c
@@ -414,6 +414,7 @@
   BIGNUM *e, *n;
   BN_CTX *ctx;
   BN_BLINDING *ret = NULL;
+  BN_MONT_CTX *mont_ctx = NULL;
 
   if (in_ctx == NULL) {
     ctx = BN_CTX_new();
@@ -445,14 +446,15 @@
   BN_with_flags(n, rsa->n, BN_FLG_CONSTTIME);
 
   if (rsa->flags & RSA_FLAG_CACHE_PUBLIC) {
-    if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n,
-                                ctx)) {
+    mont_ctx =
+        BN_MONT_CTX_set_locked(&rsa->_method_mod_n, &rsa->lock, rsa->n, ctx);
+    if (mont_ctx == NULL) {
       goto err;
     }
   }
 
   ret = BN_BLINDING_create_param(NULL, e, n, ctx, rsa->meth->bn_mod_exp,
-                                 rsa->_method_mod_n);
+                                 mont_ctx);
   if (ret == NULL) {
     OPENSSL_PUT_ERROR(RSA, rsa_setup_blinding, ERR_R_BN_LIB);
     goto err;
diff --git a/crypto/rsa/rsa.c b/crypto/rsa/rsa.c
index 884f67e..88d38a2 100644
--- a/crypto/rsa/rsa.c
+++ b/crypto/rsa/rsa.c
@@ -67,6 +67,7 @@
 #include <openssl/thread.h>
 
 #include "internal.h"
+#include "../internal.h"
 
 
 extern const RSA_METHOD RSA_default_method;
@@ -93,6 +94,7 @@
 
   rsa->references = 1;
   rsa->flags = rsa->meth->flags;
+  CRYPTO_MUTEX_init(&rsa->lock);
 
   if (!CRYPTO_new_ex_data(CRYPTO_EX_INDEX_RSA, rsa, &rsa->ex_data)) {
     METHOD_unref(rsa->meth);
@@ -161,6 +163,7 @@
   if (rsa->blindings_inuse != NULL) {
     OPENSSL_free(rsa->blindings_inuse);
   }
+  CRYPTO_MUTEX_cleanup(&rsa->lock);
   OPENSSL_free(rsa);
 }
 
diff --git a/crypto/rsa/rsa_impl.c b/crypto/rsa/rsa_impl.c
index f790e64..e8cbd97 100644
--- a/crypto/rsa/rsa_impl.c
+++ b/crypto/rsa/rsa_impl.c
@@ -64,6 +64,7 @@
 #include <openssl/thread.h>
 
 #include "internal.h"
+#include "../internal.h"
 
 
 #define OPENSSL_RSA_MAX_MODULUS_BITS 16384
@@ -166,13 +167,14 @@
   }
 
   if (rsa->flags & RSA_FLAG_CACHE_PUBLIC) {
-    if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n,
-                                ctx)) {
+    if (BN_MONT_CTX_set_locked(&rsa->_method_mod_n, &rsa->lock, rsa->n, ctx) ==
+        NULL) {
       goto err;
     }
   }
 
-  if (!rsa->meth->bn_mod_exp(result, f, rsa->e, rsa->n, ctx, rsa->_method_mod_n)) {
+  if (!rsa->meth->bn_mod_exp(result, f, rsa->e, rsa->n, ctx,
+                             rsa->_method_mod_n)) {
     goto err;
   }
 
@@ -218,7 +220,7 @@
   uint8_t *new_blindings_inuse;
   char overflow = 0;
 
-  CRYPTO_w_lock(CRYPTO_LOCK_RSA_BLINDING);
+  CRYPTO_MUTEX_lock_write(&rsa->lock);
 
   unsigned i;
   for (i = 0; i < rsa->num_blindings; i++) {
@@ -231,7 +233,7 @@
   }
 
   if (ret != NULL) {
-    CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING);
+    CRYPTO_MUTEX_unlock(&rsa->lock);
     return ret;
   }
 
@@ -240,7 +242,7 @@
   /* We didn't find a free BN_BLINDING to use so increase the length of
    * the arrays by one and use the newly created element. */
 
-  CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING);
+  CRYPTO_MUTEX_unlock(&rsa->lock);
   ret = rsa_setup_blinding(rsa, ctx);
   if (ret == NULL) {
     return NULL;
@@ -253,7 +255,7 @@
     return ret;
   }
 
-  CRYPTO_w_lock(CRYPTO_LOCK_RSA_BLINDING);
+  CRYPTO_MUTEX_lock_write(&rsa->lock);
 
   new_blindings =
       OPENSSL_malloc(sizeof(BN_BLINDING *) * (rsa->num_blindings + 1));
@@ -282,14 +284,14 @@
   rsa->blindings_inuse = new_blindings_inuse;
   rsa->num_blindings++;
 
-  CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING);
+  CRYPTO_MUTEX_unlock(&rsa->lock);
   return ret;
 
 err2:
   OPENSSL_free(new_blindings);
 
 err1:
-  CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING);
+  CRYPTO_MUTEX_unlock(&rsa->lock);
   BN_BLINDING_free(ret);
   return NULL;
 }
@@ -304,9 +306,9 @@
     return;
   }
 
-  CRYPTO_w_lock(CRYPTO_LOCK_RSA_BLINDING);
+  CRYPTO_MUTEX_lock_write(&rsa->lock);
   rsa->blindings_inuse[blinding_index] = 0;
-  CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING);
+  CRYPTO_MUTEX_unlock(&rsa->lock);
 }
 
 /* signing */
@@ -479,8 +481,8 @@
   }
 
   if (rsa->flags & RSA_FLAG_CACHE_PUBLIC) {
-    if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n,
-                                ctx)) {
+    if (BN_MONT_CTX_set_locked(&rsa->_method_mod_n, &rsa->lock, rsa->n, ctx) ==
+        NULL) {
       goto err;
     }
   }
@@ -583,8 +585,8 @@
     BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
 
     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC) {
-      if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n,
-                                  ctx)) {
+      if (BN_MONT_CTX_set_locked(&rsa->_method_mod_n, &rsa->lock, rsa->n,
+                                 ctx) == NULL) {
         goto err;
       }
     }
@@ -645,18 +647,20 @@
     BN_with_flags(q, rsa->q, BN_FLG_CONSTTIME);
 
     if (rsa->flags & RSA_FLAG_CACHE_PRIVATE) {
-      if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_p, CRYPTO_LOCK_RSA, p, ctx)) {
+      if (BN_MONT_CTX_set_locked(&rsa->_method_mod_p, &rsa->lock, p, ctx) ==
+          NULL) {
         goto err;
       }
-      if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_q, CRYPTO_LOCK_RSA, q, ctx)) {
+      if (BN_MONT_CTX_set_locked(&rsa->_method_mod_q, &rsa->lock, q, ctx) ==
+          NULL) {
         goto err;
       }
     }
   }
 
   if (rsa->flags & RSA_FLAG_CACHE_PUBLIC) {
-    if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n,
-                                ctx)) {
+    if (BN_MONT_CTX_set_locked(&rsa->_method_mod_n, &rsa->lock, rsa->n, ctx) ==
+        NULL) {
       goto err;
     }
   }
diff --git a/include/openssl/bn.h b/include/openssl/bn.h
index 838870d..917beaf 100644
--- a/include/openssl/bn.h
+++ b/include/openssl/bn.h
@@ -124,6 +124,7 @@
 #define OPENSSL_HEADER_BN_H
 
 #include <openssl/base.h>
+#include <openssl/thread.h>
 
 #include <stdio.h>  /* for FILE* */
 
@@ -711,15 +712,13 @@
 OPENSSL_EXPORT int BN_MONT_CTX_set(BN_MONT_CTX *mont, const BIGNUM *mod,
                                    BN_CTX *ctx);
 
-/* BN_MONT_CTX_set_locked takes the lock indicated by |lock| and checks whether
- * |*pmont| is NULL. If so, it creates a new |BN_MONT_CTX| and sets the modulus
- * for it to |mod|. It then stores it as |*pmont| and returns it, or NULL on
- * error.
+/* BN_MONT_CTX_set_locked takes |lock| and checks whether |*pmont| is NULL. If
+ * so, it creates a new |BN_MONT_CTX| and sets the modulus for it to |mod|. It
+ * then stores it as |*pmont| and returns it, or NULL on error.
  *
  * If |*pmont| is already non-NULL then the existing value is returned. */
-OPENSSL_EXPORT BN_MONT_CTX *BN_MONT_CTX_set_locked(BN_MONT_CTX **pmont,
-                                                   int lock, const BIGNUM *mod,
-                                                   BN_CTX *ctx);
+BN_MONT_CTX *BN_MONT_CTX_set_locked(BN_MONT_CTX **pmont, CRYPTO_MUTEX *lock,
+                                    const BIGNUM *mod, BN_CTX *bn_ctx);
 
 /* BN_to_montgomery sets |ret| equal to |a| in the Montgomery domain. It
  * returns one on success and zero on error. */
diff --git a/include/openssl/dh.h b/include/openssl/dh.h
index 3c8f290..39614ff 100644
--- a/include/openssl/dh.h
+++ b/include/openssl/dh.h
@@ -61,6 +61,7 @@
 
 #include <openssl/engine.h>
 #include <openssl/ex_data.h>
+#include <openssl/thread.h>
 
 #if defined(__cplusplus)
 extern "C" {
@@ -236,6 +237,8 @@
   /* priv_length contains the length, in bits, of the private value. If zero,
    * the private value will be the same length as |p|. */
   unsigned priv_length;
+
+  CRYPTO_MUTEX method_mont_p_lock;
   BN_MONT_CTX *method_mont_p;
 
   /* Place holders if we want to do X9.42 DH */
diff --git a/include/openssl/dsa.h b/include/openssl/dsa.h
index 69dd56b..47270f8 100644
--- a/include/openssl/dsa.h
+++ b/include/openssl/dsa.h
@@ -64,6 +64,7 @@
 
 #include <openssl/engine.h>
 #include <openssl/ex_data.h>
+#include <openssl/thread.h>
 
 #if defined(__cplusplus)
 extern "C" {
@@ -351,6 +352,7 @@
 
   int flags;
   /* Normally used to cache montgomery values */
+  CRYPTO_MUTEX method_mont_p_lock;
   BN_MONT_CTX *method_mont_p;
   int references;
   CRYPTO_EX_DATA ex_data;
diff --git a/include/openssl/rsa.h b/include/openssl/rsa.h
index f49eb14..889ad19 100644
--- a/include/openssl/rsa.h
+++ b/include/openssl/rsa.h
@@ -61,6 +61,7 @@
 
 #include <openssl/engine.h>
 #include <openssl/ex_data.h>
+#include <openssl/thread.h>
 
 #if defined(__cplusplus)
 extern "C" {
@@ -471,18 +472,21 @@
   int references;
   int flags;
 
-  /* Used to cache montgomery values */
+  CRYPTO_MUTEX lock;
+
+  /* Used to cache montgomery values. The creation of these values is protected
+   * by |lock|. */
   BN_MONT_CTX *_method_mod_n;
   BN_MONT_CTX *_method_mod_p;
   BN_MONT_CTX *_method_mod_q;
 
   /* num_blindings contains the size of the |blindings| and |blindings_inuse|
    * arrays. This member and the |blindings_inuse| array are protected by
-   * CRYPTO_LOCK_RSA_BLINDING. */
+   * |lock|. */
   unsigned num_blindings;
   /* blindings is an array of BN_BLINDING structures that can be reserved by a
-   * thread by locking CRYPTO_LOCK_RSA_BLINDING and changing the corresponding
-   * element in |blindings_inuse| from 0 to 1. */
+   * thread by locking |lock| and changing the corresponding element in
+   * |blindings_inuse| from 0 to 1. */
   BN_BLINDING **blindings;
   unsigned char *blindings_inuse;
 };