Simplify handshake hash handling.

Rather than support arbitrarily many handshake hashes in the general
case (which the PRF logic assumes is capped at two), special-case the
MD5/SHA1 two-hash combination and otherwise maintain a single rolling
hash.

Change-Id: Ide9475565b158f6839bb10b8b22f324f89399f92
Reviewed-on: https://boringssl-review.googlesource.com/5618
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/t1_enc.c b/ssl/t1_enc.c
index febd54d..aa6095d 100644
--- a/ssl/t1_enc.c
+++ b/ssl/t1_enc.c
@@ -149,7 +149,7 @@
 
 
 /* tls1_P_hash computes the TLS P_<hash> function as described in RFC 5246,
- * section 5. It writes |out_len| bytes to |out|, using |md| as the hash and
+ * section 5. It XORs |out_len| bytes to |out|, using |md| as the hash and
  * |secret| as the secret. |seed1| through |seed3| are concatenated to form the
  * seed parameter. It returns one on success and zero on failure. */
 static int tls1_P_hash(uint8_t *out, size_t out_len, const EVP_MD *md,
@@ -188,26 +188,32 @@
       goto err;
     }
 
-    if (out_len > chunk) {
-      unsigned len;
-      if (!HMAC_Final(&ctx, out, &len)) {
-        goto err;
-      }
-      assert(len == chunk);
-      out += len;
-      out_len -= len;
-      /* Calculate the next A1 value. */
-      if (!HMAC_Final(&ctx_tmp, A1, &A1_len)) {
-        goto err;
-      }
-    } else {
-      /* Last chunk. */
-      if (!HMAC_Final(&ctx, A1, &A1_len)) {
-        goto err;
-      }
-      memcpy(out, A1, out_len);
+    unsigned len;
+    uint8_t hmac[EVP_MAX_MD_SIZE];
+    if (!HMAC_Final(&ctx, hmac, &len)) {
+      goto err;
+    }
+    assert(len == chunk);
+
+    /* XOR the result into |out|. */
+    if (len > out_len) {
+      len = out_len;
+    }
+    unsigned i;
+    for (i = 0; i < len; i++) {
+      out[i] ^= hmac[i];
+    }
+    out += len;
+    out_len -= len;
+
+    if (out_len == 0) {
       break;
     }
+
+    /* Calculate the next A1 value. */
+    if (!HMAC_Final(&ctx_tmp, A1, &A1_len)) {
+      goto err;
+    }
   }
 
   ret = 1;
