Test-only DTLS implementation in runner.go.

Run against openssl s_client and openssl s_server. This seems to work for a
start, although it may need to become cleverer to stress more of BoringSSL's
implementation for test purposes.

In particular, it assumes a reliable, in-order channel. And it requires that
the peer send handshake fragments in order. Retransmit and whatnot are not
implemented. The peer under test will be expected to handle a lossy channel,
but all loss in the channel will be controlled. MAC errors, etc., are fatal.

Change-Id: I329233cfb0994938fd012667ddf7c6a791ac7164
Reviewed-on: https://boringssl-review.googlesource.com/1390
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index f3e2495..5371a64 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -24,6 +24,7 @@
 type Conn struct {
 	// constant
 	conn     net.Conn
+	isDTLS   bool
 	isClient bool
 
 	// constant after handshake; protected by handshakeMutex
@@ -49,8 +50,14 @@
 	// input/output
 	in, out  halfConn     // in.Mutex < out.Mutex
 	rawInput *block       // raw input, right off the wire
-	input    *block       // application data waiting to be read
-	hand     bytes.Buffer // handshake data waiting to be read
+	input    *block       // application record waiting to be read
+	hand     bytes.Buffer // handshake record waiting to be read
+
+	// DTLS state
+	sendHandshakeSeq uint16
+	recvHandshakeSeq uint16
+	handMsg          []byte // pending assembled handshake message
+	handMsgLen       int    // handshake message length, not including the header
 
 	tmp [16]byte
 }
