]> Cypherpunks.ru repositories - gostls13.git/commitdiff
crypto/tls: replace custom *block with standard buffers
authorFilippo Valsorda <filippo@golang.org>
Wed, 17 Oct 2018 00:06:08 +0000 (20:06 -0400)
committerFilippo Valsorda <filippo@golang.org>
Wed, 24 Oct 2018 10:03:23 +0000 (10:03 +0000)
The crypto/tls record layer used a custom buffer implementation with its
own semantics, freelist, and offset management. Replace it all with
per-task bytes.Buffer, bytes.Reader and byte slices, along with a
refactor of all the encrypt and decrypt code.

The main quirk of *block was to do a best-effort read past the record
boundary, so that if a closeNotify was waiting it would be peeked and
surfaced along with the last Read. Address that with atLeastReader and
ReadFrom to avoid a useless copy (instead of a LimitReader or CopyN).

There was also an optimization to split blocks along record boundary
lines without having to copy in and out the data. Replicate that by
aliasing c.input into consumed c.rawInput (after an in-place decrypt
operation). This is safe because c.rawInput is not used until c.input is
drained.

The benchmarks are noisy but look like an improvement across the board,
which is a nice side effect :)

name                                       old time/op   new time/op   delta
HandshakeServer/RSA-8                        817µs ± 2%    797µs ± 2%  -2.52%  (p=0.000 n=10+9)
HandshakeServer/ECDHE-P256-RSA-8             984µs ±11%    897µs ± 0%  -8.89%  (p=0.000 n=10+9)
HandshakeServer/ECDHE-P256-ECDSA-P256-8      206µs ±10%    199µs ± 3%    ~     (p=0.113 n=10+9)
HandshakeServer/ECDHE-X25519-ECDSA-P256-8    204µs ± 3%    202µs ± 1%  -1.06%  (p=0.013 n=10+9)
HandshakeServer/ECDHE-P521-ECDSA-P521-8     15.5ms ± 0%   15.6ms ± 1%    ~     (p=0.095 n=9+10)
Throughput/MaxPacket/1MB-8                  5.35ms ±19%   5.39ms ±36%    ~     (p=1.000 n=9+10)
Throughput/MaxPacket/2MB-8                  9.20ms ±15%   8.30ms ± 8%  -9.79%  (p=0.035 n=10+9)
Throughput/MaxPacket/4MB-8                  13.8ms ± 7%   13.6ms ± 8%    ~     (p=0.315 n=10+10)
Throughput/MaxPacket/8MB-8                  25.1ms ± 3%   23.2ms ± 2%  -7.66%  (p=0.000 n=10+9)
Throughput/MaxPacket/16MB-8                 46.9ms ± 1%   43.0ms ± 3%  -8.29%  (p=0.000 n=9+10)
Throughput/MaxPacket/32MB-8                 88.9ms ± 2%   82.3ms ± 2%  -7.40%  (p=0.000 n=9+9)
Throughput/MaxPacket/64MB-8                  175ms ± 2%    164ms ± 4%  -6.18%  (p=0.000 n=10+10)
Throughput/DynamicPacket/1MB-8              5.79ms ±26%   5.82ms ±22%    ~     (p=0.912 n=10+10)
Throughput/DynamicPacket/2MB-8              9.23ms ±14%   9.50ms ±23%    ~     (p=0.971 n=10+10)
Throughput/DynamicPacket/4MB-8              14.5ms ±11%   13.8ms ± 6%  -4.66%  (p=0.019 n=10+10)
Throughput/DynamicPacket/8MB-8              25.6ms ± 4%   23.5ms ± 3%  -8.33%  (p=0.000 n=10+10)
Throughput/DynamicPacket/16MB-8             47.3ms ± 3%   44.6ms ± 7%  -5.65%  (p=0.000 n=10+10)
Throughput/DynamicPacket/32MB-8             91.9ms ±14%   85.0ms ± 4%  -7.55%  (p=0.000 n=10+10)
Throughput/DynamicPacket/64MB-8              177ms ± 2%    168ms ± 4%  -4.97%  (p=0.000 n=8+10)
Latency/MaxPacket/200kbps-8                  694ms ± 0%    694ms ± 0%    ~     (p=0.315 n=10+9)
Latency/MaxPacket/500kbps-8                  279ms ± 0%    279ms ± 0%    ~     (p=0.447 n=9+10)
Latency/MaxPacket/1000kbps-8                 140ms ± 0%    140ms ± 0%    ~     (p=0.661 n=9+10)
Latency/MaxPacket/2000kbps-8                71.1ms ± 0%   71.1ms ± 0%  +0.05%  (p=0.019 n=9+9)
Latency/MaxPacket/5000kbps-8                30.4ms ± 7%   30.5ms ± 4%    ~     (p=0.720 n=9+10)
Latency/DynamicPacket/200kbps-8              134ms ± 0%    134ms ± 0%    ~     (p=0.075 n=10+10)
Latency/DynamicPacket/500kbps-8             54.8ms ± 0%   54.8ms ± 0%    ~     (p=0.631 n=10+10)
Latency/DynamicPacket/1000kbps-8            28.5ms ± 0%   28.5ms ± 0%    ~     (p=1.000 n=8+8)
Latency/DynamicPacket/2000kbps-8            15.7ms ±12%   16.1ms ± 0%    ~     (p=0.109 n=10+7)
Latency/DynamicPacket/5000kbps-8            8.20ms ±26%   8.17ms ±13%    ~     (p=1.000 n=9+9)

