Test-only DTLS implementation in runner.go.
Run against openssl s_client and openssl s_server. This seems to work for a
start, although it may need to become cleverer to stress more of BoringSSL's
implementation for test purposes.
In particular, it assumes a reliable, in-order channel. And it requires that
the peer send handshake fragments in order. Retransmit and whatnot are not
implemented. The peer under test will be expected to handle a lossy channel,
but all loss in the channel will be controlled. MAC errors, etc., are fatal.
Change-Id: I329233cfb0994938fd012667ddf7c6a791ac7164
Reviewed-on: https://boringssl-review.googlesource.com/1390
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index f3e2495..5371a64 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -24,6 +24,7 @@
type Conn struct {
// constant
conn net.Conn
+ isDTLS bool
isClient bool
// constant after handshake; protected by handshakeMutex
@@ -49,8 +50,14 @@
// input/output
in, out halfConn // in.Mutex < out.Mutex
rawInput *block // raw input, right off the wire
- input *block // application data waiting to be read
- hand bytes.Buffer // handshake data waiting to be read
+ input *block // application record waiting to be read
+ hand bytes.Buffer // handshake record waiting to be read
+
+ // DTLS state
+ sendHandshakeSeq uint16
+ recvHandshakeSeq uint16
+ handMsg []byte // pending assembled handshake message
+ handMsgLen int // handshake message length, not including the header
tmp [16]byte
}
@@ -94,8 +101,9 @@
type halfConn struct {
sync.Mutex
- err error // first permanent error
- version uint16 // protocol version
+ err error // first permanent error
+ version uint16 // protocol version
+ isDTLS bool
cipher interface{} // cipher algorithm
mac macFunction
seq [8]byte // 64-bit sequence number
@@ -141,15 +149,18 @@
hc.nextCipher = nil
hc.nextMac = nil
hc.config = config
- for i := range hc.seq {
- hc.seq[i] = 0
- }
+ hc.incEpoch()
return nil
}
// incSeq increments the sequence number.
func (hc *halfConn) incSeq() {
- for i := 7; i >= 0; i-- {
+ limit := 0
+ if hc.isDTLS {
+ // Increment up to the epoch in DTLS.
+ limit = 2
+ }
+ for i := 7; i >= limit; i-- {
hc.seq[i]++
if hc.seq[i] != 0 {
return
@@ -162,11 +173,33 @@
panic("TLS: sequence number wraparound")
}
-// resetSeq resets the sequence number to zero.
-func (hc *halfConn) resetSeq() {
- for i := range hc.seq {
- hc.seq[i] = 0
+// incEpoch resets the sequence number. In DTLS, it increments the
+// epoch half of the sequence number.
+func (hc *halfConn) incEpoch() {
+ limit := 0
+ if hc.isDTLS {
+ for i := 1; i >= 0; i-- {
+ hc.seq[i]++
+ if hc.seq[i] != 0 {
+ break
+ }
+ if i == 0 {
+ panic("TLS: epoch number wraparound")
+ }
+ }
+ limit = 2
}
+ seq := hc.seq[limit:]
+ for i := range seq {
+ seq[i] = 0
+ }
+}
+
+func (hc *halfConn) recordHeaderLen() int {
+ if hc.isDTLS {
+ return dtlsRecordHeaderLen
+ }
+ return tlsRecordHeaderLen
}
// removePadding returns an unpadded slice, in constant time, which is a prefix
@@ -237,6 +270,8 @@
// success boolean, the number of bytes to skip from the start of the record in
// order to get the application payload, and an optional alert value.
func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) {
+ recordHeaderLen := hc.recordHeaderLen()
+
// pull out payload
payload := b.data[recordHeaderLen:]
@@ -248,6 +283,12 @@
paddingGood := byte(255)
explicitIVLen := 0
+ seq := hc.seq[:]
+ if hc.isDTLS {
+ // DTLS sequence numbers are explicit.
+ seq = b.data[3:11]
+ }
+
// decrypt
if hc.cipher != nil {
switch c := hc.cipher.(type) {
@@ -262,7 +303,7 @@
payload = payload[8:]
var additionalData [13]byte
- copy(additionalData[:], hc.seq[:])
+ copy(additionalData[:], seq)
copy(additionalData[8:], b.data[:3])
n := len(payload) - c.Overhead()
additionalData[11] = byte(n >> 8)
@@ -275,7 +316,7 @@
b.resize(recordHeaderLen + explicitIVLen + len(payload))
case cbcMode:
blockSize := c.BlockSize()
- if hc.version >= VersionTLS11 {
+ if hc.version >= VersionTLS11 || hc.isDTLS {
explicitIVLen = blockSize
}
@@ -318,11 +359,11 @@
// strip mac off payload, b.data
n := len(payload) - macSize
- b.data[3] = byte(n >> 8)
- b.data[4] = byte(n)
+ b.data[recordHeaderLen-2] = byte(n >> 8)
+ b.data[recordHeaderLen-1] = byte(n)
b.resize(recordHeaderLen + explicitIVLen + n)
remoteMAC := payload[n:]
- localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n])
+ localMAC := hc.mac.MAC(hc.inDigestBuf, seq, b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], payload[:n])
if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
return false, 0, alertBadRecordMAC
@@ -364,9 +405,11 @@
// encrypt encrypts and macs the data in b.
func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) {
+ recordHeaderLen := hc.recordHeaderLen()
+
// mac
if hc.mac != nil {
- mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:])
+ mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:])
n := len(b.data)
b.resize(n + len(mac))
@@ -412,8 +455,8 @@
// update length to include MAC and any block padding needed.
n := len(b.data) - recordHeaderLen
- b.data[3] = byte(n >> 8)
- b.data[4] = byte(n)
+ b.data[recordHeaderLen-2] = byte(n >> 8)
+ b.data[recordHeaderLen-1] = byte(n)
hc.incSeq()
return true, 0
@@ -517,6 +560,86 @@
return b, bb
}
+func (c *Conn) doReadRecord(want recordType) (recordType, *block, error) {
+ if c.isDTLS {
+ return c.dtlsDoReadRecord(want)
+ }
+
+ recordHeaderLen := tlsRecordHeaderLen
+
+ if c.rawInput == nil {
+ c.rawInput = c.in.newBlock()
+ }
+ b := c.rawInput
+
+ // Read header, payload.
+ if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
+ // RFC suggests that EOF without an alertCloseNotify is
+ // an error, but popular web sites seem to do this,
+ // so we can't make it an error.
+ // if err == io.EOF {
+ // err = io.ErrUnexpectedEOF
+ // }
+ if e, ok := err.(net.Error); !ok || !e.Temporary() {
+ c.in.setErrorLocked(err)
+ }
+ return 0, nil, err
+ }
+ typ := recordType(b.data[0])
+
+ // No valid TLS record has a type of 0x80, however SSLv2 handshakes
+ // start with a uint16 length where the MSB is set and the first record
+ // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
+ // an SSLv2 client.
+ if want == recordTypeHandshake && typ == 0x80 {
+ c.sendAlert(alertProtocolVersion)
+ return 0, nil, c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
+ }
+
+ vers := uint16(b.data[1])<<8 | uint16(b.data[2])
+ n := int(b.data[3])<<8 | int(b.data[4])
+ if c.haveVers && vers != c.vers {
+ c.sendAlert(alertProtocolVersion)
+ return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
+ }
+ if n > maxCiphertext {
+ c.sendAlert(alertRecordOverflow)
+ return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
+ }
+ if !c.haveVers {
+ // First message, be extra suspicious:
+ // this might not be a TLS client.
+ // Bail out before reading a full 'body', if possible.
+ // The current max version is 3.1.
+ // If the version is >= 16.0, it's probably not real.
+ // Similarly, a clientHello message encodes in
+ // well under a kilobyte. If the length is >= 12 kB,
+ // it's probably not real.
+ if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
+ c.sendAlert(alertUnexpectedMessage)
+ return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
+ }
+ }
+ if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ if e, ok := err.(net.Error); !ok || !e.Temporary() {
+ c.in.setErrorLocked(err)
+ }
+ return 0, nil, err
+ }
+
+ // Process message.
+ b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
+ ok, off, err := c.in.decrypt(b)
+ if !ok {
+ c.in.setErrorLocked(c.sendAlert(err))
+ }
+ b.off = off
+ return typ, b, nil
+}
+
// readRecord reads the next TLS record from the connection
// and updates the record layer state.
// c.in.Mutex <= L; c.input == nil.
@@ -541,76 +664,10 @@
}
Again:
- if c.rawInput == nil {
- c.rawInput = c.in.newBlock()
- }
- b := c.rawInput
-
- // Read header, payload.
- if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
- // RFC suggests that EOF without an alertCloseNotify is
- // an error, but popular web sites seem to do this,
- // so we can't make it an error.
- // if err == io.EOF {
- // err = io.ErrUnexpectedEOF
- // }
- if e, ok := err.(net.Error); !ok || !e.Temporary() {
- c.in.setErrorLocked(err)
- }
+ typ, b, err := c.doReadRecord(want)
+ if err != nil {
return err
}
- typ := recordType(b.data[0])
-
- // No valid TLS record has a type of 0x80, however SSLv2 handshakes
- // start with a uint16 length where the MSB is set and the first record
- // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
- // an SSLv2 client.
- if want == recordTypeHandshake && typ == 0x80 {
- c.sendAlert(alertProtocolVersion)
- return c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
- }
-
- vers := uint16(b.data[1])<<8 | uint16(b.data[2])
- n := int(b.data[3])<<8 | int(b.data[4])
- if c.haveVers && vers != c.vers {
- c.sendAlert(alertProtocolVersion)
- return c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
- }
- if n > maxCiphertext {
- c.sendAlert(alertRecordOverflow)
- return c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
- }
- if !c.haveVers {
- // First message, be extra suspicious:
- // this might not be a TLS client.
- // Bail out before reading a full 'body', if possible.
- // The current max version is 3.1.
- // If the version is >= 16.0, it's probably not real.
- // Similarly, a clientHello message encodes in
- // well under a kilobyte. If the length is >= 12 kB,
- // it's probably not real.
- if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
- c.sendAlert(alertUnexpectedMessage)
- return c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
- }
- }
- if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
- }
- if e, ok := err.(net.Error); !ok || !e.Temporary() {
- c.in.setErrorLocked(err)
- }
- return err
- }
-
- // Process message.
- b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
- ok, off, err := c.in.decrypt(b)
- if !ok {
- c.in.setErrorLocked(c.sendAlert(err))
- }
- b.off = off
data := b.data[b.off:]
if len(data) > maxPlaintext {
err := c.sendAlert(alertRecordOverflow)
@@ -713,6 +770,11 @@
// to the connection and updates the record layer state.
// c.out.Mutex <= L.
func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
+ if c.isDTLS {
+ return c.dtlsWriteRecord(typ, data)
+ }
+
+ recordHeaderLen := tlsRecordHeaderLen
b := c.out.newBlock()
first := true
isClientHello := typ == recordTypeHandshake && len(data) > 0 && data[0] == typeClientHello
@@ -800,10 +862,11 @@
return
}
-// readHandshake reads the next handshake message from
-// the record layer.
-// c.in.Mutex < L; c.out.Mutex < L.
-func (c *Conn) readHandshake() (interface{}, error) {
+func (c *Conn) doReadHandshake() ([]byte, error) {
+ if c.isDTLS {
+ return c.dtlsDoReadHandshake()
+ }
+
for c.hand.Len() < 4 {
if err := c.in.err; err != nil {
return nil, err
@@ -826,13 +889,28 @@
return nil, err
}
}
- data = c.hand.Next(4 + n)
+ return c.hand.Next(4 + n), nil
+}
+
+// readHandshake reads the next handshake message from
+// the record layer.
+// c.in.Mutex < L; c.out.Mutex < L.
+func (c *Conn) readHandshake() (interface{}, error) {
+ data, err := c.doReadHandshake()
+ if err != nil {
+ return nil, err
+ }
+
var m handshakeMessage
switch data[0] {
case typeClientHello:
- m = new(clientHelloMsg)
+ m = &clientHelloMsg{
+ isDTLS: c.isDTLS,
+ }
case typeServerHello:
- m = new(serverHelloMsg)
+ m = &serverHelloMsg{
+ isDTLS: c.isDTLS,
+ }
case typeNewSessionTicket:
m = new(newSessionTicketMsg)
case typeCertificate:
@@ -857,6 +935,8 @@
m = new(nextProtoMsg)
case typeFinished:
m = new(finishedMsg)
+ case typeHelloVerifyRequest:
+ m = new(helloVerifyRequestMsg)
default:
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
@@ -899,7 +979,7 @@
// http://www.imperialviolet.org/2012/01/15/beastfollowup.html
var m int
- if len(b) > 1 && c.vers <= VersionTLS10 {
+ if len(b) > 1 && c.vers <= VersionTLS10 && !c.isDTLS {
if _, ok := c.out.cipher.(cipher.BlockMode); ok {
n, err := c.writeRecord(recordTypeApplicationData, b[:1])
if err != nil {
@@ -938,7 +1018,7 @@
}
n, err = c.input.Read(b)
- if c.input.off >= len(c.input.data) {
+ if c.input.off >= len(c.input.data) || c.isDTLS {
c.in.freeBlock(c.input)
c.input = nil
}