@@ -94,8 +101,9 @@
 type halfConn struct {
 	sync.Mutex
 
-	err     error       // first permanent error
-	version uint16      // protocol version
+	err     error  // first permanent error
+	version uint16 // protocol version
+	isDTLS  bool
 	cipher  interface{} // cipher algorithm
 	mac     macFunction
 	seq     [8]byte // 64-bit sequence number
@@ -141,15 +149,18 @@
 	hc.nextCipher = nil
 	hc.nextMac = nil
 	hc.config = config
-	for i := range hc.seq {
-		hc.seq[i] = 0
-	}
+	hc.incEpoch()
 	return nil
 }
 
 // incSeq increments the sequence number.
 func (hc *halfConn) incSeq() {
-	for i := 7; i >= 0; i-- {
+	limit := 0
+	if hc.isDTLS {
+		// Increment up to the epoch in DTLS.
+		limit = 2
+	}
+	for i := 7; i >= limit; i-- {
 		hc.seq[i]++
 		if hc.seq[i] != 0 {
 			return
@@ -162,11 +173,33 @@
 	panic("TLS: sequence number wraparound")
 }
 
-// resetSeq resets the sequence number to zero.
-func (hc *halfConn) resetSeq() {
-	for i := range hc.seq {
-		hc.seq[i] = 0
+// incEpoch resets the sequence number. In DTLS, it increments the
+// epoch half of the sequence number.
+func (hc *halfConn) incEpoch() {
+	limit := 0
+	if hc.isDTLS {
+		for i := 1; i >= 0; i-- {
+			hc.seq[i]++
+			if hc.seq[i] != 0 {
+				break
+			}
+			if i == 0 {
+				panic("TLS: epoch number wraparound")
+			}
+		}
+		limit = 2
 	}
+	seq := hc.seq[limit:]
+	for i := range seq {
+		seq[i] = 0
+	}
+}
+
+func (hc *halfConn) recordHeaderLen() int {
+	if hc.isDTLS {
+		return dtlsRecordHeaderLen
+	}
+	return tlsRecordHeaderLen
 }
 
 // removePadding returns an unpadded slice, in constant time, which is a prefix
@@ -237,6 +270,8 @@
 // success boolean, the number of bytes to skip from the start of the record in
 // order to get the application payload, and an optional alert value.
 func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) {
+	recordHeaderLen := hc.recordHeaderLen()
+
 	// pull out payload
 	payload := b.data[recordHeaderLen:]
 
@@ -248,6 +283,12 @@
 	paddingGood := byte(255)
 	explicitIVLen := 0
 
+	seq := hc.seq[:]
+	if hc.isDTLS {
+		// DTLS sequence numbers are explicit.
+		seq = b.data[3:11]
+	}
+
 	// decrypt
 	if hc.cipher != nil {
 		switch c := hc.cipher.(type) {
@@ -262,7 +303,7 @@
 			payload = payload[8:]
 
 			var additionalData [13]byte
-			copy(additionalData[:], hc.seq[:])
+			copy(additionalData[:], seq)
 			copy(additionalData[8:], b.data[:3])
 			n := len(payload) - c.Overhead()
 			additionalData[11] = byte(n >> 8)
@@ -275,7 +316,7 @@
 			b.resize(recordHeaderLen + explicitIVLen + len(payload))
 		case cbcMode:
 			blockSize := c.BlockSize()
-			if hc.version >= VersionTLS11 {
+			if hc.version >= VersionTLS11 || hc.isDTLS {
 				explicitIVLen = blockSize
 			}
 
@@ -318,11 +359,11 @@
 
 		// strip mac off payload, b.data
 		n := len(payload) - macSize
-		b.data[3] = byte(n >> 8)
-		b.data[4] = byte(n)
+		b.data[recordHeaderLen-2] = byte(n >> 8)
+		b.data[recordHeaderLen-1] = byte(n)
 		b.resize(recordHeaderLen + explicitIVLen + n)
 		remoteMAC := payload[n:]
-		localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n])
+		localMAC := hc.mac.MAC(hc.inDigestBuf, seq, b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], payload[:n])
 
 		if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
 			return false, 0, alertBadRecordMAC
@@ -364,9 +405,11 @@
 
 // encrypt encrypts and macs the data in b.
 func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) {
+	recordHeaderLen := hc.recordHeaderLen()
+
 	// mac
 	if hc.mac != nil {
-		mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:])
+		mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:])
 
 		n := len(b.data)
 		b.resize(n + len(mac))
@@ -412,8 +455,8 @@
 
 	// update length to include MAC and any block padding needed.
 	n := len(b.data) - recordHeaderLen
-	b.data[3] = byte(n >> 8)
-	b.data[4] = byte(n)
+	b.data[recordHeaderLen-2] = byte(n >> 8)
+	b.data[recordHeaderLen-1] = byte(n)
 	hc.incSeq()
 
 	return true, 0
@@ -517,6 +560,86 @@
 	return b, bb
 }
 
+func (c *Conn) doReadRecord(want recordType) (recordType, *block, error) {
+	if c.isDTLS {
+		return c.dtlsDoReadRecord(want)
+	}
+
+	recordHeaderLen := tlsRecordHeaderLen
+
+	if c.rawInput == nil {
+		c.rawInput = c.in.newBlock()
+	}
+	b := c.rawInput
+
+	// Read header, payload.
+	if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
+		// RFC suggests that EOF without an alertCloseNotify is
+		// an error, but popular web sites seem to do this,
+		// so we can't make it an error.
+		// if err == io.EOF {
+		// 	err = io.ErrUnexpectedEOF
+		// }
+		if e, ok := err.(net.Error); !ok || !e.Temporary() {
+			c.in.setErrorLocked(err)
+		}
+		return 0, nil, err
+	}
+	typ := recordType(b.data[0])
+
+	// No valid TLS record has a type of 0x80, however SSLv2 handshakes
+	// start with a uint16 length where the MSB is set and the first record
+	// is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
+	// an SSLv2 client.
+	if want == recordTypeHandshake && typ == 0x80 {
+		c.sendAlert(alertProtocolVersion)
+		return 0, nil, c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
+	}
+
+	vers := uint16(b.data[1])<<8 | uint16(b.data[2])
+	n := int(b.data[3])<<8 | int(b.data[4])
+	if c.haveVers && vers != c.vers {
+		c.sendAlert(alertProtocolVersion)
+		return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
+	}
+	if n > maxCiphertext {
+		c.sendAlert(alertRecordOverflow)
+		return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
+	}
+	if !c.haveVers {
+		// First message, be extra suspicious:
+		// this might not be a TLS client.
+		// Bail out before reading a full 'body', if possible.
+		// The current max version is 3.1.
+		// If the version is >= 16.0, it's probably not real.
+		// Similarly, a clientHello message encodes in
+		// well under a kilobyte.  If the length is >= 12 kB,
+		// it's probably not real.
+		if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
+			c.sendAlert(alertUnexpectedMessage)
+			return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
+		}
+	}
+	if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
+		if err == io.EOF {
+			err = io.ErrUnexpectedEOF
+		}
+		if e, ok := err.(net.Error); !ok || !e.Temporary() {
+			c.in.setErrorLocked(err)
+		}
+		return 0, nil, err
+	}
+
+	// Process message.
+	b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
+	ok, off, err := c.in.decrypt(b)
+	if !ok {
+		c.in.setErrorLocked(c.sendAlert(err))
+	}
+	b.off = off
+	return typ, b, nil
+}
+
 // readRecord reads the next TLS record from the connection
 // and updates the record layer state.
 // c.in.Mutex <= L; c.input == nil.
@@ -541,76 +664,10 @@
 	}
 
 Again:
-	if c.rawInput == nil {
-		c.rawInput = c.in.newBlock()
-	}
-	b := c.rawInput
-
-	// Read header, payload.
-	if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
-		// RFC suggests that EOF without an alertCloseNotify is
-		// an error, but popular web sites seem to do this,
-		// so we can't make it an error.
-		// if err == io.EOF {
-		// 	err = io.ErrUnexpectedEOF
-		// }
-		if e, ok := err.(net.Error); !ok || !e.Temporary() {
-			c.in.setErrorLocked(err)
-		}
+	typ, b, err := c.doReadRecord(want)
+	if err != nil {
 		return err
 	}