name                                       old speed     new speed     delta
Throughput/MaxPacket/1MB-8                 193MB/s ±14%  202MB/s ±30%    ~     (p=0.897 n=8+10)
Throughput/MaxPacket/2MB-8                 230MB/s ±14%  249MB/s ±17%    ~     (p=0.089 n=10+10)
Throughput/MaxPacket/4MB-8                 304MB/s ± 6%  309MB/s ± 7%    ~     (p=0.315 n=10+10)
Throughput/MaxPacket/8MB-8                 334MB/s ± 3%  362MB/s ± 2%  +8.29%  (p=0.000 n=10+9)
Throughput/MaxPacket/16MB-8                358MB/s ± 1%  390MB/s ± 3%  +9.08%  (p=0.000 n=9+10)
Throughput/MaxPacket/32MB-8                378MB/s ± 2%  408MB/s ± 2%  +8.00%  (p=0.000 n=9+9)
Throughput/MaxPacket/64MB-8                384MB/s ± 2%  410MB/s ± 4%  +6.61%  (p=0.000 n=10+10)
Throughput/DynamicPacket/1MB-8             178MB/s ±24%  182MB/s ±24%    ~     (p=0.604 n=9+10)
Throughput/DynamicPacket/2MB-8             228MB/s ±13%  225MB/s ±20%    ~     (p=0.971 n=10+10)
Throughput/DynamicPacket/4MB-8             291MB/s ±10%  305MB/s ± 6%  +4.83%  (p=0.019 n=10+10)
Throughput/DynamicPacket/8MB-8             327MB/s ± 4%  357MB/s ± 3%  +9.08%  (p=0.000 n=10+10)
Throughput/DynamicPacket/16MB-8            355MB/s ± 3%  376MB/s ± 6%  +6.07%  (p=0.000 n=10+10)
Throughput/DynamicPacket/32MB-8            366MB/s ±12%  395MB/s ± 4%  +7.91%  (p=0.000 n=10+10)
Throughput/DynamicPacket/64MB-8            380MB/s ± 2%  400MB/s ± 4%  +5.26%  (p=0.000 n=8+10)

Note that this reduced the buffer for the first read from 1024 to 5+512,
so it triggered the issue described at #24198 when using a synchronous
net.Pipe: the first server flight was not being consumed entirely by the
first read anymore, causing a deadlock as both the client and the server
were trying to send (the client a reply to the ServerHello, the server
the rest of the buffer). Fixed by rebasing on top of CL 142817.

Change-Id: Ie31b0a572b2ad37878469877798d5c6a5276f931
Reviewed-on: https://go-review.googlesource.com/c/142818
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
src/crypto/tls/cipher_suites.go
src/crypto/tls/conn.go
src/crypto/tls/tls.go

index d232996629f25910da1167be5d3ecf308dc3b31b..e937235876e7ece3a6e9df6c89b0454793d59d83 100644 (file)
@@ -134,25 +134,29 @@ func macSHA1(version uint16, key []byte) macFunction {
                copy(mac.key, key)
                return mac
        }
-       return tls10MAC{hmac.New(newConstantTimeHash(sha1.New), key)}
+       return tls10MAC{h: hmac.New(newConstantTimeHash(sha1.New), key)}
 }
 
 // macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2
 // so the given version is ignored.
 func macSHA256(version uint16, key []byte) macFunction {
-       return tls10MAC{hmac.New(sha256.New, key)}
+       return tls10MAC{h: hmac.New(sha256.New, key)}
 }
 
 type macFunction interface {
+       // Size returns the length of the MAC.
        Size() int
-       MAC(digestBuf, seq, header, data, extra []byte) []byte
+       // MAC appends the MAC of (seq, header, data) to out. The extra data is fed
+       // into the MAC after obtaining the result to normalize timing. The result
+       // is only valid until the next invocation of MAC as the buffer is reused.
+       MAC(seq, header, data, extra []byte) []byte
 }
 
 type aead interface {
        cipher.AEAD
 
-       // explicitIVLen returns the number of bytes used by the explicit nonce
-       // that is included in the record. This is eight for older AEADs and
+       // explicitNonceLen returns the number of bytes of explicit nonce
+       // included in each record. This is eight for older AEADs and
        // zero for modern ones.
        explicitNonceLen() int
 }
