Don't use init_buf in DTLS.

This machinery is so different between TLS and DTLS that there is no
sense in having them share structures. This switches us to maintaining
the full reassembled message in hm_fragment and get_message just lets
the caller read out of that when ready.

This removes the last direct handshake dependency on init_buf,
ssl3_hash_message.

Change-Id: I4eccfb6e6021116255daead5359a0aa3f4d5be7b
Reviewed-on: https://boringssl-review.googlesource.com/8667
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.c b/ssl/d1_both.c
index 6aa4cc6..3da1fd8 100644
--- a/ssl/d1_both.c
+++ b/ssl/d1_both.c
@@ -146,34 +146,50 @@
   if (frag == NULL) {
     return;
   }
-  OPENSSL_free(frag->fragment);
+  OPENSSL_free(frag->data);
   OPENSSL_free(frag->reassembly);
   OPENSSL_free(frag);
 }
 
-static hm_fragment *dtls1_hm_fragment_new(size_t frag_len) {
+static hm_fragment *dtls1_hm_fragment_new(const struct hm_header_st *msg_hdr) {
   hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment));
   if (frag == NULL) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
     return NULL;
   }
   memset(frag, 0, sizeof(hm_fragment));
+  frag->type = msg_hdr->type;
+  frag->seq = msg_hdr->seq;
+  frag->msg_len = msg_hdr->msg_len;
 
-  /* If the handshake message is empty, |frag->fragment| and |frag->reassembly|
-   * are NULL. */
-  if (frag_len > 0) {
-    frag->fragment = OPENSSL_malloc(frag_len);
-    if (frag->fragment == NULL) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-      goto err;
-    }
+  /* Allocate space for the reassembled message and fill in the header. */
+  frag->data = OPENSSL_malloc(DTLS1_HM_HEADER_LENGTH + msg_hdr->msg_len);
+  if (frag->data == NULL) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+    goto err;
+  }
 