-	typ := recordType(b.data[0])
-
-	// No valid TLS record has a type of 0x80, however SSLv2 handshakes
-	// start with a uint16 length where the MSB is set and the first record
-	// is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
-	// an SSLv2 client.
-	if want == recordTypeHandshake && typ == 0x80 {
-		c.sendAlert(alertProtocolVersion)
-		return c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
-	}
-
-	vers := uint16(b.data[1])<<8 | uint16(b.data[2])
-	n := int(b.data[3])<<8 | int(b.data[4])
-	if c.haveVers && vers != c.vers {
-		c.sendAlert(alertProtocolVersion)
-		return c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
-	}
-	if n > maxCiphertext {
-		c.sendAlert(alertRecordOverflow)
-		return c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
-	}
-	if !c.haveVers {
-		// First message, be extra suspicious:
-		// this might not be a TLS client.
-		// Bail out before reading a full 'body', if possible.
-		// The current max version is 3.1.
-		// If the version is >= 16.0, it's probably not real.
-		// Similarly, a clientHello message encodes in
-		// well under a kilobyte.  If the length is >= 12 kB,
-		// it's probably not real.
-		if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
-			c.sendAlert(alertUnexpectedMessage)
-			return c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
-		}
-	}
-	if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
-		if err == io.EOF {
-			err = io.ErrUnexpectedEOF
-		}
-		if e, ok := err.(net.Error); !ok || !e.Temporary() {
-			c.in.setErrorLocked(err)
-		}
-		return err
-	}
-
-	// Process message.
-	b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
-	ok, off, err := c.in.decrypt(b)
-	if !ok {
-		c.in.setErrorLocked(c.sendAlert(err))
-	}
-	b.off = off
 	data := b.data[b.off:]
 	if len(data) > maxPlaintext {
 		err := c.sendAlert(alertRecordOverflow)
@@ -713,6 +770,11 @@
 // to the connection and updates the record layer state.
 // c.out.Mutex <= L.
 func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
+	if c.isDTLS {
+		return c.dtlsWriteRecord(typ, data)
+	}
+
+	recordHeaderLen := tlsRecordHeaderLen
 	b := c.out.newBlock()
 	first := true
 	isClientHello := typ == recordTypeHandshake && len(data) > 0 && data[0] == typeClientHello
@@ -800,10 +862,11 @@
 	return
 }
 
-// readHandshake reads the next handshake message from
-// the record layer.
-// c.in.Mutex < L; c.out.Mutex < L.
-func (c *Conn) readHandshake() (interface{}, error) {
+func (c *Conn) doReadHandshake() ([]byte, error) {
+	if c.isDTLS {
+		return c.dtlsDoReadHandshake()
+	}
+
 	for c.hand.Len() < 4 {
 		if err := c.in.err; err != nil {
 			return nil, err
@@ -826,13 +889,28 @@
 			return nil, err
 		}
 	}
-	data = c.hand.Next(4 + n)
+	return c.hand.Next(4 + n), nil
+}
+
+// readHandshake reads the next handshake message from
+// the record layer.
+// c.in.Mutex < L; c.out.Mutex < L.
+func (c *Conn) readHandshake() (interface{}, error) {
+	data, err := c.doReadHandshake()
+	if err != nil {
+		return nil, err
+	}
+
 	var m handshakeMessage
 	switch data[0] {
 	case typeClientHello:
-		m = new(clientHelloMsg)
+		m = &clientHelloMsg{
+			isDTLS: c.isDTLS,
+		}
 	case typeServerHello:
-		m = new(serverHelloMsg)
+		m = &serverHelloMsg{
+			isDTLS: c.isDTLS,
+		}
 	case typeNewSessionTicket:
 		m = new(newSessionTicketMsg)
 	case typeCertificate:
@@ -857,6 +935,8 @@
 		m = new(nextProtoMsg)
 	case typeFinished:
 		m = new(finishedMsg)
+	case typeHelloVerifyRequest:
+		m = new(helloVerifyRequestMsg)
 	default:
 		return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
 	}
@@ -899,7 +979,7 @@
 	// http://www.imperialviolet.org/2012/01/15/beastfollowup.html
 
 	var m int
-	if len(b) > 1 && c.vers <= VersionTLS10 {
+	if len(b) > 1 && c.vers <= VersionTLS10 && !c.isDTLS {
 		if _, ok := c.out.cipher.(cipher.BlockMode); ok {
 			n, err := c.writeRecord(recordTypeApplicationData, b[:1])
 			if err != nil {
@@ -938,7 +1018,7 @@
 		}
 
 		n, err = c.input.Read(b)
-		if c.input.off >= len(c.input.data) {
+		if c.input.off >= len(c.input.data) || c.isDTLS {
 			c.in.freeBlock(c.input)
 			c.input = nil
 		}