Replace init_msg/init_num with a get_message hook.

Rather than init_msg/init_num, there is a get_message function which
either returns success or try again. This function does not advance the
current message (see the previous preparatory change). It only completes
the current one if necessary.

Being idempotent means it may be freely placed at the top of states
which otherwise have other asychronous operations. It also eases
converting the TLS 1.2 state machine. See
https://docs.google.com/a/google.com/document/d/11n7LHsT3GwE34LAJIe3EFs4165TI4UR_3CqiM9LJVpI/edit?usp=sharing
for details.

The read_message hook (later to be replaced by something which doesn't
depend on BIO) intentionally does not finish the handshake, only "makes
progress". A follow-up change will align both TLS and DTLS on consuming
one handshake record and always consuming the entire record (so init_buf
may contain trailing data). In a few places I've gone ahead and
accounted for that case because it was more natural to do so.

This change also removes a couple pointers of redundant state from every
socket.

Bug: 128
Change-Id: I89d8f3622d3b53147d69ee3ac34bb654ed044a71
Reviewed-on: https://boringssl-review.googlesource.com/18806
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index 4ae6f70..9c4aa7f 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -187,12 +187,11 @@
 
 void ssl_handshake_free(SSL_HANDSHAKE *hs) { Delete(hs); }
 
-int ssl_check_message_type(SSL *ssl, int type) {
-  if (ssl->s3->tmp.message_type != type) {
+int ssl_check_message_type(SSL *ssl, const SSLMessage &msg, int type) {
+  if (msg.type != type) {
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
-    ERR_add_error_dataf("got type %d, wanted type %d",
-                        ssl->s3->tmp.message_type, type);
+    ERR_add_error_dataf("got type %d, wanted type %d", msg.type, type);
     return 0;
   }
 
@@ -422,12 +421,13 @@
 
 int ssl3_get_finished(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  int ret = ssl->method->ssl_get_message(ssl);
+  SSLMessage msg;
+  int ret = ssl_read_message(ssl, &msg);
   if (ret <= 0) {
     return ret;
   }
 
-  if (!ssl_check_message_type(ssl, SSL3_MT_FINISHED)) {
+  if (!ssl_check_message_type(ssl, msg, SSL3_MT_FINISHED)) {
     return -1;
   }
 
@@ -437,12 +437,11 @@
   if (!hs->transcript.GetFinishedMAC(finished, &finished_len,
                                      SSL_get_session(ssl), !ssl->server,
                                      ssl3_protocol_version(ssl)) ||
-      !ssl_hash_current_message(hs)) {
+      !ssl_hash_message(hs, msg)) {
     return -1;
   }
 
-  int finished_ok = ssl->init_num == finished_len &&
-                    CRYPTO_memcmp(ssl->init_msg, finished, finished_len) == 0;
+  int finished_ok = CBS_mem_equal(&msg.body, finished, finished_len);
 #if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
   finished_ok = 1;
 #endif
@@ -516,6 +515,16 @@
   return kMaxMessageLen;
 }
 
+int ssl_read_message(SSL *ssl, SSLMessage *out) {
+  while (!ssl->method->get_message(ssl, out)) {
+    int ret = ssl->method->read_message(ssl);
+    if (ret <= 0) {
+      return ret;
+    }
+  }
+  return 1;
+}
+
 static int extend_handshake_buffer(SSL *ssl, size_t length) {
   if (!BUF_MEM_reserve(ssl->init_buf, length)) {
     return -1;
@@ -683,7 +692,62 @@
   return 1;
 }
 
-int ssl3_get_message(SSL *ssl) {
+/* TODO(davidben): Remove |out_bytes_needed| and inline into |ssl3_get_message|
+ * when the entire record is copied into |init_buf|. */
+static bool parse_message(SSL *ssl, SSLMessage *out, size_t *out_bytes_needed) {
+  if (ssl->init_buf == NULL) {
+    *out_bytes_needed = 4;
+    return false;
+  }
+
+  CBS cbs;
+  uint32_t len;
+  CBS_init(&cbs, reinterpret_cast<const uint8_t *>(ssl->init_buf->data),
+           ssl->init_buf->length);
+  if (!CBS_get_u8(&cbs, &out->type) ||
+      !CBS_get_u24(&cbs, &len)) {
+    *out_bytes_needed = 4;
+    return false;
+  }
+
+  if (!CBS_get_bytes(&cbs, &out->body, len)) {
+    *out_bytes_needed = 4 + len;
+    return false;
+  }
+
+  CBS_init(&out->raw, reinterpret_cast<const uint8_t *>(ssl->init_buf->data),
+           4 + len);
+  out->is_v2_hello = ssl->s3->is_v2_hello;
+  if (!ssl->s3->has_message) {
+    if (!out->is_v2_hello) {
+      ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE,
+                          CBS_data(&out->raw), CBS_len(&out->raw));
+    }
+    ssl->s3->has_message = 1;
+  }
+  return true;
+}
+
+bool ssl3_get_message(SSL *ssl, SSLMessage *out) {
+  size_t unused;
+  return parse_message(ssl, out, &unused);
+}
+
+int ssl3_read_message(SSL *ssl) {
+  SSLMessage msg;
+  size_t bytes_needed;
+  if (parse_message(ssl, &msg, &bytes_needed)) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+    return -1;
+  }
+
+  /* Enforce the limit so the peer cannot force us to buffer 16MB. */
+  if (bytes_needed > 4 + ssl_max_handshake_message_len(ssl)) {
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
+    OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
+    return -1;
+  }
+
   /* Re-create the handshake buffer if needed. */
   if (ssl->init_buf == NULL) {
     ssl->init_buf = BUF_MEM_new();
@@ -692,79 +756,45 @@
     }
   }
 
+  /* Bypass the record layer for the first message to handle V2ClientHello. */
   if (ssl->server && !ssl->s3->v2_hello_done) {
-    /* Bypass the record layer for the first message to handle V2ClientHello. */
     int ret = read_v2_client_hello(ssl);
-    if (ret <= 0) {
-      return ret;
+    if (ret > 0) {
+      ssl->s3->v2_hello_done = 1;
     }
-    ssl->s3->v2_hello_done = 1;
-  }
-
-  /* Read the message header, if we haven't yet. */
-  int ret = extend_handshake_buffer(ssl, SSL3_HM_HEADER_LENGTH);
-  if (ret <= 0) {
     return ret;
   }
 
-  /* Parse out the length. Cap it so the peer cannot force us to buffer up to
-   * 2^24 bytes. */
-  const uint8_t *p = (uint8_t *)ssl->init_buf->data;
-  size_t msg_len = (((uint32_t)p[1]) << 16) | (((uint32_t)p[2]) << 8) | p[3];
-  if (msg_len > ssl_max_handshake_message_len(ssl)) {
-    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
-    OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
-    return -1;
-  }
-
-  /* Read the message body, if we haven't yet. */
-  ret = extend_handshake_buffer(ssl, SSL3_HM_HEADER_LENGTH + msg_len);
-  if (ret <= 0) {
-    return ret;
-  }
-
-  /* We have now received a complete message. */
-  if (ssl->init_msg == NULL && !ssl->s3->is_v2_hello) {
-    ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE,
-                        ssl->init_buf->data, ssl->init_buf->length);
-  }
-
-  ssl->s3->tmp.message_type = ((const uint8_t *)ssl->init_buf->data)[0];
-  ssl->init_msg = (uint8_t*)ssl->init_buf->data + SSL3_HM_HEADER_LENGTH;
-  ssl->init_num = ssl->init_buf->length - SSL3_HM_HEADER_LENGTH;
-  return 1;
+  return extend_handshake_buffer(ssl, bytes_needed);
 }
 
-void ssl3_get_current_message(const SSL *ssl, CBS *out) {
-  CBS_init(out, (uint8_t *)ssl->init_buf->data, ssl->init_buf->length);
-}
-
-int ssl_hash_current_message(SSL_HANDSHAKE *hs) {
-  /* V2ClientHellos are hashed implicitly. */
-  if (hs->ssl->s3->is_v2_hello) {
-    return 1;
+bool ssl_hash_message(SSL_HANDSHAKE *hs, const SSLMessage &msg) {
+  /* V2ClientHello messages are pre-hashed. */
+  if (msg.is_v2_hello) {
+    return true;
   }
 
-  CBS cbs;
-  hs->ssl->method->get_current_message(hs->ssl, &cbs);
-  return hs->transcript.Update(CBS_data(&cbs), CBS_len(&cbs));
+  return hs->transcript.Update(CBS_data(&msg.raw), CBS_len(&msg.raw));
 }
 
 void ssl3_next_message(SSL *ssl) {
-  assert(ssl->init_msg != NULL);
+  SSLMessage msg;
+  if (!ssl3_get_message(ssl, &msg) ||
+      ssl->init_buf == NULL ||
+      ssl->init_buf->length < CBS_len(&msg.raw)) {
+    assert(0);
+    return;
+  }
 
-  /* |init_buf| never contains data beyond the current message. */
-  assert(SSL3_HM_HEADER_LENGTH + ssl->init_num == ssl->init_buf->length);
-
-  /* Clear the current message. */
-  ssl->init_msg = NULL;
-  ssl->init_num = 0;
-  ssl->init_buf->length = 0;
+  OPENSSL_memmove(ssl->init_buf->data, ssl->init_buf->data + CBS_len(&msg.raw),
+                  ssl->init_buf->length - CBS_len(&msg.raw));
+  ssl->init_buf->length -= CBS_len(&msg.raw);
   ssl->s3->is_v2_hello = 0;
+  ssl->s3->has_message = 0;
 
   /* Post-handshake messages are rare, so release the buffer after every
    * message. During the handshake, |on_handshake_complete| will release it. */
-  if (!SSL_in_init(ssl)) {
+  if (!SSL_in_init(ssl) && ssl->init_buf->length == 0) {
     BUF_MEM_free(ssl->init_buf);
     ssl->init_buf = NULL;
   }