]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/conn.go
crypto/tls: support QUIC as a transport
[gostls13.git] / src / crypto / tls / conn.go
1 // Copyright 2010 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // TLS low level connection and record layer
6
7 package tls
8
9 import (
10         "bytes"
11         "context"
12         "crypto/cipher"
13         "crypto/subtle"
14         "crypto/x509"
15         "errors"
16         "fmt"
17         "hash"
18         "io"
19         "net"
20         "sync"
21         "sync/atomic"
22         "time"
23 )
24
25 // A Conn represents a secured connection.
26 // It implements the net.Conn interface.
27 type Conn struct {
28         // constant
29         conn        net.Conn
30         isClient    bool
31         handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
32         quic        *quicState                  // nil for non-QUIC connections
33
34         // isHandshakeComplete is true if the connection is currently transferring
35         // application data (i.e. is not currently processing a handshake).
36         // isHandshakeComplete is true implies handshakeErr == nil.
37         isHandshakeComplete atomic.Bool
38         // constant after handshake; protected by handshakeMutex
39         handshakeMutex sync.Mutex
40         handshakeErr   error   // error resulting from handshake
41         vers           uint16  // TLS version
42         haveVers       bool    // version has been negotiated
43         config         *Config // configuration passed to constructor
44         // handshakes counts the number of handshakes performed on the
45         // connection so far. If renegotiation is disabled then this is either
46         // zero or one.
47         handshakes       int
48         didResume        bool // whether this connection was a session resumption
49         cipherSuite      uint16
50         ocspResponse     []byte   // stapled OCSP response
51         scts             [][]byte // signed certificate timestamps from server
52         peerCertificates []*x509.Certificate
53         // activeCertHandles contains the cache handles to certificates in
54         // peerCertificates that are used to track active references.
55         activeCertHandles []*activeCert
56         // verifiedChains contains the certificate chains that we built, as
57         // opposed to the ones presented by the server.
58         verifiedChains [][]*x509.Certificate
59         // serverName contains the server name indicated by the client, if any.
60         serverName string
61         // secureRenegotiation is true if the server echoed the secure
62         // renegotiation extension. (This is meaningless as a server because
63         // renegotiation is not supported in that case.)
64         secureRenegotiation bool
65         // ekm is a closure for exporting keying material.
66         ekm func(label string, context []byte, length int) ([]byte, error)
67         // resumptionSecret is the resumption_master_secret for handling
68         // NewSessionTicket messages. nil if config.SessionTicketsDisabled.
69         resumptionSecret []byte
70
71         // ticketKeys is the set of active session ticket keys for this
72         // connection. The first one is used to encrypt new tickets and
73         // all are tried to decrypt tickets.
74         ticketKeys []ticketKey
75
76         // clientFinishedIsFirst is true if the client sent the first Finished
77         // message during the most recent handshake. This is recorded because
78         // the first transmitted Finished message is the tls-unique
79         // channel-binding value.
80         clientFinishedIsFirst bool
81
82         // closeNotifyErr is any error from sending the alertCloseNotify record.
83         closeNotifyErr error
84         // closeNotifySent is true if the Conn attempted to send an
85         // alertCloseNotify record.
86         closeNotifySent bool
87
88         // clientFinished and serverFinished contain the Finished message sent
89         // by the client or server in the most recent handshake. This is
90         // retained to support the renegotiation extension and tls-unique
91         // channel-binding.
92         clientFinished [12]byte
93         serverFinished [12]byte
94
95         // clientProtocol is the negotiated ALPN protocol.
96         clientProtocol string
97
98         // input/output
99         in, out   halfConn
100         rawInput  bytes.Buffer // raw input, starting with a record header
101         input     bytes.Reader // application data waiting to be read, from rawInput.Next
102         hand      bytes.Buffer // handshake data waiting to be read
103         buffering bool         // whether records are buffered in sendBuf
104         sendBuf   []byte       // a buffer of records waiting to be sent
105
106         // bytesSent counts the bytes of application data sent.
107         // packetsSent counts packets.
108         bytesSent   int64
109         packetsSent int64
110
111         // retryCount counts the number of consecutive non-advancing records
112         // received by Conn.readRecord. That is, records that neither advance the
113         // handshake, nor deliver application data. Protected by in.Mutex.
114         retryCount int
115
116         // activeCall indicates whether Close has been call in the low bit.
117         // the rest of the bits are the number of goroutines in Conn.Write.
118         activeCall atomic.Int32
119
120         tmp [16]byte
121 }
122
123 // Access to net.Conn methods.
124 // Cannot just embed net.Conn because that would
125 // export the struct field too.
126
127 // LocalAddr returns the local network address.
128 func (c *Conn) LocalAddr() net.Addr {
129         return c.conn.LocalAddr()
130 }
131
132 // RemoteAddr returns the remote network address.
133 func (c *Conn) RemoteAddr() net.Addr {
134         return c.conn.RemoteAddr()
135 }
136
137 // SetDeadline sets the read and write deadlines associated with the connection.
138 // A zero value for t means Read and Write will not time out.
139 // After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
140 func (c *Conn) SetDeadline(t time.Time) error {
141         return c.conn.SetDeadline(t)
142 }
143
144 // SetReadDeadline sets the read deadline on the underlying connection.
145 // A zero value for t means Read will not time out.
146 func (c *Conn) SetReadDeadline(t time.Time) error {
147         return c.conn.SetReadDeadline(t)
148 }
149
150 // SetWriteDeadline sets the write deadline on the underlying connection.
151 // A zero value for t means Write will not time out.
152 // After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
153 func (c *Conn) SetWriteDeadline(t time.Time) error {
154         return c.conn.SetWriteDeadline(t)
155 }
156
157 // NetConn returns the underlying connection that is wrapped by c.
158 // Note that writing to or reading from this connection directly will corrupt the
159 // TLS session.
160 func (c *Conn) NetConn() net.Conn {
161         return c.conn
162 }
163
164 // A halfConn represents one direction of the record layer
165 // connection, either sending or receiving.
166 type halfConn struct {
167         sync.Mutex
168
169         err     error  // first permanent error
170         version uint16 // protocol version
171         cipher  any    // cipher algorithm
172         mac     hash.Hash
173         seq     [8]byte // 64-bit sequence number
174
175         scratchBuf [13]byte // to avoid allocs; interface method args escape
176
177         nextCipher any       // next encryption state
178         nextMac    hash.Hash // next MAC algorithm
179
180         level         QUICEncryptionLevel // current QUIC encryption level
181         trafficSecret []byte              // current TLS 1.3 traffic secret
182 }
183
184 type permanentError struct {
185         err net.Error
186 }
187
188 func (e *permanentError) Error() string   { return e.err.Error() }
189 func (e *permanentError) Unwrap() error   { return e.err }
190 func (e *permanentError) Timeout() bool   { return e.err.Timeout() }
191 func (e *permanentError) Temporary() bool { return false }
192
193 func (hc *halfConn) setErrorLocked(err error) error {
194         if e, ok := err.(net.Error); ok {
195                 hc.err = &permanentError{err: e}
196         } else {
197                 hc.err = err
198         }
199         return hc.err
200 }
201
202 // prepareCipherSpec sets the encryption and MAC states
203 // that a subsequent changeCipherSpec will use.
204 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
205         hc.version = version
206         hc.nextCipher = cipher
207         hc.nextMac = mac
208 }
209
210 // changeCipherSpec changes the encryption and MAC states
211 // to the ones previously passed to prepareCipherSpec.
212 func (hc *halfConn) changeCipherSpec() error {
213         if hc.nextCipher == nil || hc.version == VersionTLS13 {
214                 return alertInternalError
215         }
216         hc.cipher = hc.nextCipher
217         hc.mac = hc.nextMac
218         hc.nextCipher = nil
219         hc.nextMac = nil
220         for i := range hc.seq {
221                 hc.seq[i] = 0
222         }
223         return nil
224 }
225
226 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
227         hc.trafficSecret = secret
228         hc.level = level
229         key, iv := suite.trafficKey(secret)
230         hc.cipher = suite.aead(key, iv)
231         for i := range hc.seq {
232                 hc.seq[i] = 0
233         }
234 }
235
236 // incSeq increments the sequence number.
237 func (hc *halfConn) incSeq() {
238         for i := 7; i >= 0; i-- {
239                 hc.seq[i]++
240                 if hc.seq[i] != 0 {
241                         return
242                 }
243         }
244
245         // Not allowed to let sequence number wrap.
246         // Instead, must renegotiate before it does.
247         // Not likely enough to bother.
248         panic("TLS: sequence number wraparound")
249 }
250
251 // explicitNonceLen returns the number of bytes of explicit nonce or IV included
252 // in each record. Explicit nonces are present only in CBC modes after TLS 1.0
253 // and in certain AEAD modes in TLS 1.2.
254 func (hc *halfConn) explicitNonceLen() int {
255         if hc.cipher == nil {
256                 return 0
257         }
258
259         switch c := hc.cipher.(type) {
260         case cipher.Stream:
261                 return 0
262         case aead:
263                 return c.explicitNonceLen()
264         case cbcMode:
265                 // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack.
266                 if hc.version >= VersionTLS11 {
267                         return c.BlockSize()
268                 }
269                 return 0
270         default:
271                 panic("unknown cipher type")
272         }
273 }
274
275 // extractPadding returns, in constant time, the length of the padding to remove
276 // from the end of payload. It also returns a byte which is equal to 255 if the
277 // padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2.
278 func extractPadding(payload []byte) (toRemove int, good byte) {
279         if len(payload) < 1 {
280                 return 0, 0
281         }
282
283         paddingLen := payload[len(payload)-1]
284         t := uint(len(payload)-1) - uint(paddingLen)
285         // if len(payload) >= (paddingLen - 1) then the MSB of t is zero
286         good = byte(int32(^t) >> 31)
287
288         // The maximum possible padding length plus the actual length field
289         toCheck := 256
290         // The length of the padded data is public, so we can use an if here
291         if toCheck > len(payload) {
292                 toCheck = len(payload)
293         }
294
295         for i := 0; i < toCheck; i++ {
296                 t := uint(paddingLen) - uint(i)
297                 // if i <= paddingLen then the MSB of t is zero
298                 mask := byte(int32(^t) >> 31)
299                 b := payload[len(payload)-1-i]
300                 good &^= mask&paddingLen ^ mask&b
301         }
302
303         // We AND together the bits of good and replicate the result across
304         // all the bits.
305         good &= good << 4
306         good &= good << 2
307         good &= good << 1
308         good = uint8(int8(good) >> 7)
309
310         // Zero the padding length on error. This ensures any unchecked bytes
311         // are included in the MAC. Otherwise, an attacker that could
312         // distinguish MAC failures from padding failures could mount an attack
313         // similar to POODLE in SSL 3.0: given a good ciphertext that uses a
314         // full block's worth of padding, replace the final block with another
315         // block. If the MAC check passed but the padding check failed, the
316         // last byte of that block decrypted to the block size.
317         //
318         // See also macAndPaddingGood logic below.
319         paddingLen &= good
320
321         toRemove = int(paddingLen) + 1
322         return
323 }
324
325 func roundUp(a, b int) int {
326         return a + (b-a%b)%b
327 }
328
329 // cbcMode is an interface for block ciphers using cipher block chaining.
330 type cbcMode interface {
331         cipher.BlockMode
332         SetIV([]byte)
333 }
334
335 // decrypt authenticates and decrypts the record if protection is active at
336 // this stage. The returned plaintext might overlap with the input.
337 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
338         var plaintext []byte
339         typ := recordType(record[0])
340         payload := record[recordHeaderLen:]
341
342         // In TLS 1.3, change_cipher_spec messages are to be ignored without being
343         // decrypted. See RFC 8446, Appendix D.4.
344         if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
345                 return payload, typ, nil
346         }
347
348         paddingGood := byte(255)
349         paddingLen := 0
350
351         explicitNonceLen := hc.explicitNonceLen()
352
353         if hc.cipher != nil {
354                 switch c := hc.cipher.(type) {
355                 case cipher.Stream:
356                         c.XORKeyStream(payload, payload)
357                 case aead:
358                         if len(payload) < explicitNonceLen {
359                                 return nil, 0, alertBadRecordMAC
360                         }
361                         nonce := payload[:explicitNonceLen]
362                         if len(nonce) == 0 {
363                                 nonce = hc.seq[:]
364                         }
365                         payload = payload[explicitNonceLen:]
366
367                         var additionalData []byte
368                         if hc.version == VersionTLS13 {
369                                 additionalData = record[:recordHeaderLen]
370                         } else {
371                                 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
372                                 additionalData = append(additionalData, record[:3]...)
373                                 n := len(payload) - c.Overhead()
374                                 additionalData = append(additionalData, byte(n>>8), byte(n))
375                         }
376
377                         var err error
378                         plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
379                         if err != nil {
380                                 return nil, 0, alertBadRecordMAC
381                         }
382                 case cbcMode:
383                         blockSize := c.BlockSize()
384                         minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
385                         if len(payload)%blockSize != 0 || len(payload) < minPayload {
386                                 return nil, 0, alertBadRecordMAC
387                         }
388
389                         if explicitNonceLen > 0 {
390                                 c.SetIV(payload[:explicitNonceLen])
391                                 payload = payload[explicitNonceLen:]
392                         }
393                         c.CryptBlocks(payload, payload)
394
395                         // In a limited attempt to protect against CBC padding oracles like
396                         // Lucky13, the data past paddingLen (which is secret) is passed to
397                         // the MAC function as extra data, to be fed into the HMAC after
398                         // computing the digest. This makes the MAC roughly constant time as
399                         // long as the digest computation is constant time and does not
400                         // affect the subsequent write, modulo cache effects.
401                         paddingLen, paddingGood = extractPadding(payload)
402                 default:
403                         panic("unknown cipher type")
404                 }
405
406                 if hc.version == VersionTLS13 {
407                         if typ != recordTypeApplicationData {
408                                 return nil, 0, alertUnexpectedMessage
409                         }
410                         if len(plaintext) > maxPlaintext+1 {
411                                 return nil, 0, alertRecordOverflow
412                         }
413                         // Remove padding and find the ContentType scanning from the end.
414                         for i := len(plaintext) - 1; i >= 0; i-- {
415                                 if plaintext[i] != 0 {
416                                         typ = recordType(plaintext[i])
417                                         plaintext = plaintext[:i]
418                                         break
419                                 }
420                                 if i == 0 {
421                                         return nil, 0, alertUnexpectedMessage
422                                 }
423                         }
424                 }
425         } else {
426                 plaintext = payload
427         }
428
429         if hc.mac != nil {
430                 macSize := hc.mac.Size()
431                 if len(payload) < macSize {
432                         return nil, 0, alertBadRecordMAC
433                 }
434
435                 n := len(payload) - macSize - paddingLen
436                 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 }
437                 record[3] = byte(n >> 8)
438                 record[4] = byte(n)
439                 remoteMAC := payload[n : n+macSize]
440                 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
441
442                 // This is equivalent to checking the MACs and paddingGood
443                 // separately, but in constant-time to prevent distinguishing
444                 // padding failures from MAC failures. Depending on what value
445                 // of paddingLen was returned on bad padding, distinguishing
446                 // bad MAC from bad padding can lead to an attack.
447                 //
448                 // See also the logic at the end of extractPadding.
449                 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
450                 if macAndPaddingGood != 1 {
451                         return nil, 0, alertBadRecordMAC
452                 }
453
454                 plaintext = payload[:n]
455         }
456
457         hc.incSeq()
458         return plaintext, typ, nil
459 }
460
461 // sliceForAppend extends the input slice by n bytes. head is the full extended
462 // slice, while tail is the appended part. If the original slice has sufficient
463 // capacity no allocation is performed.
464 func sliceForAppend(in []byte, n int) (head, tail []byte) {
465         if total := len(in) + n; cap(in) >= total {
466                 head = in[:total]
467         } else {
468                 head = make([]byte, total)
469                 copy(head, in)
470         }
471         tail = head[len(in):]
472         return
473 }
474
475 // encrypt encrypts payload, adding the appropriate nonce and/or MAC, and
476 // appends it to record, which must already contain the record header.
477 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
478         if hc.cipher == nil {
479                 return append(record, payload...), nil
480         }
481
482         var explicitNonce []byte
483         if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
484                 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
485                 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
486                         // The AES-GCM construction in TLS has an explicit nonce so that the
487                         // nonce can be random. However, the nonce is only 8 bytes which is
488                         // too small for a secure, random nonce. Therefore we use the
489                         // sequence number as the nonce. The 3DES-CBC construction also has
490                         // an 8 bytes nonce but its nonces must be unpredictable (see RFC
491                         // 5246, Appendix F.3), forcing us to use randomness. That's not
492                         // 3DES' biggest problem anyway because the birthday bound on block
493                         // collision is reached first due to its similarly small block size
494                         // (see the Sweet32 attack).
495                         copy(explicitNonce, hc.seq[:])
496                 } else {
497                         if _, err := io.ReadFull(rand, explicitNonce); err != nil {
498                                 return nil, err
499                         }
500                 }
501         }
502
503         var dst []byte
504         switch c := hc.cipher.(type) {
505         case cipher.Stream:
506                 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
507                 record, dst = sliceForAppend(record, len(payload)+len(mac))
508                 c.XORKeyStream(dst[:len(payload)], payload)
509                 c.XORKeyStream(dst[len(payload):], mac)
510         case aead:
511                 nonce := explicitNonce
512                 if len(nonce) == 0 {
513                         nonce = hc.seq[:]
514                 }
515
516                 if hc.version == VersionTLS13 {
517                         record = append(record, payload...)
518
519                         // Encrypt the actual ContentType and replace the plaintext one.
520                         record = append(record, record[0])
521                         record[0] = byte(recordTypeApplicationData)
522
523                         n := len(payload) + 1 + c.Overhead()
524                         record[3] = byte(n >> 8)
525                         record[4] = byte(n)
526
527                         record = c.Seal(record[:recordHeaderLen],
528                                 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
529                 } else {
530                         additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
531                         additionalData = append(additionalData, record[:recordHeaderLen]...)
532                         record = c.Seal(record, nonce, payload, additionalData)
533                 }
534         case cbcMode:
535                 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
536                 blockSize := c.BlockSize()
537                 plaintextLen := len(payload) + len(mac)
538                 paddingLen := blockSize - plaintextLen%blockSize
539                 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
540                 copy(dst, payload)
541                 copy(dst[len(payload):], mac)
542                 for i := plaintextLen; i < len(dst); i++ {
543                         dst[i] = byte(paddingLen - 1)
544                 }
545                 if len(explicitNonce) > 0 {
546                         c.SetIV(explicitNonce)
547                 }
548                 c.CryptBlocks(dst, dst)
549         default:
550                 panic("unknown cipher type")
551         }
552
553         // Update length to include nonce, MAC and any block padding needed.
554         n := len(record) - recordHeaderLen
555         record[3] = byte(n >> 8)
556         record[4] = byte(n)
557         hc.incSeq()
558
559         return record, nil
560 }
561
562 // RecordHeaderError is returned when a TLS record header is invalid.
563 type RecordHeaderError struct {
564         // Msg contains a human readable string that describes the error.
565         Msg string
566         // RecordHeader contains the five bytes of TLS record header that
567         // triggered the error.
568         RecordHeader [5]byte
569         // Conn provides the underlying net.Conn in the case that a client
570         // sent an initial handshake that didn't look like TLS.
571         // It is nil if there's already been a handshake or a TLS alert has
572         // been written to the connection.
573         Conn net.Conn
574 }
575
576 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
577
578 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
579         err.Msg = msg
580         err.Conn = conn
581         copy(err.RecordHeader[:], c.rawInput.Bytes())
582         return err
583 }
584
585 func (c *Conn) readRecord() error {
586         return c.readRecordOrCCS(false)
587 }
588
589 func (c *Conn) readChangeCipherSpec() error {
590         return c.readRecordOrCCS(true)
591 }
592
593 // readRecordOrCCS reads one or more TLS records from the connection and
594 // updates the record layer state. Some invariants:
595 //   - c.in must be locked
596 //   - c.input must be empty
597 //
598 // During the handshake one and only one of the following will happen:
599 //   - c.hand grows
600 //   - c.in.changeCipherSpec is called
601 //   - an error is returned
602 //
603 // After the handshake one and only one of the following will happen:
604 //   - c.hand grows
605 //   - c.input is set
606 //   - an error is returned
607 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
608         if c.in.err != nil {
609                 return c.in.err
610         }
611         handshakeComplete := c.isHandshakeComplete.Load()
612
613         // This function modifies c.rawInput, which owns the c.input memory.
614         if c.input.Len() != 0 {
615                 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
616         }
617         c.input.Reset(nil)
618
619         if c.quic != nil {
620                 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
621         }
622
623         // Read header, payload.
624         if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
625                 // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
626                 // is an error, but popular web sites seem to do this, so we accept it
627                 // if and only if at the record boundary.
628                 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
629                         err = io.EOF
630                 }
631                 if e, ok := err.(net.Error); !ok || !e.Temporary() {
632                         c.in.setErrorLocked(err)
633                 }
634                 return err
635         }
636         hdr := c.rawInput.Bytes()[:recordHeaderLen]
637         typ := recordType(hdr[0])
638
639         // No valid TLS record has a type of 0x80, however SSLv2 handshakes
640         // start with a uint16 length where the MSB is set and the first record
641         // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
642         // an SSLv2 client.
643         if !handshakeComplete && typ == 0x80 {
644                 c.sendAlert(alertProtocolVersion)
645                 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
646         }
647
648         vers := uint16(hdr[1])<<8 | uint16(hdr[2])
649         expectedVers := c.vers
650         if expectedVers == VersionTLS13 {
651                 // All TLS 1.3 records are expected to have 0x0303 (1.2) after
652                 // the initial hello (RFC 8446 Section 5.1).
653                 expectedVers = VersionTLS12
654         }
655         n := int(hdr[3])<<8 | int(hdr[4])
656         if c.haveVers && vers != expectedVers {
657                 c.sendAlert(alertProtocolVersion)
658                 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
659                 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
660         }
661         if !c.haveVers {
662                 // First message, be extra suspicious: this might not be a TLS
663                 // client. Bail out before reading a full 'body', if possible.
664                 // The current max version is 3.3 so if the version is >= 16.0,
665                 // it's probably not real.
666                 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
667                         return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
668                 }
669         }
670         if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
671                 c.sendAlert(alertRecordOverflow)
672                 msg := fmt.Sprintf("oversized record received with length %d", n)
673                 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
674         }
675         if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
676                 if e, ok := err.(net.Error); !ok || !e.Temporary() {
677                         c.in.setErrorLocked(err)
678                 }
679                 return err
680         }
681
682         // Process message.
683         record := c.rawInput.Next(recordHeaderLen + n)
684         data, typ, err := c.in.decrypt(record)
685         if err != nil {
686                 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
687         }
688         if len(data) > maxPlaintext {
689                 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
690         }
691
692         // Application Data messages are always protected.
693         if c.in.cipher == nil && typ == recordTypeApplicationData {
694                 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
695         }
696
697         if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
698                 // This is a state-advancing message: reset the retry count.
699                 c.retryCount = 0
700         }
701
702         // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3.
703         if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
704                 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
705         }
706
707         switch typ {
708         default:
709                 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
710
711         case recordTypeAlert:
712                 if c.quic != nil {
713                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
714                 }
715                 if len(data) != 2 {
716                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
717                 }
718                 if alert(data[1]) == alertCloseNotify {
719                         return c.in.setErrorLocked(io.EOF)
720                 }
721                 if c.vers == VersionTLS13 {
722                         return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
723                 }
724                 switch data[0] {
725                 case alertLevelWarning:
726                         // Drop the record on the floor and retry.
727                         return c.retryReadRecord(expectChangeCipherSpec)
728                 case alertLevelError:
729                         return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
730                 default:
731                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
732                 }
733
734         case recordTypeChangeCipherSpec:
735                 if len(data) != 1 || data[0] != 1 {
736                         return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
737                 }
738                 // Handshake messages are not allowed to fragment across the CCS.
739                 if c.hand.Len() > 0 {
740                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
741                 }
742                 // In TLS 1.3, change_cipher_spec records are ignored until the
743                 // Finished. See RFC 8446, Appendix D.4. Note that according to Section
744                 // 5, a server can send a ChangeCipherSpec before its ServerHello, when
745                 // c.vers is still unset. That's not useful though and suspicious if the
746                 // server then selects a lower protocol version, so don't allow that.
747                 if c.vers == VersionTLS13 {
748                         return c.retryReadRecord(expectChangeCipherSpec)
749                 }
750                 if !expectChangeCipherSpec {
751                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
752                 }
753                 if err := c.in.changeCipherSpec(); err != nil {
754                         return c.in.setErrorLocked(c.sendAlert(err.(alert)))
755                 }
756
757         case recordTypeApplicationData:
758                 if !handshakeComplete || expectChangeCipherSpec {
759                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
760                 }
761                 // Some OpenSSL servers send empty records in order to randomize the
762                 // CBC IV. Ignore a limited number of empty records.
763                 if len(data) == 0 {
764                         return c.retryReadRecord(expectChangeCipherSpec)
765                 }
766                 // Note that data is owned by c.rawInput, following the Next call above,
767                 // to avoid copying the plaintext. This is safe because c.rawInput is
768                 // not read from or written to until c.input is drained.
769                 c.input.Reset(data)
770
771         case recordTypeHandshake:
772                 if len(data) == 0 || expectChangeCipherSpec {
773                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
774                 }
775                 c.hand.Write(data)
776         }
777
778         return nil
779 }
780
781 // retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like
782 // a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3.
783 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
784         c.retryCount++
785         if c.retryCount > maxUselessRecords {
786                 c.sendAlert(alertUnexpectedMessage)
787                 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
788         }
789         return c.readRecordOrCCS(expectChangeCipherSpec)
790 }
791
792 // atLeastReader reads from R, stopping with EOF once at least N bytes have been
793 // read. It is different from an io.LimitedReader in that it doesn't cut short
794 // the last Read call, and in that it considers an early EOF an error.
795 type atLeastReader struct {
796         R io.Reader
797         N int64
798 }
799
800 func (r *atLeastReader) Read(p []byte) (int, error) {
801         if r.N <= 0 {
802                 return 0, io.EOF
803         }
804         n, err := r.R.Read(p)
805         r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
806         if r.N > 0 && err == io.EOF {
807                 return n, io.ErrUnexpectedEOF
808         }
809         if r.N <= 0 && err == nil {
810                 return n, io.EOF
811         }
812         return n, err
813 }
814
815 // readFromUntil reads from r into c.rawInput until c.rawInput contains
816 // at least n bytes or else returns an error.
817 func (c *Conn) readFromUntil(r io.Reader, n int) error {
818         if c.rawInput.Len() >= n {
819                 return nil
820         }
821         needs := n - c.rawInput.Len()
822         // There might be extra input waiting on the wire. Make a best effort
823         // attempt to fetch it so that it can be used in (*Conn).Read to
824         // "predict" closeNotify alerts.
825         c.rawInput.Grow(needs + bytes.MinRead)
826         _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
827         return err
828 }
829
830 // sendAlertLocked sends a TLS alert message.
831 func (c *Conn) sendAlertLocked(err alert) error {
832         if c.quic != nil {
833                 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
834         }
835
836         switch err {
837         case alertNoRenegotiation, alertCloseNotify:
838                 c.tmp[0] = alertLevelWarning
839         default:
840                 c.tmp[0] = alertLevelError
841         }
842         c.tmp[1] = byte(err)
843
844         _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
845         if err == alertCloseNotify {
846                 // closeNotify is a special case in that it isn't an error.
847                 return writeErr
848         }
849
850         return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
851 }
852
853 // sendAlert sends a TLS alert message.
854 func (c *Conn) sendAlert(err alert) error {
855         c.out.Lock()
856         defer c.out.Unlock()
857         return c.sendAlertLocked(err)
858 }
859
860 const (
861         // tcpMSSEstimate is a conservative estimate of the TCP maximum segment
862         // size (MSS). A constant is used, rather than querying the kernel for
863         // the actual MSS, to avoid complexity. The value here is the IPv6
864         // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40
865         // bytes) and a TCP header with timestamps (32 bytes).
866         tcpMSSEstimate = 1208
867
868         // recordSizeBoostThreshold is the number of bytes of application data
869         // sent after which the TLS record size will be increased to the
870         // maximum.
871         recordSizeBoostThreshold = 128 * 1024
872 )
873
874 // maxPayloadSizeForWrite returns the maximum TLS payload size to use for the
875 // next application data record. There is the following trade-off:
876 //
877 //   - For latency-sensitive applications, such as web browsing, each TLS
878 //     record should fit in one TCP segment.
879 //   - For throughput-sensitive applications, such as large file transfers,
880 //     larger TLS records better amortize framing and encryption overheads.
881 //
882 // A simple heuristic that works well in practice is to use small records for
883 // the first 1MB of data, then use larger records for subsequent data, and
884 // reset back to smaller records after the connection becomes idle. See "High
885 // Performance Web Networking", Chapter 4, or:
886 // https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/
887 //
888 // In the interests of simplicity and determinism, this code does not attempt
889 // to reset the record size once the connection is idle, however.
890 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
891         if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
892                 return maxPlaintext
893         }
894
895         if c.bytesSent >= recordSizeBoostThreshold {
896                 return maxPlaintext
897         }
898
899         // Subtract TLS overheads to get the maximum payload size.
900         payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
901         if c.out.cipher != nil {
902                 switch ciph := c.out.cipher.(type) {
903                 case cipher.Stream:
904                         payloadBytes -= c.out.mac.Size()
905                 case cipher.AEAD:
906                         payloadBytes -= ciph.Overhead()
907                 case cbcMode:
908                         blockSize := ciph.BlockSize()
909                         // The payload must fit in a multiple of blockSize, with
910                         // room for at least one padding byte.
911                         payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
912                         // The MAC is appended before padding so affects the
913                         // payload size directly.
914                         payloadBytes -= c.out.mac.Size()
915                 default:
916                         panic("unknown cipher type")
917                 }
918         }
919         if c.vers == VersionTLS13 {
920                 payloadBytes-- // encrypted ContentType
921         }
922
923         // Allow packet growth in arithmetic progression up to max.
924         pkt := c.packetsSent
925         c.packetsSent++
926         if pkt > 1000 {
927                 return maxPlaintext // avoid overflow in multiply below
928         }
929
930         n := payloadBytes * int(pkt+1)
931         if n > maxPlaintext {
932                 n = maxPlaintext
933         }
934         return n
935 }
936
937 func (c *Conn) write(data []byte) (int, error) {
938         if c.buffering {
939                 c.sendBuf = append(c.sendBuf, data...)
940                 return len(data), nil
941         }
942
943         n, err := c.conn.Write(data)
944         c.bytesSent += int64(n)
945         return n, err
946 }
947
948 func (c *Conn) flush() (int, error) {
949         if len(c.sendBuf) == 0 {
950                 return 0, nil
951         }
952
953         n, err := c.conn.Write(c.sendBuf)
954         c.bytesSent += int64(n)
955         c.sendBuf = nil
956         c.buffering = false
957         return n, err
958 }
959
960 // outBufPool pools the record-sized scratch buffers used by writeRecordLocked.
961 var outBufPool = sync.Pool{
962         New: func() any {
963                 return new([]byte)
964         },
965 }
966
967 // writeRecordLocked writes a TLS record with the given type and payload to the
968 // connection and updates the record layer state.
969 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
970         if c.quic != nil {
971                 if typ != recordTypeHandshake {
972                         return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
973                 }
974                 c.quicWriteCryptoData(c.out.level, data)
975                 if !c.buffering {
976                         if _, err := c.flush(); err != nil {
977                                 return 0, err
978                         }
979                 }
980                 return len(data), nil
981         }
982
983         outBufPtr := outBufPool.Get().(*[]byte)
984         outBuf := *outBufPtr
985         defer func() {
986                 // You might be tempted to simplify this by just passing &outBuf to Put,
987                 // but that would make the local copy of the outBuf slice header escape
988                 // to the heap, causing an allocation. Instead, we keep around the
989                 // pointer to the slice header returned by Get, which is already on the
990                 // heap, and overwrite and return that.
991                 *outBufPtr = outBuf
992                 outBufPool.Put(outBufPtr)
993         }()
994
995         var n int
996         for len(data) > 0 {
997                 m := len(data)
998                 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
999                         m = maxPayload
1000                 }
1001
1002                 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
1003                 outBuf[0] = byte(typ)
1004                 vers := c.vers
1005                 if vers == 0 {
1006                         // Some TLS servers fail if the record version is
1007                         // greater than TLS 1.0 for the initial ClientHello.
1008                         vers = VersionTLS10
1009                 } else if vers == VersionTLS13 {
1010                         // TLS 1.3 froze the record layer version to 1.2.
1011                         // See RFC 8446, Section 5.1.
1012                         vers = VersionTLS12
1013                 }
1014                 outBuf[1] = byte(vers >> 8)
1015                 outBuf[2] = byte(vers)
1016                 outBuf[3] = byte(m >> 8)
1017                 outBuf[4] = byte(m)
1018
1019                 var err error
1020                 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
1021                 if err != nil {
1022                         return n, err
1023                 }
1024                 if _, err := c.write(outBuf); err != nil {
1025                         return n, err
1026                 }
1027                 n += m
1028                 data = data[m:]
1029         }
1030
1031         if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
1032                 if err := c.out.changeCipherSpec(); err != nil {
1033                         return n, c.sendAlertLocked(err.(alert))
1034                 }
1035         }
1036
1037         return n, nil
1038 }
1039
1040 // writeHandshakeRecord writes a handshake message to the connection and updates
1041 // the record layer state. If transcript is non-nil the marshalled message is
1042 // written to it.
1043 func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
1044         c.out.Lock()
1045         defer c.out.Unlock()
1046
1047         data, err := msg.marshal()
1048         if err != nil {
1049                 return 0, err
1050         }
1051         if transcript != nil {
1052                 transcript.Write(data)
1053         }
1054
1055         return c.writeRecordLocked(recordTypeHandshake, data)
1056 }
1057
1058 // writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and
1059 // updates the record layer state.
1060 func (c *Conn) writeChangeCipherRecord() error {
1061         c.out.Lock()
1062         defer c.out.Unlock()
1063         _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
1064         return err
1065 }
1066
1067 // readHandshakeBytes reads handshake data until c.hand contains at least n bytes.
1068 func (c *Conn) readHandshakeBytes(n int) error {
1069         if c.quic != nil {
1070                 return c.quicReadHandshakeBytes(n)
1071         }
1072         for c.hand.Len() < n {
1073                 if err := c.readRecord(); err != nil {
1074                         return err
1075                 }
1076         }
1077         return nil
1078 }
1079
1080 // readHandshake reads the next handshake message from
1081 // the record layer. If transcript is non-nil, the message
1082 // is written to the passed transcriptHash.
1083 func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
1084         if err := c.readHandshakeBytes(4); err != nil {
1085                 return nil, err
1086         }
1087         data := c.hand.Bytes()
1088         n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1089         if n > maxHandshake {
1090                 c.sendAlertLocked(alertInternalError)
1091                 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1092         }
1093         if err := c.readHandshakeBytes(4 + n); err != nil {
1094                 return nil, err
1095         }
1096         data = c.hand.Next(4 + n)
1097         return c.unmarshalHandshakeMessage(data, transcript)
1098 }
1099
1100 func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
1101         var m handshakeMessage
1102         switch data[0] {
1103         case typeHelloRequest:
1104                 m = new(helloRequestMsg)
1105         case typeClientHello:
1106                 m = new(clientHelloMsg)
1107         case typeServerHello:
1108                 m = new(serverHelloMsg)
1109         case typeNewSessionTicket:
1110                 if c.vers == VersionTLS13 {
1111                         m = new(newSessionTicketMsgTLS13)
1112                 } else {
1113                         m = new(newSessionTicketMsg)
1114                 }
1115         case typeCertificate:
1116                 if c.vers == VersionTLS13 {
1117                         m = new(certificateMsgTLS13)
1118                 } else {
1119                         m = new(certificateMsg)
1120                 }
1121         case typeCertificateRequest:
1122                 if c.vers == VersionTLS13 {
1123                         m = new(certificateRequestMsgTLS13)
1124                 } else {
1125                         m = &certificateRequestMsg{
1126                                 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1127                         }
1128                 }
1129         case typeCertificateStatus:
1130                 m = new(certificateStatusMsg)
1131         case typeServerKeyExchange:
1132                 m = new(serverKeyExchangeMsg)
1133         case typeServerHelloDone:
1134                 m = new(serverHelloDoneMsg)
1135         case typeClientKeyExchange:
1136                 m = new(clientKeyExchangeMsg)
1137         case typeCertificateVerify:
1138                 m = &certificateVerifyMsg{
1139                         hasSignatureAlgorithm: c.vers >= VersionTLS12,
1140                 }
1141         case typeFinished:
1142                 m = new(finishedMsg)
1143         case typeEncryptedExtensions:
1144                 m = new(encryptedExtensionsMsg)
1145         case typeEndOfEarlyData:
1146                 m = new(endOfEarlyDataMsg)
1147         case typeKeyUpdate:
1148                 m = new(keyUpdateMsg)
1149         default:
1150                 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1151         }
1152
1153         // The handshake message unmarshalers
1154         // expect to be able to keep references to data,
1155         // so pass in a fresh copy that won't be overwritten.
1156         data = append([]byte(nil), data...)
1157
1158         if !m.unmarshal(data) {
1159                 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1160         }
1161
1162         if transcript != nil {
1163                 transcript.Write(data)
1164         }
1165
1166         return m, nil
1167 }
1168
1169 var (
1170         errShutdown = errors.New("tls: protocol is shutdown")
1171 )
1172
1173 // Write writes data to the connection.
1174 //
1175 // As Write calls Handshake, in order to prevent indefinite blocking a deadline
1176 // must be set for both Read and Write before Write is called when the handshake
1177 // has not yet completed. See SetDeadline, SetReadDeadline, and
1178 // SetWriteDeadline.
1179 func (c *Conn) Write(b []byte) (int, error) {
1180         // interlock with Close below
1181         for {
1182                 x := c.activeCall.Load()
1183                 if x&1 != 0 {
1184                         return 0, net.ErrClosed
1185                 }
1186                 if c.activeCall.CompareAndSwap(x, x+2) {
1187                         break
1188                 }
1189         }
1190         defer c.activeCall.Add(-2)
1191
1192         if err := c.Handshake(); err != nil {
1193                 return 0, err
1194         }
1195
1196         c.out.Lock()
1197         defer c.out.Unlock()
1198
1199         if err := c.out.err; err != nil {
1200                 return 0, err
1201         }
1202
1203         if !c.isHandshakeComplete.Load() {
1204                 return 0, alertInternalError
1205         }
1206
1207         if c.closeNotifySent {
1208                 return 0, errShutdown
1209         }
1210
1211         // TLS 1.0 is susceptible to a chosen-plaintext
1212         // attack when using block mode ciphers due to predictable IVs.
1213         // This can be prevented by splitting each Application Data
1214         // record into two records, effectively randomizing the IV.
1215         //
1216         // https://www.openssl.org/~bodo/tls-cbc.txt
1217         // https://bugzilla.mozilla.org/show_bug.cgi?id=665814
1218         // https://www.imperialviolet.org/2012/01/15/beastfollowup.html
1219
1220         var m int
1221         if len(b) > 1 && c.vers == VersionTLS10 {
1222                 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1223                         n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1224                         if err != nil {
1225                                 return n, c.out.setErrorLocked(err)
1226                         }
1227                         m, b = 1, b[1:]
1228                 }
1229         }
1230
1231         n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1232         return n + m, c.out.setErrorLocked(err)
1233 }
1234
1235 // handleRenegotiation processes a HelloRequest handshake message.
1236 func (c *Conn) handleRenegotiation() error {
1237         if c.vers == VersionTLS13 {
1238                 return errors.New("tls: internal error: unexpected renegotiation")
1239         }
1240
1241         msg, err := c.readHandshake(nil)
1242         if err != nil {
1243                 return err
1244         }
1245
1246         helloReq, ok := msg.(*helloRequestMsg)
1247         if !ok {
1248                 c.sendAlert(alertUnexpectedMessage)
1249                 return unexpectedMessageError(helloReq, msg)
1250         }
1251
1252         if !c.isClient {
1253                 return c.sendAlert(alertNoRenegotiation)
1254         }
1255
1256         switch c.config.Renegotiation {
1257         case RenegotiateNever:
1258                 return c.sendAlert(alertNoRenegotiation)
1259         case RenegotiateOnceAsClient:
1260                 if c.handshakes > 1 {
1261                         return c.sendAlert(alertNoRenegotiation)
1262                 }
1263         case RenegotiateFreelyAsClient:
1264                 // Ok.
1265         default:
1266                 c.sendAlert(alertInternalError)
1267                 return errors.New("tls: unknown Renegotiation value")
1268         }
1269
1270         c.handshakeMutex.Lock()
1271         defer c.handshakeMutex.Unlock()
1272
1273         c.isHandshakeComplete.Store(false)
1274         if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1275                 c.handshakes++
1276         }
1277         return c.handshakeErr
1278 }
1279
1280 // handlePostHandshakeMessage processes a handshake message arrived after the
1281 // handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
1282 func (c *Conn) handlePostHandshakeMessage() error {
1283         if c.vers != VersionTLS13 {
1284                 return c.handleRenegotiation()
1285         }
1286
1287         msg, err := c.readHandshake(nil)
1288         if err != nil {
1289                 return err
1290         }
1291         c.retryCount++
1292         if c.retryCount > maxUselessRecords {
1293                 c.sendAlert(alertUnexpectedMessage)
1294                 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1295         }
1296
1297         switch msg := msg.(type) {
1298         case *newSessionTicketMsgTLS13:
1299                 return c.handleNewSessionTicket(msg)
1300         case *keyUpdateMsg:
1301                 return c.handleKeyUpdate(msg)
1302         }
1303         // The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
1304         // as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
1305         // unexpected_message alert here doesn't provide it with enough information to distinguish
1306         // this condition from other unexpected messages. This is probably fine.
1307         c.sendAlert(alertUnexpectedMessage)
1308         return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1309 }
1310
1311 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1312         if c.quic != nil {
1313                 c.sendAlert(alertUnexpectedMessage)
1314                 return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
1315         }
1316
1317         cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1318         if cipherSuite == nil {
1319                 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1320         }
1321
1322         newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1323         c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1324
1325         if keyUpdate.updateRequested {
1326                 c.out.Lock()
1327                 defer c.out.Unlock()
1328
1329                 msg := &keyUpdateMsg{}
1330                 msgBytes, err := msg.marshal()
1331                 if err != nil {
1332                         return err
1333                 }
1334                 _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
1335                 if err != nil {
1336                         // Surface the error at the next write.
1337                         c.out.setErrorLocked(err)
1338                         return nil
1339                 }
1340
1341                 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1342                 c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1343         }
1344
1345         return nil
1346 }
1347
1348 // Read reads data from the connection.
1349 //
1350 // As Read calls Handshake, in order to prevent indefinite blocking a deadline
1351 // must be set for both Read and Write before Read is called when the handshake
1352 // has not yet completed. See SetDeadline, SetReadDeadline, and
1353 // SetWriteDeadline.
1354 func (c *Conn) Read(b []byte) (int, error) {
1355         if err := c.Handshake(); err != nil {
1356                 return 0, err
1357         }
1358         if len(b) == 0 {
1359                 // Put this after Handshake, in case people were calling
1360                 // Read(nil) for the side effect of the Handshake.
1361                 return 0, nil
1362         }
1363
1364         c.in.Lock()
1365         defer c.in.Unlock()
1366
1367         for c.input.Len() == 0 {
1368                 if err := c.readRecord(); err != nil {
1369                         return 0, err
1370                 }
1371                 for c.hand.Len() > 0 {
1372                         if err := c.handlePostHandshakeMessage(); err != nil {
1373                                 return 0, err
1374                         }
1375                 }
1376         }
1377
1378         n, _ := c.input.Read(b)
1379
1380         // If a close-notify alert is waiting, read it so that we can return (n,
1381         // EOF) instead of (n, nil), to signal to the HTTP response reading
1382         // goroutine that the connection is now closed. This eliminates a race
1383         // where the HTTP response reading goroutine would otherwise not observe
1384         // the EOF until its next read, by which time a client goroutine might
1385         // have already tried to reuse the HTTP connection for a new request.
1386         // See https://golang.org/cl/76400046 and https://golang.org/issue/3514
1387         if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1388                 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1389                 if err := c.readRecord(); err != nil {
1390                         return n, err // will be io.EOF on closeNotify
1391                 }
1392         }
1393
1394         return n, nil
1395 }
1396
1397 // Close closes the connection.
1398 func (c *Conn) Close() error {
1399         // Interlock with Conn.Write above.
1400         var x int32
1401         for {
1402                 x = c.activeCall.Load()
1403                 if x&1 != 0 {
1404                         return net.ErrClosed
1405                 }
1406                 if c.activeCall.CompareAndSwap(x, x|1) {
1407                         break
1408                 }
1409         }
1410         if x != 0 {
1411                 // io.Writer and io.Closer should not be used concurrently.
1412                 // If Close is called while a Write is currently in-flight,
1413                 // interpret that as a sign that this Close is really just
1414                 // being used to break the Write and/or clean up resources and
1415                 // avoid sending the alertCloseNotify, which may block
1416                 // waiting on handshakeMutex or the c.out mutex.
1417                 return c.conn.Close()
1418         }
1419
1420         var alertErr error
1421         if c.isHandshakeComplete.Load() {
1422                 if err := c.closeNotify(); err != nil {
1423                         alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1424                 }
1425         }
1426
1427         if err := c.conn.Close(); err != nil {
1428                 return err
1429         }
1430         return alertErr
1431 }
1432
1433 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1434
1435 // CloseWrite shuts down the writing side of the connection. It should only be
1436 // called once the handshake has completed and does not call CloseWrite on the
1437 // underlying connection. Most callers should just use Close.
1438 func (c *Conn) CloseWrite() error {
1439         if !c.isHandshakeComplete.Load() {
1440                 return errEarlyCloseWrite
1441         }
1442
1443         return c.closeNotify()
1444 }
1445
1446 func (c *Conn) closeNotify() error {
1447         c.out.Lock()
1448         defer c.out.Unlock()
1449
1450         if !c.closeNotifySent {
1451                 // Set a Write Deadline to prevent possibly blocking forever.
1452                 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1453                 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1454                 c.closeNotifySent = true
1455                 // Any subsequent writes will fail.
1456                 c.SetWriteDeadline(time.Now())
1457         }
1458         return c.closeNotifyErr
1459 }
1460
1461 // Handshake runs the client or server handshake
1462 // protocol if it has not yet been run.
1463 //
1464 // Most uses of this package need not call Handshake explicitly: the
1465 // first Read or Write will call it automatically.
1466 //
1467 // For control over canceling or setting a timeout on a handshake, use
1468 // HandshakeContext or the Dialer's DialContext method instead.
1469 func (c *Conn) Handshake() error {
1470         return c.HandshakeContext(context.Background())
1471 }
1472
1473 // HandshakeContext runs the client or server handshake
1474 // protocol if it has not yet been run.
1475 //
1476 // The provided Context must be non-nil. If the context is canceled before
1477 // the handshake is complete, the handshake is interrupted and an error is returned.
1478 // Once the handshake has completed, cancellation of the context will not affect the
1479 // connection.
1480 //
1481 // Most uses of this package need not call HandshakeContext explicitly: the
1482 // first Read or Write will call it automatically.
1483 func (c *Conn) HandshakeContext(ctx context.Context) error {
1484         // Delegate to unexported method for named return
1485         // without confusing documented signature.
1486         return c.handshakeContext(ctx)
1487 }
1488
1489 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1490         // Fast sync/atomic-based exit if there is no handshake in flight and the
1491         // last one succeeded without an error. Avoids the expensive context setup
1492         // and mutex for most Read and Write calls.
1493         if c.isHandshakeComplete.Load() {
1494                 return nil
1495         }
1496
1497         handshakeCtx, cancel := context.WithCancel(ctx)
1498         // Note: defer this before starting the "interrupter" goroutine
1499         // so that we can tell the difference between the input being canceled and
1500         // this cancellation. In the former case, we need to close the connection.
1501         defer cancel()
1502
1503         if c.quic != nil {
1504                 c.quic.cancelc = handshakeCtx.Done()
1505                 c.quic.cancel = cancel
1506         } else if ctx.Done() != nil {
1507                 // Start the "interrupter" goroutine, if this context might be canceled.
1508                 // (The background context cannot).
1509                 //
1510                 // The interrupter goroutine waits for the input context to be done and
1511                 // closes the connection if this happens before the function returns.
1512                 done := make(chan struct{})
1513                 interruptRes := make(chan error, 1)
1514                 defer func() {
1515                         close(done)
1516                         if ctxErr := <-interruptRes; ctxErr != nil {
1517                                 // Return context error to user.
1518                                 ret = ctxErr
1519                         }
1520                 }()
1521                 go func() {
1522                         select {
1523                         case <-handshakeCtx.Done():
1524                                 // Close the connection, discarding the error
1525                                 _ = c.conn.Close()
1526                                 interruptRes <- handshakeCtx.Err()
1527                         case <-done:
1528                                 interruptRes <- nil
1529                         }
1530                 }()
1531         }
1532
1533         c.handshakeMutex.Lock()
1534         defer c.handshakeMutex.Unlock()
1535
1536         if err := c.handshakeErr; err != nil {
1537                 return err
1538         }
1539         if c.isHandshakeComplete.Load() {
1540                 return nil
1541         }
1542
1543         c.in.Lock()
1544         defer c.in.Unlock()
1545
1546         c.handshakeErr = c.handshakeFn(handshakeCtx)
1547         if c.handshakeErr == nil {
1548                 c.handshakes++
1549         } else {
1550                 // If an error occurred during the handshake try to flush the
1551                 // alert that might be left in the buffer.
1552                 c.flush()
1553         }
1554
1555         if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
1556                 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1557         }
1558         if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
1559                 panic("tls: internal error: handshake returned an error but is marked successful")
1560         }
1561
1562         if c.quic != nil {
1563                 if c.handshakeErr == nil {
1564                         c.quicHandshakeComplete()
1565                         // Provide the 1-RTT read secret now that the handshake is complete.
1566                         // The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
1567                         // the handshake (RFC 9001, Section 5.7).
1568                         c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
1569                 } else {
1570                         var a alert
1571                         c.out.Lock()
1572                         if !errors.As(c.out.err, &a) {
1573                                 a = alertInternalError
1574                         }
1575                         c.out.Unlock()
1576                         // Return an error which wraps both the handshake error and
1577                         // any alert error we may have sent, or alertInternalError
1578                         // if we didn't send an alert.
1579                         // Truncate the text of the alert to 0 characters.
1580                         c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
1581                 }
1582                 close(c.quic.blockedc)
1583                 close(c.quic.signalc)
1584         }
1585
1586         return c.handshakeErr
1587 }
1588
1589 // ConnectionState returns basic TLS details about the connection.
1590 func (c *Conn) ConnectionState() ConnectionState {
1591         c.handshakeMutex.Lock()
1592         defer c.handshakeMutex.Unlock()
1593         return c.connectionStateLocked()
1594 }
1595
1596 func (c *Conn) connectionStateLocked() ConnectionState {
1597         var state ConnectionState
1598         state.HandshakeComplete = c.isHandshakeComplete.Load()
1599         state.Version = c.vers
1600         state.NegotiatedProtocol = c.clientProtocol
1601         state.DidResume = c.didResume
1602         state.NegotiatedProtocolIsMutual = true
1603         state.ServerName = c.serverName
1604         state.CipherSuite = c.cipherSuite
1605         state.PeerCertificates = c.peerCertificates
1606         state.VerifiedChains = c.verifiedChains
1607         state.SignedCertificateTimestamps = c.scts
1608         state.OCSPResponse = c.ocspResponse
1609         if !c.didResume && c.vers != VersionTLS13 {
1610                 if c.clientFinishedIsFirst {
1611                         state.TLSUnique = c.clientFinished[:]
1612                 } else {
1613                         state.TLSUnique = c.serverFinished[:]
1614                 }
1615         }
1616         if c.config.Renegotiation != RenegotiateNever {
1617                 state.ekm = noExportedKeyingMaterial
1618         } else {
1619                 state.ekm = c.ekm
1620         }
1621         return state
1622 }
1623
1624 // OCSPResponse returns the stapled OCSP response from the TLS server, if
1625 // any. (Only valid for client connections.)
1626 func (c *Conn) OCSPResponse() []byte {
1627         c.handshakeMutex.Lock()
1628         defer c.handshakeMutex.Unlock()
1629
1630         return c.ocspResponse
1631 }
1632
1633 // VerifyHostname checks that the peer certificate chain is valid for
1634 // connecting to host. If so, it returns nil; if not, it returns an error
1635 // describing the problem.
1636 func (c *Conn) VerifyHostname(host string) error {
1637         c.handshakeMutex.Lock()
1638         defer c.handshakeMutex.Unlock()
1639         if !c.isClient {
1640                 return errors.New("tls: VerifyHostname called on TLS server connection")
1641         }
1642         if !c.isHandshakeComplete.Load() {
1643                 return errors.New("tls: handshake has not yet been performed")
1644         }
1645         if len(c.verifiedChains) == 0 {
1646                 return errors.New("tls: handshake did not verify certificate chain")
1647         }
1648         return c.peerCertificates[0].VerifyHostname(host)
1649 }