@@ -245,6 +249,7 @@ func aeadChaCha20Poly1305(key, fixedNonce []byte) cipher.AEAD {
 type ssl30MAC struct {
        h   hash.Hash
        key []byte
+       buf []byte
 }
 
 func (s ssl30MAC) Size() int {
@@ -257,7 +262,7 @@ var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0
 
 // MAC does not offer constant timing guarantees for SSL v3.0, since it's deemed
 // useless considering the similar, protocol-level POODLE vulnerability.
-func (s ssl30MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte {
+func (s ssl30MAC) MAC(seq, header, data, extra []byte) []byte {
        padLength := 48
        if s.h.Size() == 20 {
                padLength = 40
@@ -270,13 +275,13 @@ func (s ssl30MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte {
        s.h.Write(header[:1])
        s.h.Write(header[3:5])
        s.h.Write(data)
-       digestBuf = s.h.Sum(digestBuf[:0])
+       s.buf = s.h.Sum(s.buf[:0])
 
        s.h.Reset()
        s.h.Write(s.key)
        s.h.Write(ssl30Pad2[:padLength])
-       s.h.Write(digestBuf)
-       return s.h.Sum(digestBuf[:0])
+       s.h.Write(s.buf)
+       return s.h.Sum(s.buf[:0])
 }
 
 type constantTimeHash interface {
@@ -304,7 +309,8 @@ func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
 
 // tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
 type tls10MAC struct {
-       h hash.Hash
+       h   hash.Hash
+       buf []byte
 }
 
 func (s tls10MAC) Size() int {
@@ -314,12 +320,12 @@ func (s tls10MAC) Size() int {
 // MAC is guaranteed to take constant time, as long as
 // len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into
 // the MAC, but is only provided to make the timing profile constant.
-func (s tls10MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte {
+func (s tls10MAC) MAC(seq, header, data, extra []byte) []byte {
        s.h.Reset()
        s.h.Write(seq)
        s.h.Write(header)
        s.h.Write(data)
-       res := s.h.Sum(digestBuf[:0])
+       res := s.h.Sum(s.buf[:0])
        if extra != nil {
                s.h.Write(extra)
        }
index f05135bc33230edd8a9828035aa2dcbf8af719ed..13cebc9042dd3daa72e6dd92117a36c2d4587c46 100644 (file)
@@ -82,9 +82,10 @@ type Conn struct {
 
        // input/output
        in, out   halfConn
-       rawInput  *block       // raw input, right off the wire
-       input     *block       // application data waiting to be read
+       rawInput  bytes.Buffer // raw input, starting with a record header
+       input     bytes.Reader // application data waiting to be read, from rawInput.Next
        hand      bytes.Buffer // handshake data waiting to be read
+       outBuf    []byte       // scratch buffer used by out.encrypt
        buffering bool         // whether records are buffered in sendBuf
        sendBuf   []byte       // a buffer of records waiting to be sent
 
@@ -149,14 +150,10 @@ type halfConn struct {
        cipher         interface{} // cipher algorithm
        mac            macFunction
        seq            [8]byte  // 64-bit sequence number
-       bfree          *block   // list of free blocks
        additionalData [13]byte // to avoid allocs; interface method args escape
 
        nextCipher interface{} // next encryption state
        nextMac    macFunction // next MAC algorithm
-
-       // used to save allocating a new buffer for each MAC.
-       inDigestBuf, outDigestBuf []byte
 }
 
 func (hc *halfConn) setErrorLocked(err error) error {
@@ -203,6 +200,30 @@ func (hc *halfConn) incSeq() {
        panic("TLS: sequence number wraparound")
 }
 
+// explicitNonceLen returns the number of bytes of explicit nonce or IV included
+// in each record. Explicit nonces are present only in CBC modes after TLS 1.0
+// and in certain AEAD modes in TLS 1.2.
+func (hc *halfConn) explicitNonceLen() int {
+       if hc.cipher == nil {
+               return 0
+       }
+
+       switch c := hc.cipher.(type) {
+       case cipher.Stream:
+               return 0
+       case aead:
+               return c.explicitNonceLen()
+       case cbcMode:
+               // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack.
+               if hc.version >= VersionTLS11 {
+                       return c.BlockSize()
+               }
+               return 0
+       default:
+               panic("unknown cipher type")
+       }
+}
+
 // extractPadding returns, in constant time, the length of the padding to remove
 // from the end of payload. It also returns a byte which is equal to 255 if the
 // padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2.
@@ -268,283 +289,189 @@ type cbcMode interface {
        SetIV([]byte)
 }
 
-// decrypt checks and strips the mac and decrypts the data in b. Returns a
-// success boolean, the number of bytes to skip from the start of the record in
-// order to get the application payload, and an optional alert value.
-func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) {
-       // pull out payload
-       payload := b.data[recordHeaderLen:]
-
-       macSize := 0
-       if hc.mac != nil {
-               macSize = hc.mac.Size()
-       }
+// decrypt authenticates and decrypts the record if protection is active at
+// this stage. The returned plaintext might overlap with the input.
+func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) {
+       payload := record[recordHeaderLen:]
 
        paddingGood := byte(255)
        paddingLen := 0
-       explicitIVLen := 0
 
-       // decrypt
+       explicitNonceLen := hc.explicitNonceLen()
+
        if hc.cipher != nil {
                switch c := hc.cipher.(type) {
                case cipher.Stream:
                        c.XORKeyStream(payload, payload)
                case aead:
-                       explicitIVLen = c.explicitNonceLen()
-                       if len(payload) < explicitIVLen {
-                               return false, 0, alertBadRecordMAC
+                       if len(payload) < explicitNonceLen {
+                               return nil, alertBadRecordMAC
                        }
-                       nonce := payload[:explicitIVLen]
-                       payload = payload[explicitIVLen:]
-
+                       nonce := payload[:explicitNonceLen]
                        if len(nonce) == 0 {
                                nonce = hc.seq[:]
                        }
+                       payload = payload[explicitNonceLen:]
 
                        copy(hc.additionalData[:], hc.seq[:])
-                       copy(hc.additionalData[8:], b.data[:3])
+                       copy(hc.additionalData[8:], record[:3])
                        n := len(payload) - c.Overhead()
                        hc.additionalData[11] = byte(n >> 8)
                        hc.additionalData[12] = byte(n)
+
                        var err error
-                       payload, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:])
+                       plaintext, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:])
                        if err != nil {
-                               return false, 0, alertBadRecordMAC
+                               return nil, alertBadRecordMAC
                        }
-                       b.resize(recordHeaderLen + explicitIVLen + len(payload))
                case cbcMode:
                        blockSize := c.BlockSize()
-                       if hc.version >= VersionTLS11 {
-                               explicitIVLen = blockSize
+                       minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) // TODO: vuln?
+                       if len(payload)%blockSize != 0 || len(payload) < minPayload {
+                               return nil, alertBadRecordMAC
                        }
 
-                       if len(payload)%blockSize != 0 || len(payload) < roundUp(explicitIVLen+macSize+1, blockSize) {
-                               return false, 0, alertBadRecordMAC
-                       }
-
-                       if explicitIVLen > 0 {
-                               c.SetIV(payload[:explicitIVLen])
-                               payload = payload[explicitIVLen:]
+                       if explicitNonceLen > 0 {
+                               c.SetIV(payload[:explicitNonceLen])
+                               payload = payload[explicitNonceLen:]
                        }
                        c.CryptBlocks(payload, payload)
+
+                       // In a limited attempt to protect against CBC padding oracles like
+                       // Lucky13, the data past paddingLen (which is secret) is passed to
+                       // the MAC function as extra data, to be fed into the HMAC after
+                       // computing the digest. This makes the MAC roughly constant time as
+                       // long as the digest computation is constant time and does not
+                       // affect the subsequent write, modulo cache effects.
                        if hc.version == VersionSSL30 {
                                paddingLen, paddingGood = extractPaddingSSL30(payload)
                        } else {
                                paddingLen, paddingGood = extractPadding(payload)
-
-                               // To protect against CBC padding oracles like Lucky13, the data
-                               // past paddingLen (which is secret) is passed to the MAC
-                               // function as extra data, to be fed into the HMAC after
-                               // computing the digest. This makes the MAC constant time as
-                               // long as the digest computation is constant time and does not
-                               // affect the subsequent write.
                        }
                default:
                        panic("unknown cipher type")
                }
+       } else {
+               plaintext = payload
        }
 
-       // check, strip mac
        if hc.mac != nil {
+               macSize := hc.mac.Size()
                if len(payload) < macSize {
-                       return false, 0, alertBadRecordMAC
+                       return nil, alertBadRecordMAC
                }
 
-               // strip mac off payload, b.data
                n := len(payload) - macSize - paddingLen
                n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 }
-               b.data[3] = byte(n >> 8)
-               b.data[4] = byte(n)
+               record[3] = byte(n >> 8)
+               record[4] = byte(n)
                remoteMAC := payload[n : n+macSize]
-               localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n], payload[n+macSize:])
+               localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
 
                if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
-                       return false, 0, alertBadRecordMAC
+                       return nil, alertBadRecordMAC
                }
-               hc.inDigestBuf = localMAC
 
-               b.resize(recordHeaderLen + explicitIVLen + n)
+               plaintext = payload[:n]
        }
-       hc.incSeq()
 
-       return true, recordHeaderLen + explicitIVLen, 0
+       hc.incSeq()
+       return plaintext, nil
 }
 
-// padToBlockSize calculates the needed padding block, if any, for a payload.
-// On exit, prefix aliases payload and extends to the end of the last full
-// block of payload. finalBlock is a fresh slice which contains the contents of
-// any suffix of payload as well as the needed padding to make finalBlock a
-// full block.
-func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
-       overrun := len(payload) % blockSize
-       paddingLen := blockSize - overrun
-       prefix = payload[:len(payload)-overrun]
-       finalBlock = make([]byte, blockSize)
-       copy(finalBlock, payload[len(payload)-overrun:])
-       for i := overrun; i < blockSize; i++ {
-               finalBlock[i] = byte(paddingLen - 1)
+// sliceForAppend extends the input slice by n bytes. head is the full extended
+// slice, while tail is the appended part. If the original slice has sufficient
+// capacity no allocation is performed.
+func sliceForAppend(in []byte, n int) (head, tail []byte) {
+       if total := len(in) + n; cap(in) >= total {
+               head = in[:total]
+       } else {
+               head = make([]byte, total)
+               copy(head, in)
        }
+       tail = head[len(in):]
        return
 }
 
-// encrypt encrypts and macs the data in b.
-func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) {
-       // mac
-       if hc.mac != nil {
-               mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:], nil)
-
-               n := len(b.data)
-               b.resize(n + len(mac))
-               copy(b.data[n:], mac)
-               hc.outDigestBuf = mac
-       }
-
-       payload := b.data[recordHeaderLen:]
-
-       // encrypt
-       if hc.cipher != nil {
-               switch c := hc.cipher.(type) {
-               case cipher.Stream:
-                       c.XORKeyStream(payload, payload)
-               case aead:
-                       payloadLen := len(b.data) - recordHeaderLen - explicitIVLen
-                       b.resize(len(b.data) + c.Overhead())
-                       nonce := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
-                       if len(nonce) == 0 {
-                               nonce = hc.seq[:]
-                       }
-                       payload := b.data[recordHeaderLen+explicitIVLen:]
-                       payload = payload[:payloadLen]
-
-                       copy(hc.additionalData[:], hc.seq[:])
-                       copy(hc.additionalData[8:], b.data[:3])
-                       hc.additionalData[11] = byte(payloadLen >> 8)
-                       hc.additionalData[12] = byte(payloadLen)
-
-                       c.Seal(payload[:0], nonce, payload, hc.additionalData[:])
-               case cbcMode:
-                       blockSize := c.BlockSize()
-                       if explicitIVLen > 0 {
-                               c.SetIV(payload[:explicitIVLen])
-                               payload = payload[explicitIVLen:]
+// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and
+// appends it to record, which contains the record header.
+func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
+       if hc.cipher == nil {
+               return append(record, payload...), nil
+       }
+
+       var explicitNonce []byte
+       if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
+               record, explicitNonce = sliceForAppend(record, explicitNonceLen)
+               if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
+                       // The AES-GCM construction in TLS has an explicit nonce so that the
+                       // nonce can be random. However, the nonce is only 8 bytes which is
+                       // too small for a secure, random nonce. Therefore we use the
+                       // sequence number as the nonce. The 3DES-CBC construction also has
+                       // an 8 bytes nonce but its nonces must be unpredictable (see RFC
+                       // 5246, Appendix F.3), forcing us to use randomness. That's not
+                       // 3DES' biggest problem anyway because the birthday bound on block
+                       // collision is reached first due to its simlarly small block size
+                       // (see the Sweet32 attack).
+                       copy(explicitNonce, hc.seq[:])
+               } else {
+                       if _, err := io.ReadFull(rand, explicitNonce); err != nil {
+                               return nil, err
                        }
-                       prefix, finalBlock := padToBlockSize(payload, blockSize)
-                       b.resize(recordHeaderLen + explicitIVLen + len(prefix) + len(finalBlock))
-                       c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen:], prefix)
-                       c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen+len(prefix):], finalBlock)
-               default:
-                       panic("unknown cipher type")
                }
        }
 
-       // update length to include MAC and any block padding needed.
-       n := len(b.data) - recordHeaderLen
-       b.data[3] = byte(n >> 8)
-       b.data[4] = byte(n)
-       hc.incSeq()
-
-       return true, 0
-}
-
-// A block is a simple data buffer.
-type block struct {
-       data []byte
-       off  int // index for Read
-       link *block
-}
-
-// resize resizes block to be n bytes, growing if necessary.
-func (b *block) resize(n int) {
-       if n > cap(b.data) {
-               b.reserve(n)
-       }
-       b.data = b.data[0:n]
-}
-
-// reserve makes sure that block contains a capacity of at least n bytes.
-func (b *block) reserve(n int) {
-       if cap(b.data) >= n {
-               return
-       }
-       m := cap(b.data)
-       if m == 0 {
-               m = 1024
-       }
-       for m < n {
-               m *= 2
-       }
-       data := make([]byte, len(b.data), m)
-       copy(data, b.data)
-       b.data = data
-}
-
-// readFromUntil reads from r into b until b contains at least n bytes
-// or else returns an error.
-func (b *block) readFromUntil(r io.Reader, n int) error {
-       // quick case
-       if len(b.data) >= n {
-               return nil
-       }
+       var mac []byte
+       if hc.mac != nil {
+               mac = hc.mac.MAC(hc.seq[:], record[:recordHeaderLen], payload, nil)
+       }
+
+       var dst []byte
+       switch c := hc.cipher.(type) {
+       case cipher.Stream:
+               record, dst = sliceForAppend(record, len(payload)+len(mac))
+               c.XORKeyStream(dst[:len(payload)], payload)
+               c.XORKeyStream(dst[len(payload):], mac)
+       case aead:
+               nonce := explicitNonce
+               if len(nonce) == 0 {
+                       nonce = hc.seq[:]
+               }
 
-       // read until have enough.
-       b.reserve(n)
-       for {
-               m, err := r.Read(b.data[len(b.data):cap(b.data)])
-               b.data = b.data[0 : len(b.data)+m]
-               if len(b.data) >= n {
-                       // TODO(bradfitz,agl): slightly suspicious
-                       // that we're throwing away r.Read's err here.
-                       break
+               copy(hc.additionalData[:], hc.seq[:])
+               copy(hc.additionalData[8:], record[:3])
+               hc.additionalData[11] = byte(len(payload) >> 8)
+               hc.additionalData[12] = byte(len(payload))
+
+               record = c.Seal(record, nonce, payload, hc.additionalData[:])
+       case cbcMode:
+               blockSize := c.BlockSize()
+               plaintextLen := len(payload) + len(mac)
+               paddingLen := blockSize - plaintextLen%blockSize
+               record, dst = sliceForAppend(record, plaintextLen+paddingLen)
+               copy(dst, payload)
+               copy(dst[len(payload):], mac)
+               for i := plaintextLen; i < len(dst); i++ {
+                       dst[i] = byte(paddingLen - 1)
                }
-               if err != nil {
-                       return err
+               if len(explicitNonce) > 0 {
+                       c.SetIV(explicitNonce)
                }
+               c.CryptBlocks(dst, dst)
+       default:
+               panic("unknown cipher type")
        }
-       return nil
-}
-
-func (b *block) Read(p []byte) (n int, err error) {
-       n = copy(p, b.data[b.off:])
-       b.off += n
-       return
-}
-
-// newBlock allocates a new block, from hc's free list if possible.
-func (hc *halfConn) newBlock() *block {
-       b := hc.bfree
-       if b == nil {
-               return new(block)
-       }
-       hc.bfree = b.link
-       b.link = nil
-       b.resize(0)
-       return b
-}
 
-// freeBlock returns a block to hc's free list.
-// The protocol is such that each side only has a block or two on
-// its free list at a time, so there's no need to worry about
-// trimming the list, etc.
-func (hc *halfConn) freeBlock(b *block) {
-       b.link = hc.bfree
-       hc.bfree = b
-}
+       // Update length to include nonce, MAC and any block padding needed.
+       n := len(record) - recordHeaderLen
+       record[3] = byte(n >> 8)
+       record[4] = byte(n)
+       hc.incSeq()
 
-// splitBlock splits a block after the first n bytes,
-// returning a block with those n bytes and a
-// block with the remainder.  the latter may be nil.
-func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
-       if len(b.data) <= n {
-               return b, nil
-       }
-       bb := hc.newBlock()
-       bb.resize(len(b.data) - n)
-       copy(bb.data, b.data[n:])
-       b.data = b.data[0:n]
-       return b, bb
+       return record, nil
 }
 
-// RecordHeaderError results when a TLS record header is invalid.
+// RecordHeaderError is returned when a TLS record header is invalid.
 type RecordHeaderError struct {
        // Msg contains a human readable string that describes the error.
        Msg string
@@ -557,7 +484,7 @@ func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
 
 func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) {
        err.Msg = msg
-       copy(err.RecordHeader[:], c.rawInput.data)
+       copy(err.RecordHeader[:], c.rawInput.Bytes())
        return err
 }
 
@@ -569,40 +496,38 @@ func (c *Conn) readRecord(want recordType) error {
        // else application data.
        switch want {
        default:
-               c.sendAlert(alertInternalError)
-               return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
+               panic("tls: unknown record type requested")
        case recordTypeHandshake, recordTypeChangeCipherSpec:
                if c.handshakeComplete() {
-                       c.sendAlert(alertInternalError)
-                       return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
+                       panic("tls: handshake or ChangeCipherSpec requested while not in handshake")
                }
        case recordTypeApplicationData:
                if !c.handshakeComplete() {
-                       c.sendAlert(alertInternalError)
-                       return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
+                       panic("tls: application data record requested while in handshake")
                }
        }
 
-Again:
-       if c.rawInput == nil {
-               c.rawInput = c.in.newBlock()
+       // This function modifies c.rawInput, which owns the c.input memory.
+       if c.input.Len() != 0 {
+               panic("tls: attempted to read record with pending application data")
        }
-       b := c.rawInput
+       c.input.Reset(nil)
 
        // Read header, payload.
-       if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
-               // RFC suggests that EOF without an alertCloseNotify is
-               // an error, but popular web sites seem to do this,
-               // so we can't make it an error.
-               // if err == io.EOF {
-               //      err = io.ErrUnexpectedEOF
-               // }
+       if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
+               // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
+               // is an error, but popular web sites seem to do this, so we accept it
+               // if and only if at the record boundary.
+               if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
+                       err = io.EOF
+               }
                if e, ok := err.(net.Error); !ok || !e.Temporary() {
                        c.in.setErrorLocked(err)
                }
                return err
        }
-       typ := recordType(b.data[0])
+       hdr := c.rawInput.Bytes()[:recordHeaderLen]
+       typ := recordType(hdr[0])
 
        // No valid TLS record has a type of 0x80, however SSLv2 handshakes
        // start with a uint16 length where the MSB is set and the first record
@@ -613,8 +538,8 @@ Again:
                return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received"))
        }
 
-       vers := uint16(b.data[1])<<8 | uint16(b.data[2])
-       n := int(b.data[3])<<8 | int(b.data[4])
+       vers := uint16(hdr[1])<<8 | uint16(hdr[2])
+       n := int(hdr[3])<<8 | int(hdr[4])
        if c.haveVers && vers != c.vers {
                c.sendAlert(alertProtocolVersion)
                msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
@@ -635,10 +560,7 @@ Again:
                        return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake"))
                }
        }
-       if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
-               if err == io.EOF {
-                       err = io.ErrUnexpectedEOF
-               }
+       if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
                if e, ok := err.(net.Error); !ok || !e.Temporary() {
                        c.in.setErrorLocked(err)
                }
@@ -646,18 +568,13 @@ Again:
        }
 
        // Process message.
-       b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
-       ok, off, alertValue := c.in.decrypt(b)
-       if !ok {
-               c.in.freeBlock(b)
-               return c.in.setErrorLocked(c.sendAlert(alertValue))
+       record := c.rawInput.Next(recordHeaderLen + n)
+       data, err := c.in.decrypt(record)
+       if err != nil {
+               return c.in.setErrorLocked(c.sendAlert(err.(alert)))
        }
-       b.off = off
-       data := b.data[b.off:]
        if len(data) > maxPlaintext {
-               err := c.sendAlert(alertRecordOverflow)
-               c.in.freeBlock(b)
-               return c.in.setErrorLocked(err)
+               return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
        }
 
        if typ != recordTypeAlert && len(data) > 0 {
@@ -667,70 +584,97 @@ Again:
 
        switch typ {
        default:
-               c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
+               return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
 
        case recordTypeAlert:
                if len(data) != 2 {
-                       c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
-                       break
+                       return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                }
                if alert(data[1]) == alertCloseNotify {
-                       c.in.setErrorLocked(io.EOF)
-                       break
+                       return c.in.setErrorLocked(io.EOF)
                }
                switch data[0] {
                case alertLevelWarning:
-                       // drop on the floor
-                       c.in.freeBlock(b)
-
                        c.warnCount++
                        if c.warnCount > maxWarnAlertCount {
                                c.sendAlert(alertUnexpectedMessage)
                                return c.in.setErrorLocked(errors.New("tls: too many warn alerts"))
                        }
-
-                       goto Again
+                       return c.readRecord(want) // Drop the record on the floor and retry.
                case alertLevelError:
-                       c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
+                       return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
                default:
-                       c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
+                       return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                }
 
        case recordTypeChangeCipherSpec:
                if typ != want || len(data) != 1 || data[0] != 1 {
-                       c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
-                       break
+                       return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                }
-               // Handshake messages are not allowed to fragment across the CCS
+               // Handshake messages are not allowed to fragment across the CCS.
                if c.hand.Len() > 0 {
-                       c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
-                       break
+                       return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                }
-               err := c.in.changeCipherSpec()
-               if err != nil {
-                       c.in.setErrorLocked(c.sendAlert(err.(alert)))
+               if err := c.in.changeCipherSpec(); err != nil {
+                       return c.in.setErrorLocked(c.sendAlert(err.(alert)))
                }
+               return nil
 
        case recordTypeApplicationData:
                if typ != want {
-                       c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
-                       break
+                       return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                }
-               c.input = b
-               b = nil
+               // Note that data is owned by c.rawInput, following the Next call above,
+               // to avoid copying the plaintext. This is safe because c.rawInput is
+               // not read from or written to until c.input is drained.
+               c.input.Reset(data)
+               return nil
 
        case recordTypeHandshake:
-               // TODO(rsc): Should at least pick off connection close.
                if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) {
                        return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
                }
                c.hand.Write(data)
+               return nil
+       }
+}
+
+// atLeastReader reads from R, stopping with EOF once at least N bytes have been
+// read. It is different from an io.LimitedReader in that it doesn't cut short
+// the last Read call, and in that it considers an early EOF an error.
+type atLeastReader struct {
+       R io.Reader
+       N int64
+}
+
+func (r *atLeastReader) Read(p []byte) (int, error) {
+       if r.N <= 0 {
+               return 0, io.EOF
+       }
+       n, err := r.R.Read(p)
+       r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
+       if r.N > 0 && err == io.EOF {
+               return n, io.ErrUnexpectedEOF
        }
+       if r.N <= 0 && err == nil {
+               return n, io.EOF
+       }
+       return n, err
+}
 
-       if b != nil {
-               c.in.freeBlock(b)
+// readFromUntil reads from r into c.rawInput until c.rawInput contains
+// at least n bytes or else returns an error.
+func (c *Conn) readFromUntil(r io.Reader, n int) error {
+       if c.rawInput.Len() >= n {
+               return nil
        }
-       return c.in.err
+       needs := n - c.rawInput.Len()
+       // There might be extra input waiting on the wire. Make a best effort
+       // attempt to fetch it so that it can be used in (*Conn).Read to
+       // "predict" closeNotify alerts.
+       c.rawInput.Grow(needs + bytes.MinRead)
+       _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
+       return err
 }
 
 // sendAlert sends a TLS alert message.
@@ -789,7 +733,7 @@ const (
 //
 // In the interests of simplicity and determinism, this code does not attempt
 // to reset the record size once the connection is idle, however.
-func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
+func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
        if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
                return maxPlaintext
        }
@@ -799,16 +743,11 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
        }
 
        // Subtract TLS overheads to get the maximum payload size.
-       macSize := 0
-       if c.out.mac != nil {
-               macSize = c.out.mac.Size()
-       }
-
-       payloadBytes := tcpMSSEstimate - recordHeaderLen - explicitIVLen
+       payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
        if c.out.cipher != nil {
                switch ciph := c.out.cipher.(type) {
                case cipher.Stream:
-                       payloadBytes -= macSize
+                       payloadBytes -= c.out.mac.Size()
                case cipher.AEAD:
                        payloadBytes -= ciph.Overhead()
                case cbcMode:
@@ -818,7 +757,7 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
                        payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
                        // The MAC is appended before padding so affects the
                        // payload size directly.
-                       payloadBytes -= macSize
+                       payloadBytes -= c.out.mac.Size()
                default:
                        panic("unknown cipher type")
                }
@@ -864,63 +803,32 @@ func (c *Conn) flush() (int, error) {
 // writeRecordLocked writes a TLS record with the given type and payload to the
 // connection and updates the record layer state.
 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
-       b := c.out.newBlock()
-       defer c.out.freeBlock(b)
-
        var n int
        for len(data) > 0 {
-               explicitIVLen := 0
-               explicitIVIsSeq := false
-
-               var cbc cbcMode
-               if c.out.version >= VersionTLS11 {
-                       var ok bool
-                       if cbc, ok = c.out.cipher.(cbcMode); ok {
-                               explicitIVLen = cbc.BlockSize()
-                       }
-               }
-               if explicitIVLen == 0 {
-                       if c, ok := c.out.cipher.(aead); ok {
-                               explicitIVLen = c.explicitNonceLen()
-
-                               // The AES-GCM construction in TLS has an
-                               // explicit nonce so that the nonce can be
-                               // random. However, the nonce is only 8 bytes
-                               // which is too small for a secure, random
-                               // nonce. Therefore we use the sequence number
-                               // as the nonce.
-                               explicitIVIsSeq = explicitIVLen > 0
-                       }
-               }
                m := len(data)
-               if maxPayload := c.maxPayloadSizeForWrite(typ, explicitIVLen); m > maxPayload {
+               if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
                        m = maxPayload
                }
-               b.resize(recordHeaderLen + explicitIVLen + m)
-               b.data[0] = byte(typ)
+
+               _, c.outBuf = sliceForAppend(c.outBuf[:0], recordHeaderLen)
+               c.outBuf[0] = byte(typ)
                vers := c.vers
                if vers == 0 {
                        // Some TLS servers fail if the record version is
                        // greater than TLS 1.0 for the initial ClientHello.
                        vers = VersionTLS10
                }
-               b.data[1] = byte(vers >> 8)
-               b.data[2] = byte(vers)
-               b.data[3] = byte(m >> 8)
-               b.data[4] = byte(m)
-               if explicitIVLen > 0 {
-                       explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
-                       if explicitIVIsSeq {
-                               copy(explicitIV, c.out.seq[:])
-                       } else {
-                               if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil {
-                                       return n, err
-                               }
-                       }
+               c.outBuf[1] = byte(vers >> 8)
+               c.outBuf[2] = byte(vers)
+               c.outBuf[3] = byte(m >> 8)
+               c.outBuf[4] = byte(m)
+
+               var err error
+               c.outBuf, err = c.out.encrypt(c.outBuf, data[:m], c.config.rand())
+               if err != nil {
+                       return n, err
                }
-               copy(b.data[recordHeaderLen+explicitIVLen:], data)
-               c.out.encrypt(b, explicitIVLen)
-               if _, err := c.write(b.data); err != nil {
+               if _, err := c.write(c.outBuf); err != nil {
                        return n, err
                }
                n += m
@@ -1124,14 +1032,14 @@ func (c *Conn) handleRenegotiation() error {
 
 // Read can be made to time out and return a net.Error with Timeout() == true
 // after a fixed time limit; see SetDeadline and SetReadDeadline.
-func (c *Conn) Read(b []byte) (n int, err error) {
-       if err = c.Handshake(); err != nil {
-               return
+func (c *Conn) Read(b []byte) (int, error) {
+       if err := c.Handshake(); err != nil {
+               return 0, err
        }
        if len(b) == 0 {
                // Put this after Handshake, in case people were calling
                // Read(nil) for the side effect of the Handshake.
-               return
+               return 0, nil
        }
 
        c.in.Lock()
@@ -1141,9 +1049,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
        // CBC IV. So this loop ignores a limited number of empty records.
        const maxConsecutiveEmptyRecords = 100
        for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
-               for c.input == nil && c.in.err == nil {
+               for c.input.Len() == 0 && c.in.err == nil {
                        if err := c.readRecord(recordTypeApplicationData); err != nil {
-                               // Soft error, like EAGAIN
                                return 0, err
                        }
                        if c.hand.Len() > 0 {
@@ -1158,33 +1065,24 @@ func (c *Conn) Read(b []byte) (n int, err error) {
                        return 0, err
                }
 
-               n, err = c.input.Read(b)
-               if c.input.off >= len(c.input.data) {
-                       c.in.freeBlock(c.input)
-                       c.input = nil
-               }
-
-               // If a close-notify alert is waiting, read it so that
-               // we can return (n, EOF) instead of (n, nil), to signal
-               // to the HTTP response reading goroutine that the
-               // connection is now closed. This eliminates a race
-               // where the HTTP response reading goroutine would
-               // otherwise not observe the EOF until its next read,
-               // by which time a client goroutine might have already
-               // tried to reuse the HTTP connection for a new
-               // request.
-               // See https://codereview.appspot.com/76400046
-               // and https://golang.org/issue/3514
-               if ri := c.rawInput; ri != nil &&
-                       n != 0 && err == nil &&
-                       c.input == nil && len(ri.data) > 0 && recordType(ri.data[0]) == recordTypeAlert {
-                       if recErr := c.readRecord(recordTypeApplicationData); recErr != nil {
-                               err = recErr // will be io.EOF on closeNotify
+               n, _ := c.input.Read(b)
+
+               // If a close-notify alert is waiting, read it so that we can return (n,
+               // EOF) instead of (n, nil), to signal to the HTTP response reading
+               // goroutine that the connection is now closed. This eliminates a race
+               // where the HTTP response reading goroutine would otherwise not observe
+               // the EOF until its next read, by which time a client goroutine might
+               // have already tried to reuse the HTTP connection for a new request.
+               // See https://golang.org/cl/76400046 and https://golang.org/issue/3514
+               if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
+                       recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
+                       if err := c.readRecord(recordTypeApplicationData); err != nil {
+                               return n, err // will be io.EOF on closeNotify
                        }
                }
 
-               if n != 0 || err != nil {
-                       return n, err
+               if n != 0 {
+                       return n, nil
                }
        }
 
index 8fd429431595ead9b87834a400d0bfb71ff2340b..51932882c013853c6ee62852bb30cfdab0fcfa3c 100644 (file)
@@ -11,6 +11,7 @@ package tls
 // https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
 
 import (
+       "bytes"
        "crypto"
        "crypto/ecdsa"
        "crypto/rsa"
@@ -29,7 +30,10 @@ import (
 // The configuration config must be non-nil and must include
 // at least one certificate or else set GetCertificate.
 func Server(conn net.Conn, config *Config) *Conn {
-       return &Conn{conn: conn, config: config}
+       return &Conn{
+               conn: conn, config: config,
+               input: *bytes.NewReader(nil), // Issue 28269
+       }
 }
 
 // Client returns a new TLS client side connection
@@ -37,7 +41,10 @@ func Server(conn net.Conn, config *Config) *Conn {
 // The config cannot be nil: users must set either ServerName or
 // InsecureSkipVerify in the config.
 func Client(conn net.Conn, config *Config) *Conn {
-       return &Conn{conn: conn, config: config, isClient: true}
+       return &Conn{
+               conn: conn, config: config, isClient: true,
+               input: *bytes.NewReader(nil), // Issue 28269
+       }
 }
 
 // A listener implements a network listener (net.Listener) for TLS connections.