@@ -224,62 +230,36 @@
              size_t secret_len, const char *label, size_t label_len,
              const uint8_t *seed1, size_t seed1_len,
              const uint8_t *seed2, size_t seed2_len) {
-  size_t idx, len, count, i;
-  const uint8_t *S1;
-  uint32_t m;
-  const EVP_MD *md;
-  int ret = 0;
-  uint8_t *tmp;
 
   if (out_len == 0) {
     return 1;
   }
 
-  /* Allocate a temporary buffer. */
-  tmp = OPENSSL_malloc(out_len);
-  if (tmp == NULL) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+  memset(out, 0, out_len);
+
+  uint32_t algorithm_prf = ssl_get_algorithm_prf(s);
+  if (algorithm_prf == SSL_HANDSHAKE_MAC_DEFAULT) {
+    /* If using the MD5/SHA1 PRF, |secret| is partitioned between SHA-1 and
+     * MD5, MD5 first. */
+    size_t secret_half = secret_len - (secret_len / 2);
+    if (!tls1_P_hash(out, out_len, EVP_md5(), secret, secret_half,
+                     (const uint8_t *)label, label_len, seed1, seed1_len, seed2,
+                     seed2_len)) {
+      return 0;
+    }
+
+    /* Note that, if |secret_len| is odd, the two halves share a byte. */
+    secret = secret + (secret_len - secret_half);
+    secret_len = secret_half;
+  }
+
+  if (!tls1_P_hash(out, out_len, ssl_get_handshake_digest(algorithm_prf),
+                   secret, secret_len, (const uint8_t *)label, label_len,
+                   seed1, seed1_len, seed2, seed2_len)) {
     return 0;
   }
 
-  /* Count number of digests and partition |secret| evenly. */
-  count = 0;
-  for (idx = 0; ssl_get_handshake_digest(&m, &md, idx); idx++) {
-    if (m & ssl_get_algorithm_prf(s)) {
-      count++;
-    }
-  }
-  /* TODO(davidben): The only case where count isn't 1 is the old MD5/SHA-1
-   * combination. The logic around multiple handshake digests can probably be
-   * simplified. */
-  assert(count == 1 || count == 2);
-  len = secret_len / count;
-  if (count == 1) {
-    secret_len = 0;
-  }
-  S1 = secret;
-  memset(out, 0, out_len);
-  for (idx = 0; ssl_get_handshake_digest(&m, &md, idx); idx++) {
-    if (m & ssl_get_algorithm_prf(s)) {
-      /* If |count| is 2 and |secret_len| is odd, |secret| is partitioned into
-       * two halves with an overlapping byte. */
-      if (!tls1_P_hash(tmp, out_len, md, S1, len + (secret_len & 1),
-                       (const uint8_t *)label, label_len, seed1, seed1_len,
-                       seed2, seed2_len)) {
-        goto err;
-      }
-      S1 += len;
-      for (i = 0; i < out_len; i++) {
-        out[i] ^= tmp[i];
-      }
-    }
-  }
-  ret = 1;
-
-err:
-  OPENSSL_cleanse(tmp, out_len);
-  OPENSSL_free(tmp);
-  return ret;
+  return 1;
 }
 
 static int tls1_generate_key_block(SSL *s, uint8_t *out, size_t out_len) {
@@ -469,31 +449,50 @@
 }
 
 int tls1_cert_verify_mac(SSL *s, int md_nid, uint8_t *out) {
-  unsigned int ret;
-  EVP_MD_CTX ctx, *d = NULL;
-  int i;
-
-  for (i = 0; i < SSL_MAX_DIGEST; i++) {
-    if (s->s3->handshake_dgst[i] &&
-        EVP_MD_CTX_type(s->s3->handshake_dgst[i]) == md_nid) {
-      d = s->s3->handshake_dgst[i];
-      break;
-    }
-  }
-
-  if (!d) {
+  const EVP_MD_CTX *ctx_template;
+  if (md_nid == NID_md5) {
+    ctx_template = &s->s3->handshake_md5;
+  } else if (md_nid == EVP_MD_CTX_type(&s->s3->handshake_hash)) {
+    ctx_template = &s->s3->handshake_hash;
+  } else {
     OPENSSL_PUT_ERROR(SSL, SSL_R_NO_REQUIRED_DIGEST);
     return 0;
   }
 
+  EVP_MD_CTX ctx;
   EVP_MD_CTX_init(&ctx);
-  if (!EVP_MD_CTX_copy_ex(&ctx, d)) {
+  if (!EVP_MD_CTX_copy_ex(&ctx, ctx_template)) {
     EVP_MD_CTX_cleanup(&ctx);
     return 0;
   }
+  unsigned ret;
   EVP_DigestFinal_ex(&ctx, out, &ret);
   EVP_MD_CTX_cleanup(&ctx);
+  return ret;
+}
 
+static int append_digest(const EVP_MD_CTX *ctx, uint8_t *out, size_t *out_len,
+                         size_t max_out) {
+  int ret = 0;
+  EVP_MD_CTX ctx_copy;
+  EVP_MD_CTX_init(&ctx_copy);
+
+  if (EVP_MD_CTX_size(ctx) > max_out) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
+    goto err;
+  }
+  unsigned len;
+  if (!EVP_MD_CTX_copy_ex(&ctx_copy, ctx) ||
+      !EVP_DigestFinal_ex(&ctx_copy, out, &len)) {
+    goto err;
+  }
+  assert(len == EVP_MD_CTX_size(ctx));
+
+  *out_len = len;
+  ret = 1;
+
+err:
+  EVP_MD_CTX_cleanup(&ctx_copy);
   return ret;
 }
 
@@ -503,44 +502,19 @@
  * underlying digests so can be called multiple times and prior to the final
  * update etc. */
 int tls1_handshake_digest(SSL *s, uint8_t *out, size_t out_len) {
-  const EVP_MD *md;
-  EVP_MD_CTX ctx;
-  int err = 0, len = 0;
-  size_t i;
-  uint32_t mask;
-
-  EVP_MD_CTX_init(&ctx);
-
-  for (i = 0; ssl_get_handshake_digest(&mask, &md, i); i++) {
-    size_t hash_size;
-    unsigned int digest_len;
-    EVP_MD_CTX *hdgst = s->s3->handshake_dgst[i];
-
-    if ((mask & ssl_get_algorithm_prf(s)) == 0) {
-      continue;
-    }
-
-    hash_size = EVP_MD_size(md);
-    if (!hdgst ||
-        hash_size > out_len ||
-        !EVP_MD_CTX_copy_ex(&ctx, hdgst) ||
-        !EVP_DigestFinal_ex(&ctx, out, &digest_len) ||
-        digest_len != hash_size /* internal error */) {
-      err = 1;
-      break;
-    }
-
-    out += digest_len;
-    out_len -= digest_len;
-    len += digest_len;
-  }
-
-  EVP_MD_CTX_cleanup(&ctx);
-
-  if (err != 0) {
+  size_t md5_len = 0;
+  if (EVP_MD_CTX_md(&s->s3->handshake_md5) != NULL &&
+      !append_digest(&s->s3->handshake_md5, out, &md5_len, out_len)) {
     return -1;
   }
-  return len;
+
+  size_t len;
+  if (!append_digest(&s->s3->handshake_hash, out + md5_len, &len,
+                     out_len - md5_len)) {
+    return -1;
+  }
+
+  return (int)(md5_len + len);
 }
 
 int tls1_final_finish_mac(SSL *s, const char *str, int slen, uint8_t *out) {