Add bssl::Span<T>::subspan and use it.

This roughly aligns with absl::Span<T>::subspan.

Bug: 132
Change-Id: Iaf29418c1b10e2d357763dec90b6cb1371b86c3b
Reviewed-on: https://boringssl-review.googlesource.com/20824
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Martin Kreichgauer <martinkr@google.com>
diff --git a/include/openssl/span.h b/include/openssl/span.h
index d447314..97361c2 100644
--- a/include/openssl/span.h
+++ b/include/openssl/span.h
@@ -22,6 +22,7 @@
 extern "C++" {
 
 #include <algorithm>
+#include <cstdlib>
 #include <type_traits>
 
 namespace bssl {
@@ -104,6 +105,8 @@
       std::is_convertible<decltype(std::declval<C>().data()), T *>::value &&
       std::is_integral<decltype(std::declval<C>().size())>::value>;
 
+  static const size_t npos = -1;
+
  public:
   constexpr Span() : Span(nullptr, 0) {}
   constexpr Span(T *ptr, size_t len) : data_(ptr), size_(len) {}
@@ -124,6 +127,7 @@
 
   T *data() const { return data_; }
   size_t size() const { return size_; }
+  bool empty() const { return size_ == 0; }
 
   T *begin() const { return data_; }
   const T *cbegin() const { return data_; }
@@ -133,12 +137,22 @@
   T &operator[](size_t i) const { return data_[i]; }
   T &at(size_t i) const { return data_[i]; }
 
+  Span subspan(size_t pos = 0, size_t len = npos) const {
+    if (pos > size_) {
+      abort();  // absl::Span throws an exception here.
+    }
+    return Span(data_ + pos, std::min(size_ - pos, len));
+  }
+
  private:
   T *data_;
   size_t size_;
 };
 
 template <typename T>
+const size_t Span<T>::npos;
+
+template <typename T>
 Span<T> MakeSpan(T *ptr, size_t size) {
   return Span<T>(ptr, size);
 }
diff --git a/ssl/handshake.cc b/ssl/handshake.cc
index c10c40f..a03b140 100644
--- a/ssl/handshake.cc
+++ b/ssl/handshake.cc
@@ -147,7 +147,6 @@
 
 SSL_HANDSHAKE::~SSL_HANDSHAKE() {
   ssl->ctx->x509_method->hs_flush_cached_ca_names(this);
-  OPENSSL_free(key_block);
 }
 
 SSL_HANDSHAKE *ssl_handshake_new(SSL *ssl) {
diff --git a/ssl/internal.h b/ssl/internal.h
index 13e6655..97e3be1 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -258,6 +258,7 @@
   const T *data() const { return data_; }
   T *data() { return data_; }
   size_t size() const { return size_; }
+  bool empty() const { return size_ == 0; }
 
   const T &operator[](size_t i) const { return data_[i]; }
   T &operator[](size_t i) { return data_[i]; }
@@ -618,11 +619,12 @@
   // returns nullptr on error. Only one of |Open| or |Seal| may be used with the
   // resulting object, depending on |direction|. |version| is the normalized
   // protocol version, so DTLS 1.0 is represented as 0x0301, not 0xffef.
-  static UniquePtr<SSLAEADContext> Create(
-      enum evp_aead_direction_t direction, uint16_t version, int is_dtls,
-      const SSL_CIPHER *cipher, const uint8_t *enc_key, size_t enc_key_len,
-      const uint8_t *mac_key, size_t mac_key_len, const uint8_t *fixed_iv,
-      size_t fixed_iv_len);
+  static UniquePtr<SSLAEADContext> Create(enum evp_aead_direction_t direction,
+                                          uint16_t version, int is_dtls,
+                                          const SSL_CIPHER *cipher,
+                                          Span<const uint8_t> enc_key,
+                                          Span<const uint8_t> mac_key,
+                                          Span<const uint8_t> fixed_iv);
 
   // SetVersionIfNullCipher sets the version the SSLAEADContext for the null
   // cipher, to make version-specific determinations in the record layer prior
@@ -1335,8 +1337,9 @@
   // CertificateRequest message.
   UniquePtr<STACK_OF(CRYPTO_BUFFER)> ca_names;
 
-  // cached_x509_ca_names contains a cache of parsed versions of the elements
-  // of |ca_names|.
+  // cached_x509_ca_names contains a cache of parsed versions of the elements of
+  // |ca_names|. This pointer is left non-owning so only
+  // |ssl_crypto_x509_method| needs to link against crypto/x509.
   STACK_OF(X509_NAME) *cached_x509_ca_names = nullptr;
 
   // certificate_types, on the client, contains the set of certificate types
@@ -1361,8 +1364,7 @@
   const SSL_CIPHER *new_cipher = nullptr;
 
   // key_block is the record-layer key block for TLS 1.2 and earlier.
-  uint8_t *key_block = nullptr;
-  uint8_t key_block_len = 0;
+  Array<uint8_t> key_block;
 
   // scts_requested is true if the SCT extension is in the ClientHello.
   bool scts_requested:1;
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index f63ed26..f48e5e7 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -132,8 +132,8 @@
 
 namespace bssl {
 
-static int add_record_to_flight(SSL *ssl, uint8_t type, const uint8_t *in,
-                                size_t in_len) {
+static int add_record_to_flight(SSL *ssl, uint8_t type,
+                                Span<const uint8_t> in) {
   // We'll never add a flight while in the process of writing it out.
   assert(ssl->s3->pending_flight_offset == 0);
 
@@ -144,18 +144,19 @@
     }
   }
 
-  size_t max_out = in_len + SSL_max_seal_overhead(ssl);
+  size_t max_out = in.size() + SSL_max_seal_overhead(ssl);
   size_t new_cap = ssl->s3->pending_flight->length + max_out;
-  if (max_out < in_len || new_cap < max_out) {
+  if (max_out < in.size() || new_cap < max_out) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
     return 0;
   }
 
   size_t len;
   if (!BUF_MEM_reserve(ssl->s3->pending_flight, new_cap) ||
-      !tls_seal_record(ssl, (uint8_t *)ssl->s3->pending_flight->data +
-                                ssl->s3->pending_flight->length,
-                       &len, max_out, type, in, in_len)) {
+      !tls_seal_record(ssl,
+                       (uint8_t *)ssl->s3->pending_flight->data +
+                           ssl->s3->pending_flight->length,
+                       &len, max_out, type, in.data(), in.size())) {
     return 0;
   }
 
@@ -183,12 +184,10 @@
 int ssl3_add_message(SSL *ssl, Array<uint8_t> msg) {
   // Add the message to the current flight, splitting into several records if
   // needed.
-  size_t added = 0;
+  Span<const uint8_t> rest = msg;
   do {
-    size_t todo = msg.size() - added;
-    if (todo > ssl->max_send_fragment) {
-      todo = ssl->max_send_fragment;
-    }
+    Span<const uint8_t> chunk = rest.subspan(0, ssl->max_send_fragment);
+    rest = rest.subspan(chunk.size());
 
     uint8_t type = SSL3_RT_HANDSHAKE;
     if (ssl->server &&
@@ -198,11 +197,10 @@
       type = SSL3_RT_PLAINTEXT_HANDSHAKE;
     }
 
-    if (!add_record_to_flight(ssl, type, msg.data() + added, todo)) {
+    if (!add_record_to_flight(ssl, type, chunk)) {
       return 0;
     }
-    added += todo;
-  } while (added < msg.size());
+  } while (!rest.empty());
 
   ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HANDSHAKE, msg.data(),
                       msg.size());
@@ -218,8 +216,8 @@
 int ssl3_add_change_cipher_spec(SSL *ssl) {
   static const uint8_t kChangeCipherSpec[1] = {SSL3_MT_CCS};
 
-  if (!add_record_to_flight(ssl, SSL3_RT_CHANGE_CIPHER_SPEC, kChangeCipherSpec,
-                            sizeof(kChangeCipherSpec))) {
+  if (!add_record_to_flight(ssl, SSL3_RT_CHANGE_CIPHER_SPEC,
+                            kChangeCipherSpec)) {
     return 0;
   }
 
@@ -230,7 +228,7 @@
 
 int ssl3_add_alert(SSL *ssl, uint8_t level, uint8_t desc) {
   uint8_t alert[2] = {level, desc};
-  if (!add_record_to_flight(ssl, SSL3_RT_ALERT, alert, sizeof(alert))) {
+  if (!add_record_to_flight(ssl, SSL3_RT_ALERT, alert)) {
     return 0;
   }
 
diff --git a/ssl/ssl_aead_ctx.cc b/ssl/ssl_aead_ctx.cc
index d03a4a0..8856f74 100644
--- a/ssl/ssl_aead_ctx.cc
+++ b/ssl/ssl_aead_ctx.cc
@@ -56,10 +56,8 @@
 
 UniquePtr<SSLAEADContext> SSLAEADContext::Create(
     enum evp_aead_direction_t direction, uint16_t version, int is_dtls,
-    const SSL_CIPHER *cipher, const uint8_t *enc_key, size_t enc_key_len,
-    const uint8_t *mac_key, size_t mac_key_len, const uint8_t *fixed_iv,
-    size_t fixed_iv_len) {
-
+    const SSL_CIPHER *cipher, Span<const uint8_t> enc_key,
+    Span<const uint8_t> mac_key, Span<const uint8_t> fixed_iv) {
   const EVP_AEAD *aead;
   uint16_t protocol_version;
   size_t expected_mac_key_len, expected_fixed_iv_len;
@@ -68,27 +66,27 @@
                                &expected_fixed_iv_len, cipher, protocol_version,
                                is_dtls) ||
       // Ensure the caller returned correct key sizes.
-      expected_fixed_iv_len != fixed_iv_len ||
-      expected_mac_key_len != mac_key_len) {
+      expected_fixed_iv_len != fixed_iv.size() ||
+      expected_mac_key_len != mac_key.size()) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return nullptr;
   }
 
   uint8_t merged_key[EVP_AEAD_MAX_KEY_LENGTH];
-  if (mac_key_len > 0) {
+  if (!mac_key.empty()) {
     // This is a "stateful" AEAD (for compatibility with pre-AEAD cipher
     // suites).
-    if (mac_key_len + enc_key_len + fixed_iv_len > sizeof(merged_key)) {
+    if (mac_key.size() + enc_key.size() + fixed_iv.size() >
+        sizeof(merged_key)) {
       OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
       return nullptr;
     }
-    OPENSSL_memcpy(merged_key, mac_key, mac_key_len);
-    OPENSSL_memcpy(merged_key + mac_key_len, enc_key, enc_key_len);
-    OPENSSL_memcpy(merged_key + mac_key_len + enc_key_len, fixed_iv,
-                   fixed_iv_len);
-    enc_key = merged_key;
-    enc_key_len += mac_key_len;
-    enc_key_len += fixed_iv_len;
+    OPENSSL_memcpy(merged_key, mac_key.data(), mac_key.size());
+    OPENSSL_memcpy(merged_key + mac_key.size(), enc_key.data(), enc_key.size());
+    OPENSSL_memcpy(merged_key + mac_key.size() + enc_key.size(),
+                   fixed_iv.data(), fixed_iv.size());
+    enc_key = MakeConstSpan(merged_key,
+                            enc_key.size() + mac_key.size() + fixed_iv.size());
   }
 
   UniquePtr<SSLAEADContext> aead_ctx =
@@ -101,7 +99,7 @@
   assert(aead_ctx->ProtocolVersion() == protocol_version);
 
   if (!EVP_AEAD_CTX_init_with_direction(
-          aead_ctx->ctx_.get(), aead, enc_key, enc_key_len,
+          aead_ctx->ctx_.get(), aead, enc_key.data(), enc_key.size(),
           EVP_AEAD_DEFAULT_TAG_LENGTH, direction)) {
     return nullptr;
   }
@@ -110,10 +108,10 @@
   static_assert(EVP_AEAD_MAX_NONCE_LENGTH < 256,
                 "variable_nonce_len doesn't fit in uint8_t");
   aead_ctx->variable_nonce_len_ = (uint8_t)EVP_AEAD_nonce_length(aead);
-  if (mac_key_len == 0) {
-    assert(fixed_iv_len <= sizeof(aead_ctx->fixed_nonce_));
-    OPENSSL_memcpy(aead_ctx->fixed_nonce_, fixed_iv, fixed_iv_len);
-    aead_ctx->fixed_nonce_len_ = fixed_iv_len;
+  if (mac_key.empty()) {
+    assert(fixed_iv.size() <= sizeof(aead_ctx->fixed_nonce_));
+    OPENSSL_memcpy(aead_ctx->fixed_nonce_, fixed_iv.data(), fixed_iv.size());
+    aead_ctx->fixed_nonce_len_ = fixed_iv.size();
 
     if (cipher->algorithm_enc & SSL_CHACHA20POLY1305) {
       // The fixed nonce into the actual nonce (the sequence number).
@@ -121,8 +119,8 @@
       aead_ctx->variable_nonce_len_ = 8;
     } else {
       // The fixed IV is prepended to the nonce.
-      assert(fixed_iv_len <= aead_ctx->variable_nonce_len_);
-      aead_ctx->variable_nonce_len_ -= fixed_iv_len;
+      assert(fixed_iv.size() <= aead_ctx->variable_nonce_len_);
+      aead_ctx->variable_nonce_len_ -= fixed_iv.size();
     }
 
     // AES-GCM uses an explicit nonce.
@@ -137,7 +135,7 @@
       aead_ctx->variable_nonce_len_ = 8;
       aead_ctx->variable_nonce_included_in_record_ = false;
       aead_ctx->omit_ad_ = true;
-      assert(fixed_iv_len >= aead_ctx->variable_nonce_len_);
+      assert(fixed_iv.size() >= aead_ctx->variable_nonce_len_);
     }
   } else {
     assert(protocol_version < TLS1_3_VERSION);
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index d693007..8f8d328 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -318,7 +318,7 @@
 
 static int tls1_setup_key_block(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  if (hs->key_block_len != 0) {
+  if (!hs->key_block.empty()) {
     return 1;
   }
 
@@ -356,22 +356,13 @@
   ssl->s3->tmp.new_key_len = (uint8_t)key_len;
   ssl->s3->tmp.new_fixed_iv_len = (uint8_t)fixed_iv_len;
 
-  size_t key_block_len = SSL_get_key_block_len(ssl);
-
-  uint8_t *keyblock = (uint8_t *)OPENSSL_malloc(key_block_len);
-  if (keyblock == NULL) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+  Array<uint8_t> key_block;
+  if (!key_block.Init(SSL_get_key_block_len(ssl)) ||
+      !SSL_generate_key_block(ssl, key_block.data(), key_block.size())) {
     return 0;
   }
 
-  if (!SSL_generate_key_block(ssl, keyblock, key_block_len)) {
-    OPENSSL_free(keyblock);
-    return 0;
-  }
-
-  assert(key_block_len < 256);
-  hs->key_block_len = (uint8_t)key_block_len;
-  hs->key_block = keyblock;
+  hs->key_block = std::move(key_block);
   return 1;
 }
 
@@ -383,45 +374,28 @@
     return 0;
   }
 
-  // use_client_keys is true if we wish to use the keys for the "client write"
-  // direction. This is the case if we're a client sending a ChangeCipherSpec,
-  // or a server reading a client's ChangeCipherSpec.
-  const bool use_client_keys =
-      direction == (ssl->server ? evp_aead_open : evp_aead_seal);
-
   size_t mac_secret_len = ssl->s3->tmp.new_mac_secret_len;
   size_t key_len = ssl->s3->tmp.new_key_len;
   size_t iv_len = ssl->s3->tmp.new_fixed_iv_len;
-  assert((mac_secret_len + key_len + iv_len) * 2 == hs->key_block_len);
+  assert((mac_secret_len + key_len + iv_len) * 2 == hs->key_block.size());
 
-  const uint8_t *key_data = hs->key_block;
-  const uint8_t *client_write_mac_secret = key_data;
-  key_data += mac_secret_len;
-  const uint8_t *server_write_mac_secret = key_data;
-  key_data += mac_secret_len;
-  const uint8_t *client_write_key = key_data;
-  key_data += key_len;
-  const uint8_t *server_write_key = key_data;
-  key_data += key_len;
-  const uint8_t *client_write_iv = key_data;
-  key_data += iv_len;
-  const uint8_t *server_write_iv = key_data;
-  key_data += iv_len;
-
-  const uint8_t *mac_secret, *key, *iv;
-  if (use_client_keys) {
-    mac_secret = client_write_mac_secret;
-    key = client_write_key;
-    iv = client_write_iv;
+  Span<const uint8_t> key_block = hs->key_block;
+  Span<const uint8_t> mac_secret, key, iv;
+  if (direction == (ssl->server ? evp_aead_open : evp_aead_seal)) {
+    // Use the client write (server read) keys.
+    mac_secret = key_block.subspan(0, mac_secret_len);
+    key = key_block.subspan(2 * mac_secret_len, key_len);
+    iv = key_block.subspan(2 * mac_secret_len + 2 * key_len, iv_len);
   } else {
-    mac_secret = server_write_mac_secret;
-    key = server_write_key;
-    iv = server_write_iv;
+    // Use the server write (client read) keys.
+    mac_secret = key_block.subspan(mac_secret_len, mac_secret_len);
+    key = key_block.subspan(2 * mac_secret_len + key_len, key_len);
+    iv = key_block.subspan(2 * mac_secret_len + 2 * key_len + iv_len, iv_len);
   }
 
-  UniquePtr<SSLAEADContext> aead_ctx = SSLAEADContext::Create(
-      direction, ssl->version, SSL_is_dtls(ssl), hs->new_cipher, key, key_len,
-      mac_secret, mac_secret_len, iv, iv_len);
+  UniquePtr<SSLAEADContext> aead_ctx =
+      SSLAEADContext::Create(direction, ssl->version, SSL_is_dtls(ssl),
+                             hs->new_cipher, key, mac_secret, iv);
   if (!aead_ctx) {
     return 0;
   }
diff --git a/ssl/t1_lib.cc b/ssl/t1_lib.cc
index 1fe360b..3b0a335 100644
--- a/ssl/t1_lib.cc
+++ b/ssl/t1_lib.cc
@@ -2111,7 +2111,7 @@
 
     // Predict the most preferred group.
     Span<const uint16_t> groups = tls1_get_grouplist(ssl);
-    if (groups.size() == 0) {
+    if (groups.empty()) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_NO_GROUPS_SPECIFIED);
       return 0;
     }
@@ -2290,7 +2290,7 @@
 // https://tools.ietf.org/html/draft-ietf-tls-tls13-16#section-4.2.2
 
 static int ext_cookie_add_clienthello(SSL_HANDSHAKE *hs, CBB *out) {
-  if (hs->cookie.size() == 0) {
+  if (hs->cookie.empty()) {
     return 1;
   }
 
@@ -3184,7 +3184,7 @@
   }
 
   Span<const uint16_t> peer_sigalgs = hs->peer_sigalgs;
-  if (peer_sigalgs.size() == 0 && ssl3_protocol_version(ssl) < TLS1_3_VERSION) {
+  if (peer_sigalgs.empty() && ssl3_protocol_version(ssl) < TLS1_3_VERSION) {
     // If the client didn't specify any signature_algorithms extension then
     // we can assume that it supports SHA1. See
     // http://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index 6ff9972..b68a39e 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -149,9 +149,10 @@
     return 0;
   }
 
-  UniquePtr<SSLAEADContext> traffic_aead = SSLAEADContext::Create(
-      direction, session->ssl_version, SSL_is_dtls(ssl), session->cipher, key,
-      key_len, NULL, 0, iv, iv_len);
+  UniquePtr<SSLAEADContext> traffic_aead =
+      SSLAEADContext::Create(direction, session->ssl_version, SSL_is_dtls(ssl),
+                             session->cipher, MakeConstSpan(key, key_len),
+                             Span<const uint8_t>(), MakeConstSpan(iv, iv_len));
   if (!traffic_aead) {
     return 0;
   }