+  CBB cbb;
+  if (!CBB_init_fixed(&cbb, frag->data, DTLS1_HM_HEADER_LENGTH) ||
+      !CBB_add_u8(&cbb, msg_hdr->type) ||
+      !CBB_add_u24(&cbb, msg_hdr->msg_len) ||
+      !CBB_add_u16(&cbb, msg_hdr->seq) ||
+      !CBB_add_u24(&cbb, 0 /* frag_off */) ||
+      !CBB_add_u24(&cbb, msg_hdr->msg_len) ||
+      !CBB_finish(&cbb, NULL, NULL)) {
+    CBB_cleanup(&cbb);
+    OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+    goto err;
+  }
+
+  /* If the handshake message is empty, |frag->reassembly| is NULL. */
+  if (msg_hdr->msg_len > 0) {
     /* Initialize reassembly bitmask. */
-    if (frag_len + 7 < frag_len) {
+    if (msg_hdr->msg_len + 7 < msg_hdr->msg_len) {
       OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
       goto err;
     }
-    size_t bitmask_len = (frag_len + 7) / 8;
+    size_t bitmask_len = (msg_hdr->msg_len + 7) / 8;
     frag->reassembly = OPENSSL_malloc(bitmask_len);
     if (frag->reassembly == NULL) {
       OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
@@ -202,7 +218,7 @@
 static void dtls1_hm_fragment_mark(hm_fragment *frag, size_t start,
                                    size_t end) {
   size_t i;
-  size_t msg_len = frag->msg_header.msg_len;
+  size_t msg_len = frag->msg_len;
 
   if (frag->reassembly == NULL || start > end || end > msg_len) {
     assert(0);
@@ -238,14 +254,25 @@
   frag->reassembly = NULL;
 }
 
-/* dtls1_is_next_message_complete returns one if the next handshake message is
- * complete and zero otherwise. */
-static int dtls1_is_next_message_complete(SSL *ssl) {
+/* dtls1_is_current_message_complete returns one if the current handshake
+ * message is complete and zero otherwise. */
+static int dtls1_is_current_message_complete(SSL *ssl) {
   hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
                                                  SSL_MAX_HANDSHAKE_FLIGHT];
   return frag != NULL && frag->reassembly == NULL;
 }
 
+/* dtls1_pop_message removes the current handshake message, which must be
+ * complete, and advances to the next one. */
+static void dtls1_pop_message(SSL *ssl) {
+  assert(dtls1_is_current_message_complete(ssl));
+
+  size_t index = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
+  dtls1_hm_fragment_free(ssl->d1->incoming_messages[index]);
+  ssl->d1->incoming_messages[index] = NULL;
+  ssl->d1->handshake_read_seq++;
+}
+
 /* dtls1_get_incoming_message returns the incoming message corresponding to
  * |msg_hdr|. If none exists, it creates a new one and inserts it in the
  * queue. Otherwise, it checks |msg_hdr| is consistent with the existing one. It
@@ -260,11 +287,11 @@
   size_t idx = msg_hdr->seq % SSL_MAX_HANDSHAKE_FLIGHT;
   hm_fragment *frag = ssl->d1->incoming_messages[idx];
   if (frag != NULL) {
-    assert(frag->msg_header.seq == msg_hdr->seq);
+    assert(frag->seq == msg_hdr->seq);
     /* The new fragment must be compatible with the previous fragments from this
      * message. */
-    if (frag->msg_header.type != msg_hdr->type ||
-        frag->msg_header.msg_len != msg_hdr->msg_len) {
+    if (frag->type != msg_hdr->type ||
+        frag->msg_len != msg_hdr->msg_len) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
       ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
       return NULL;
@@ -273,11 +300,10 @@
   }
 
   /* This is the first fragment from this message. */
-  frag = dtls1_hm_fragment_new(msg_hdr->msg_len);
+  frag = dtls1_hm_fragment_new(msg_hdr);
   if (frag == NULL) {
     return NULL;
   }
-  memcpy(&frag->msg_header, msg_hdr, sizeof(*msg_hdr));
   ssl->d1->incoming_messages[idx] = frag;
   return frag;
 }
@@ -353,7 +379,7 @@
     if (frag == NULL) {
       return -1;
     }
-    assert(frag->msg_header.msg_len == msg_len);
+    assert(frag->msg_len == msg_len);
 
     if (frag->reassembly == NULL) {
       /* The message is already assembled. */
@@ -362,7 +388,8 @@
     assert(msg_len > 0);
 
     /* Copy the body into the fragment. */
-    memcpy(frag->fragment + frag_off, CBS_data(&body), CBS_len(&body));
+    memcpy(frag->data + DTLS1_HM_HEADER_LENGTH + frag_off, CBS_data(&body),
+           CBS_len(&body));
     dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
   }
 
@@ -376,101 +403,63 @@
  * arrive in fragments. */
 long dtls1_get_message(SSL *ssl, int msg_type,
                        enum ssl_hash_message_t hash_message, int *ok) {
-  hm_fragment *frag = NULL;
-  int al;
+  *ok = 0;
 
-  /* s3->tmp is used to store messages that are unexpected, caused
-   * by the absence of an optional handshake message */
   if (ssl->s3->tmp.reuse_message) {
     /* A ssl_dont_hash_message call cannot be combined with reuse_message; the
      * ssl_dont_hash_message would have to have been applied to the previous
      * call. */
     assert(hash_message == ssl_hash_message);
+    assert(dtls1_is_current_message_complete(ssl));
+
     ssl->s3->tmp.reuse_message = 0;
-    if (msg_type >= 0 && ssl->s3->tmp.message_type != msg_type) {
-      al = SSL_AD_UNEXPECTED_MESSAGE;
-      OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
-      goto f_err;
-    }
-    *ok = 1;
-    assert(ssl->init_buf->length >= DTLS1_HM_HEADER_LENGTH);
-    ssl->init_msg = (uint8_t *)ssl->init_buf->data + DTLS1_HM_HEADER_LENGTH;
-    ssl->init_num = (int)ssl->init_buf->length - DTLS1_HM_HEADER_LENGTH;
-    return ssl->init_num;
+    hash_message = ssl_dont_hash_message;
+  } else if (dtls1_is_current_message_complete(ssl)) {
+    dtls1_pop_message(ssl);
   }
 
-  /* Process handshake records until the next message is ready. */
-  while (!dtls1_is_next_message_complete(ssl)) {
+  /* Process handshake records until the current message is ready. */
+  while (!dtls1_is_current_message_complete(ssl)) {
     int ret = dtls1_process_handshake_record(ssl);
     if (ret <= 0) {
-      *ok = 0;
       return ret;
     }
   }
 
-  /* Pop an entry from the ring buffer. */
-  frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
-                                    SSL_MAX_HANDSHAKE_FLIGHT];
-  ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
-                             SSL_MAX_HANDSHAKE_FLIGHT] = NULL;
-
+  hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
+                                                 SSL_MAX_HANDSHAKE_FLIGHT];
   assert(frag != NULL);
   assert(frag->reassembly == NULL);
-  assert(ssl->d1->handshake_read_seq == frag->msg_header.seq);
-
-  ssl->d1->handshake_read_seq++;
-
-  /* Reconstruct the assembled message. */
-  CBB cbb;
-  CBB_zero(&cbb);
-  if (!BUF_MEM_reserve(ssl->init_buf, (size_t)frag->msg_header.msg_len +
-                                          DTLS1_HM_HEADER_LENGTH) ||
-      !CBB_init_fixed(&cbb, (uint8_t *)ssl->init_buf->data,
-                      ssl->init_buf->max) ||
-      !CBB_add_u8(&cbb, frag->msg_header.type) ||
-      !CBB_add_u24(&cbb, frag->msg_header.msg_len) ||
-      !CBB_add_u16(&cbb, frag->msg_header.seq) ||
-      !CBB_add_u24(&cbb, 0 /* frag_off */) ||
-      !CBB_add_u24(&cbb, frag->msg_header.msg_len) ||
-      !CBB_add_bytes(&cbb, frag->fragment, frag->msg_header.msg_len) ||
-      !CBB_finish(&cbb, NULL, &ssl->init_buf->length)) {
-    CBB_cleanup(&cbb);
-    OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-    goto err;
-  }
-  assert(ssl->init_buf->length ==
-         (size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
+  assert(ssl->d1->handshake_read_seq == frag->seq);
 
   /* TODO(davidben): This function has a lot of implicit outputs. Simplify the
    * |ssl_get_message| API. */
-  ssl->s3->tmp.message_type = frag->msg_header.type;
-  ssl->init_msg = (uint8_t *)ssl->init_buf->data + DTLS1_HM_HEADER_LENGTH;
-  ssl->init_num = frag->msg_header.msg_len;
+  ssl->s3->tmp.message_type = frag->type;
+  ssl->init_msg = frag->data + DTLS1_HM_HEADER_LENGTH;
+  ssl->init_num = frag->msg_len;
 
   if (msg_type >= 0 && ssl->s3->tmp.message_type != msg_type) {
-    al = SSL_AD_UNEXPECTED_MESSAGE;
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
-    goto f_err;
+    return -1;
   }
-  if (hash_message == ssl_hash_message && !ssl3_hash_current_message(ssl)) {
-    goto err;
+  if (hash_message == ssl_hash_message && !dtls1_hash_current_message(ssl)) {
+    return -1;
   }
 
   ssl_do_msg_callback(ssl, 0 /* read */, ssl->version, SSL3_RT_HANDSHAKE,
-                      ssl->init_buf->data,
-                      ssl->init_num + DTLS1_HM_HEADER_LENGTH);
-
-  dtls1_hm_fragment_free(frag);
-
+                      frag->data, ssl->init_num + DTLS1_HM_HEADER_LENGTH);
   *ok = 1;
   return ssl->init_num;
+}
 
-f_err:
-  ssl3_send_alert(ssl, SSL3_AL_FATAL, al);
-err:
-  dtls1_hm_fragment_free(frag);
-  *ok = 0;
-  return -1;
+int dtls1_hash_current_message(SSL *ssl) {
+  assert(dtls1_is_current_message_complete(ssl));
+
+  hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
+                                                 SSL_MAX_HANDSHAKE_FLIGHT];
+  return ssl3_update_handshake_hash(ssl, frag->data,
+                                    DTLS1_HM_HEADER_LENGTH + frag->msg_len);
 }
 
 void dtls_clear_incoming_messages(SSL *ssl) {