]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/conn.go
crypto/tls: implement Extended Master Secret
[gostls13.git] / src / crypto / tls / conn.go
1 // Copyright 2010 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // TLS low level connection and record layer
6
7 package tls
8
9 import (
10         "bytes"
11         "context"
12         "crypto/cipher"
13         "crypto/subtle"
14         "crypto/x509"
15         "errors"
16         "fmt"
17         "hash"
18         "io"
19         "net"
20         "sync"
21         "sync/atomic"
22         "time"
23 )
24
25 // A Conn represents a secured connection.
26 // It implements the net.Conn interface.
27 type Conn struct {
28         // constant
29         conn        net.Conn
30         isClient    bool
31         handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
32         quic        *quicState                  // nil for non-QUIC connections
33
34         // isHandshakeComplete is true if the connection is currently transferring
35         // application data (i.e. is not currently processing a handshake).
36         // isHandshakeComplete is true implies handshakeErr == nil.
37         isHandshakeComplete atomic.Bool
38         // constant after handshake; protected by handshakeMutex
39         handshakeMutex sync.Mutex
40         handshakeErr   error   // error resulting from handshake
41         vers           uint16  // TLS version
42         haveVers       bool    // version has been negotiated
43         config         *Config // configuration passed to constructor
44         // handshakes counts the number of handshakes performed on the
45         // connection so far. If renegotiation is disabled then this is either
46         // zero or one.
47         handshakes       int
48         extMasterSecret  bool
49         didResume        bool // whether this connection was a session resumption
50         cipherSuite      uint16
51         ocspResponse     []byte   // stapled OCSP response
52         scts             [][]byte // signed certificate timestamps from server
53         peerCertificates []*x509.Certificate
54         // activeCertHandles contains the cache handles to certificates in
55         // peerCertificates that are used to track active references.
56         activeCertHandles []*activeCert
57         // verifiedChains contains the certificate chains that we built, as
58         // opposed to the ones presented by the server.
59         verifiedChains [][]*x509.Certificate
60         // serverName contains the server name indicated by the client, if any.
61         serverName string
62         // secureRenegotiation is true if the server echoed the secure
63         // renegotiation extension. (This is meaningless as a server because
64         // renegotiation is not supported in that case.)
65         secureRenegotiation bool
66         // ekm is a closure for exporting keying material.
67         ekm func(label string, context []byte, length int) ([]byte, error)
68         // resumptionSecret is the resumption_master_secret for handling
69         // or sending NewSessionTicket messages.
70         resumptionSecret []byte
71
72         // ticketKeys is the set of active session ticket keys for this
73         // connection. The first one is used to encrypt new tickets and
74         // all are tried to decrypt tickets.
75         ticketKeys []ticketKey
76
77         // clientFinishedIsFirst is true if the client sent the first Finished
78         // message during the most recent handshake. This is recorded because
79         // the first transmitted Finished message is the tls-unique
80         // channel-binding value.
81         clientFinishedIsFirst bool
82
83         // closeNotifyErr is any error from sending the alertCloseNotify record.
84         closeNotifyErr error
85         // closeNotifySent is true if the Conn attempted to send an
86         // alertCloseNotify record.
87         closeNotifySent bool
88
89         // clientFinished and serverFinished contain the Finished message sent
90         // by the client or server in the most recent handshake. This is
91         // retained to support the renegotiation extension and tls-unique
92         // channel-binding.
93         clientFinished [12]byte
94         serverFinished [12]byte
95
96         // clientProtocol is the negotiated ALPN protocol.
97         clientProtocol string
98
99         // input/output
100         in, out   halfConn
101         rawInput  bytes.Buffer // raw input, starting with a record header
102         input     bytes.Reader // application data waiting to be read, from rawInput.Next
103         hand      bytes.Buffer // handshake data waiting to be read
104         buffering bool         // whether records are buffered in sendBuf
105         sendBuf   []byte       // a buffer of records waiting to be sent
106
107         // bytesSent counts the bytes of application data sent.
108         // packetsSent counts packets.
109         bytesSent   int64
110         packetsSent int64
111
112         // retryCount counts the number of consecutive non-advancing records
113         // received by Conn.readRecord. That is, records that neither advance the
114         // handshake, nor deliver application data. Protected by in.Mutex.
115         retryCount int
116
117         // activeCall indicates whether Close has been call in the low bit.
118         // the rest of the bits are the number of goroutines in Conn.Write.
119         activeCall atomic.Int32
120
121         tmp [16]byte
122 }
123
124 // Access to net.Conn methods.
125 // Cannot just embed net.Conn because that would
126 // export the struct field too.
127
128 // LocalAddr returns the local network address.
129 func (c *Conn) LocalAddr() net.Addr {
130         return c.conn.LocalAddr()
131 }
132
133 // RemoteAddr returns the remote network address.
134 func (c *Conn) RemoteAddr() net.Addr {
135         return c.conn.RemoteAddr()
136 }
137
138 // SetDeadline sets the read and write deadlines associated with the connection.
139 // A zero value for t means Read and Write will not time out.
140 // After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
141 func (c *Conn) SetDeadline(t time.Time) error {
142         return c.conn.SetDeadline(t)
143 }
144
145 // SetReadDeadline sets the read deadline on the underlying connection.
146 // A zero value for t means Read will not time out.
147 func (c *Conn) SetReadDeadline(t time.Time) error {
148         return c.conn.SetReadDeadline(t)
149 }
150
151 // SetWriteDeadline sets the write deadline on the underlying connection.
152 // A zero value for t means Write will not time out.
153 // After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
154 func (c *Conn) SetWriteDeadline(t time.Time) error {
155         return c.conn.SetWriteDeadline(t)
156 }
157
158 // NetConn returns the underlying connection that is wrapped by c.
159 // Note that writing to or reading from this connection directly will corrupt the
160 // TLS session.
161 func (c *Conn) NetConn() net.Conn {
162         return c.conn
163 }
164
165 // A halfConn represents one direction of the record layer
166 // connection, either sending or receiving.
167 type halfConn struct {
168         sync.Mutex
169
170         err     error  // first permanent error
171         version uint16 // protocol version
172         cipher  any    // cipher algorithm
173         mac     hash.Hash
174         seq     [8]byte // 64-bit sequence number
175
176         scratchBuf [13]byte // to avoid allocs; interface method args escape
177
178         nextCipher any       // next encryption state
179         nextMac    hash.Hash // next MAC algorithm
180
181         level         QUICEncryptionLevel // current QUIC encryption level
182         trafficSecret []byte              // current TLS 1.3 traffic secret
183 }
184
185 type permanentError struct {
186         err net.Error
187 }
188
189 func (e *permanentError) Error() string   { return e.err.Error() }
190 func (e *permanentError) Unwrap() error   { return e.err }
191 func (e *permanentError) Timeout() bool   { return e.err.Timeout() }
192 func (e *permanentError) Temporary() bool { return false }
193
194 func (hc *halfConn) setErrorLocked(err error) error {
195         if e, ok := err.(net.Error); ok {
196                 hc.err = &permanentError{err: e}
197         } else {
198                 hc.err = err
199         }
200         return hc.err
201 }
202
203 // prepareCipherSpec sets the encryption and MAC states
204 // that a subsequent changeCipherSpec will use.
205 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
206         hc.version = version
207         hc.nextCipher = cipher
208         hc.nextMac = mac
209 }
210
211 // changeCipherSpec changes the encryption and MAC states
212 // to the ones previously passed to prepareCipherSpec.
213 func (hc *halfConn) changeCipherSpec() error {
214         if hc.nextCipher == nil || hc.version == VersionTLS13 {
215                 return alertInternalError
216         }
217         hc.cipher = hc.nextCipher
218         hc.mac = hc.nextMac
219         hc.nextCipher = nil
220         hc.nextMac = nil
221         for i := range hc.seq {
222                 hc.seq[i] = 0
223         }
224         return nil
225 }
226
227 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
228         hc.trafficSecret = secret
229         hc.level = level
230         key, iv := suite.trafficKey(secret)
231         hc.cipher = suite.aead(key, iv)
232         for i := range hc.seq {
233                 hc.seq[i] = 0
234         }
235 }
236
237 // incSeq increments the sequence number.
238 func (hc *halfConn) incSeq() {
239         for i := 7; i >= 0; i-- {
240                 hc.seq[i]++
241                 if hc.seq[i] != 0 {
242                         return
243                 }
244         }
245
246         // Not allowed to let sequence number wrap.
247         // Instead, must renegotiate before it does.
248         // Not likely enough to bother.
249         panic("TLS: sequence number wraparound")
250 }
251
252 // explicitNonceLen returns the number of bytes of explicit nonce or IV included
253 // in each record. Explicit nonces are present only in CBC modes after TLS 1.0
254 // and in certain AEAD modes in TLS 1.2.
255 func (hc *halfConn) explicitNonceLen() int {
256         if hc.cipher == nil {
257                 return 0
258         }
259
260         switch c := hc.cipher.(type) {
261         case cipher.Stream:
262                 return 0
263         case aead:
264                 return c.explicitNonceLen()
265         case cbcMode:
266                 // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack.
267                 if hc.version >= VersionTLS11 {
268                         return c.BlockSize()
269                 }
270                 return 0
271         default:
272                 panic("unknown cipher type")
273         }
274 }
275
276 // extractPadding returns, in constant time, the length of the padding to remove
277 // from the end of payload. It also returns a byte which is equal to 255 if the
278 // padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2.
279 func extractPadding(payload []byte) (toRemove int, good byte) {
280         if len(payload) < 1 {
281                 return 0, 0
282         }
283
284         paddingLen := payload[len(payload)-1]
285         t := uint(len(payload)-1) - uint(paddingLen)
286         // if len(payload) >= (paddingLen - 1) then the MSB of t is zero
287         good = byte(int32(^t) >> 31)
288
289         // The maximum possible padding length plus the actual length field
290         toCheck := 256
291         // The length of the padded data is public, so we can use an if here
292         if toCheck > len(payload) {
293                 toCheck = len(payload)
294         }
295
296         for i := 0; i < toCheck; i++ {
297                 t := uint(paddingLen) - uint(i)
298                 // if i <= paddingLen then the MSB of t is zero
299                 mask := byte(int32(^t) >> 31)
300                 b := payload[len(payload)-1-i]
301                 good &^= mask&paddingLen ^ mask&b
302         }
303
304         // We AND together the bits of good and replicate the result across
305         // all the bits.
306         good &= good << 4
307         good &= good << 2
308         good &= good << 1
309         good = uint8(int8(good) >> 7)
310
311         // Zero the padding length on error. This ensures any unchecked bytes
312         // are included in the MAC. Otherwise, an attacker that could
313         // distinguish MAC failures from padding failures could mount an attack
314         // similar to POODLE in SSL 3.0: given a good ciphertext that uses a
315         // full block's worth of padding, replace the final block with another
316         // block. If the MAC check passed but the padding check failed, the
317         // last byte of that block decrypted to the block size.
318         //
319         // See also macAndPaddingGood logic below.
320         paddingLen &= good
321
322         toRemove = int(paddingLen) + 1
323         return
324 }
325
326 func roundUp(a, b int) int {
327         return a + (b-a%b)%b
328 }
329
330 // cbcMode is an interface for block ciphers using cipher block chaining.
331 type cbcMode interface {
332         cipher.BlockMode
333         SetIV([]byte)
334 }
335
336 // decrypt authenticates and decrypts the record if protection is active at
337 // this stage. The returned plaintext might overlap with the input.
338 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
339         var plaintext []byte
340         typ := recordType(record[0])
341         payload := record[recordHeaderLen:]
342
343         // In TLS 1.3, change_cipher_spec messages are to be ignored without being
344         // decrypted. See RFC 8446, Appendix D.4.
345         if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
346                 return payload, typ, nil
347         }
348
349         paddingGood := byte(255)
350         paddingLen := 0
351
352         explicitNonceLen := hc.explicitNonceLen()
353
354         if hc.cipher != nil {
355                 switch c := hc.cipher.(type) {
356                 case cipher.Stream:
357                         c.XORKeyStream(payload, payload)
358                 case aead:
359                         if len(payload) < explicitNonceLen {
360                                 return nil, 0, alertBadRecordMAC
361                         }
362                         nonce := payload[:explicitNonceLen]
363                         if len(nonce) == 0 {
364                                 nonce = hc.seq[:]
365                         }
366                         payload = payload[explicitNonceLen:]
367
368                         var additionalData []byte
369                         if hc.version == VersionTLS13 {
370                                 additionalData = record[:recordHeaderLen]
371                         } else {
372                                 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
373                                 additionalData = append(additionalData, record[:3]...)
374                                 n := len(payload) - c.Overhead()
375                                 additionalData = append(additionalData, byte(n>>8), byte(n))
376                         }
377
378                         var err error
379                         plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
380                         if err != nil {
381                                 return nil, 0, alertBadRecordMAC
382                         }
383                 case cbcMode:
384                         blockSize := c.BlockSize()
385                         minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
386                         if len(payload)%blockSize != 0 || len(payload) < minPayload {
387                                 return nil, 0, alertBadRecordMAC
388                         }
389
390                         if explicitNonceLen > 0 {
391                                 c.SetIV(payload[:explicitNonceLen])
392                                 payload = payload[explicitNonceLen:]
393                         }
394                         c.CryptBlocks(payload, payload)
395
396                         // In a limited attempt to protect against CBC padding oracles like
397                         // Lucky13, the data past paddingLen (which is secret) is passed to
398                         // the MAC function as extra data, to be fed into the HMAC after
399                         // computing the digest. This makes the MAC roughly constant time as
400                         // long as the digest computation is constant time and does not
401                         // affect the subsequent write, modulo cache effects.
402                         paddingLen, paddingGood = extractPadding(payload)
403                 default:
404                         panic("unknown cipher type")
405                 }
406
407                 if hc.version == VersionTLS13 {
408                         if typ != recordTypeApplicationData {
409                                 return nil, 0, alertUnexpectedMessage
410                         }
411                         if len(plaintext) > maxPlaintext+1 {
412                                 return nil, 0, alertRecordOverflow
413                         }
414                         // Remove padding and find the ContentType scanning from the end.
415                         for i := len(plaintext) - 1; i >= 0; i-- {
416                                 if plaintext[i] != 0 {
417                                         typ = recordType(plaintext[i])
418                                         plaintext = plaintext[:i]
419                                         break
420                                 }
421                                 if i == 0 {
422                                         return nil, 0, alertUnexpectedMessage
423                                 }
424                         }
425                 }
426         } else {
427                 plaintext = payload
428         }
429
430         if hc.mac != nil {
431                 macSize := hc.mac.Size()
432                 if len(payload) < macSize {
433                         return nil, 0, alertBadRecordMAC
434                 }
435
436                 n := len(payload) - macSize - paddingLen
437                 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 }
438                 record[3] = byte(n >> 8)
439                 record[4] = byte(n)
440                 remoteMAC := payload[n : n+macSize]
441                 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
442
443                 // This is equivalent to checking the MACs and paddingGood
444                 // separately, but in constant-time to prevent distinguishing
445                 // padding failures from MAC failures. Depending on what value
446                 // of paddingLen was returned on bad padding, distinguishing
447                 // bad MAC from bad padding can lead to an attack.
448                 //
449                 // See also the logic at the end of extractPadding.
450                 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
451                 if macAndPaddingGood != 1 {
452                         return nil, 0, alertBadRecordMAC
453                 }
454
455                 plaintext = payload[:n]
456         }
457
458         hc.incSeq()
459         return plaintext, typ, nil
460 }
461
462 // sliceForAppend extends the input slice by n bytes. head is the full extended
463 // slice, while tail is the appended part. If the original slice has sufficient
464 // capacity no allocation is performed.
465 func sliceForAppend(in []byte, n int) (head, tail []byte) {
466         if total := len(in) + n; cap(in) >= total {
467                 head = in[:total]
468         } else {
469                 head = make([]byte, total)
470                 copy(head, in)
471         }
472         tail = head[len(in):]
473         return
474 }
475
476 // encrypt encrypts payload, adding the appropriate nonce and/or MAC, and
477 // appends it to record, which must already contain the record header.
478 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
479         if hc.cipher == nil {
480                 return append(record, payload...), nil
481         }
482
483         var explicitNonce []byte
484         if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
485                 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
486                 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
487                         // The AES-GCM construction in TLS has an explicit nonce so that the
488                         // nonce can be random. However, the nonce is only 8 bytes which is
489                         // too small for a secure, random nonce. Therefore we use the
490                         // sequence number as the nonce. The 3DES-CBC construction also has
491                         // an 8 bytes nonce but its nonces must be unpredictable (see RFC
492                         // 5246, Appendix F.3), forcing us to use randomness. That's not
493                         // 3DES' biggest problem anyway because the birthday bound on block
494                         // collision is reached first due to its similarly small block size
495                         // (see the Sweet32 attack).
496                         copy(explicitNonce, hc.seq[:])
497                 } else {
498                         if _, err := io.ReadFull(rand, explicitNonce); err != nil {
499                                 return nil, err
500                         }
501                 }
502         }
503
504         var dst []byte
505         switch c := hc.cipher.(type) {
506         case cipher.Stream:
507                 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
508                 record, dst = sliceForAppend(record, len(payload)+len(mac))
509                 c.XORKeyStream(dst[:len(payload)], payload)
510                 c.XORKeyStream(dst[len(payload):], mac)
511         case aead:
512                 nonce := explicitNonce
513                 if len(nonce) == 0 {
514                         nonce = hc.seq[:]
515                 }
516
517                 if hc.version == VersionTLS13 {
518                         record = append(record, payload...)
519
520                         // Encrypt the actual ContentType and replace the plaintext one.
521                         record = append(record, record[0])
522                         record[0] = byte(recordTypeApplicationData)
523
524                         n := len(payload) + 1 + c.Overhead()
525                         record[3] = byte(n >> 8)
526                         record[4] = byte(n)
527
528                         record = c.Seal(record[:recordHeaderLen],
529                                 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
530                 } else {
531                         additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
532                         additionalData = append(additionalData, record[:recordHeaderLen]...)
533                         record = c.Seal(record, nonce, payload, additionalData)
534                 }
535         case cbcMode:
536                 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
537                 blockSize := c.BlockSize()
538                 plaintextLen := len(payload) + len(mac)
539                 paddingLen := blockSize - plaintextLen%blockSize
540                 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
541                 copy(dst, payload)
542                 copy(dst[len(payload):], mac)
543                 for i := plaintextLen; i < len(dst); i++ {
544                         dst[i] = byte(paddingLen - 1)
545                 }
546                 if len(explicitNonce) > 0 {
547                         c.SetIV(explicitNonce)
548                 }
549                 c.CryptBlocks(dst, dst)
550         default:
551                 panic("unknown cipher type")
552         }
553
554         // Update length to include nonce, MAC and any block padding needed.
555         n := len(record) - recordHeaderLen
556         record[3] = byte(n >> 8)
557         record[4] = byte(n)
558         hc.incSeq()
559
560         return record, nil
561 }
562
563 // RecordHeaderError is returned when a TLS record header is invalid.
564 type RecordHeaderError struct {
565         // Msg contains a human readable string that describes the error.
566         Msg string
567         // RecordHeader contains the five bytes of TLS record header that
568         // triggered the error.
569         RecordHeader [5]byte
570         // Conn provides the underlying net.Conn in the case that a client
571         // sent an initial handshake that didn't look like TLS.
572         // It is nil if there's already been a handshake or a TLS alert has
573         // been written to the connection.
574         Conn net.Conn
575 }
576
577 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
578
579 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
580         err.Msg = msg
581         err.Conn = conn
582         copy(err.RecordHeader[:], c.rawInput.Bytes())
583         return err
584 }
585
586 func (c *Conn) readRecord() error {
587         return c.readRecordOrCCS(false)
588 }
589
590 func (c *Conn) readChangeCipherSpec() error {
591         return c.readRecordOrCCS(true)
592 }
593
594 // readRecordOrCCS reads one or more TLS records from the connection and
595 // updates the record layer state. Some invariants:
596 //   - c.in must be locked
597 //   - c.input must be empty
598 //
599 // During the handshake one and only one of the following will happen:
600 //   - c.hand grows
601 //   - c.in.changeCipherSpec is called
602 //   - an error is returned
603 //
604 // After the handshake one and only one of the following will happen:
605 //   - c.hand grows
606 //   - c.input is set
607 //   - an error is returned
608 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
609         if c.in.err != nil {
610                 return c.in.err
611         }
612         handshakeComplete := c.isHandshakeComplete.Load()
613
614         // This function modifies c.rawInput, which owns the c.input memory.
615         if c.input.Len() != 0 {
616                 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
617         }
618         c.input.Reset(nil)
619
620         if c.quic != nil {
621                 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
622         }
623
624         // Read header, payload.
625         if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
626                 // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
627                 // is an error, but popular web sites seem to do this, so we accept it
628                 // if and only if at the record boundary.
629                 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
630                         err = io.EOF
631                 }
632                 if e, ok := err.(net.Error); !ok || !e.Temporary() {
633                         c.in.setErrorLocked(err)
634                 }
635                 return err
636         }
637         hdr := c.rawInput.Bytes()[:recordHeaderLen]
638         typ := recordType(hdr[0])
639
640         // No valid TLS record has a type of 0x80, however SSLv2 handshakes
641         // start with a uint16 length where the MSB is set and the first record
642         // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
643         // an SSLv2 client.
644         if !handshakeComplete && typ == 0x80 {
645                 c.sendAlert(alertProtocolVersion)
646                 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
647         }
648
649         vers := uint16(hdr[1])<<8 | uint16(hdr[2])
650         expectedVers := c.vers
651         if expectedVers == VersionTLS13 {
652                 // All TLS 1.3 records are expected to have 0x0303 (1.2) after
653                 // the initial hello (RFC 8446 Section 5.1).
654                 expectedVers = VersionTLS12
655         }
656         n := int(hdr[3])<<8 | int(hdr[4])
657         if c.haveVers && vers != expectedVers {
658                 c.sendAlert(alertProtocolVersion)
659                 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
660                 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
661         }
662         if !c.haveVers {
663                 // First message, be extra suspicious: this might not be a TLS
664                 // client. Bail out before reading a full 'body', if possible.
665                 // The current max version is 3.3 so if the version is >= 16.0,
666                 // it's probably not real.
667                 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
668                         return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
669                 }
670         }
671         if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
672                 c.sendAlert(alertRecordOverflow)
673                 msg := fmt.Sprintf("oversized record received with length %d", n)
674                 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
675         }
676         if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
677                 if e, ok := err.(net.Error); !ok || !e.Temporary() {
678                         c.in.setErrorLocked(err)
679                 }
680                 return err
681         }
682
683         // Process message.
684         record := c.rawInput.Next(recordHeaderLen + n)
685         data, typ, err := c.in.decrypt(record)
686         if err != nil {
687                 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
688         }
689         if len(data) > maxPlaintext {
690                 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
691         }
692
693         // Application Data messages are always protected.
694         if c.in.cipher == nil && typ == recordTypeApplicationData {
695                 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
696         }
697
698         if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
699                 // This is a state-advancing message: reset the retry count.
700                 c.retryCount = 0
701         }
702
703         // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3.
704         if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
705                 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
706         }
707
708         switch typ {
709         default:
710                 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
711
712         case recordTypeAlert:
713                 if c.quic != nil {
714                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
715                 }
716                 if len(data) != 2 {
717                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
718                 }
719                 if alert(data[1]) == alertCloseNotify {
720                         return c.in.setErrorLocked(io.EOF)
721                 }
722                 if c.vers == VersionTLS13 {
723                         return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
724                 }
725                 switch data[0] {
726                 case alertLevelWarning:
727                         // Drop the record on the floor and retry.
728                         return c.retryReadRecord(expectChangeCipherSpec)
729                 case alertLevelError:
730                         return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
731                 default:
732                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
733                 }
734
735         case recordTypeChangeCipherSpec:
736                 if len(data) != 1 || data[0] != 1 {
737                         return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
738                 }
739                 // Handshake messages are not allowed to fragment across the CCS.
740                 if c.hand.Len() > 0 {
741                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
742                 }
743                 // In TLS 1.3, change_cipher_spec records are ignored until the
744                 // Finished. See RFC 8446, Appendix D.4. Note that according to Section
745                 // 5, a server can send a ChangeCipherSpec before its ServerHello, when
746                 // c.vers is still unset. That's not useful though and suspicious if the
747                 // server then selects a lower protocol version, so don't allow that.
748                 if c.vers == VersionTLS13 {
749                         return c.retryReadRecord(expectChangeCipherSpec)
750                 }
751                 if !expectChangeCipherSpec {
752                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
753                 }
754                 if err := c.in.changeCipherSpec(); err != nil {
755                         return c.in.setErrorLocked(c.sendAlert(err.(alert)))
756                 }
757
758         case recordTypeApplicationData:
759                 if !handshakeComplete || expectChangeCipherSpec {
760                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
761                 }
762                 // Some OpenSSL servers send empty records in order to randomize the
763                 // CBC IV. Ignore a limited number of empty records.
764                 if len(data) == 0 {
765                         return c.retryReadRecord(expectChangeCipherSpec)
766                 }
767                 // Note that data is owned by c.rawInput, following the Next call above,
768                 // to avoid copying the plaintext. This is safe because c.rawInput is
769                 // not read from or written to until c.input is drained.
770                 c.input.Reset(data)
771
772         case recordTypeHandshake:
773                 if len(data) == 0 || expectChangeCipherSpec {
774                         return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
775                 }
776                 c.hand.Write(data)
777         }
778
779         return nil
780 }
781
782 // retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like
783 // a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3.
784 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
785         c.retryCount++
786         if c.retryCount > maxUselessRecords {
787                 c.sendAlert(alertUnexpectedMessage)
788                 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
789         }
790         return c.readRecordOrCCS(expectChangeCipherSpec)
791 }
792
793 // atLeastReader reads from R, stopping with EOF once at least N bytes have been
794 // read. It is different from an io.LimitedReader in that it doesn't cut short
795 // the last Read call, and in that it considers an early EOF an error.
796 type atLeastReader struct {
797         R io.Reader
798         N int64
799 }
800
801 func (r *atLeastReader) Read(p []byte) (int, error) {
802         if r.N <= 0 {
803                 return 0, io.EOF
804         }
805         n, err := r.R.Read(p)
806         r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
807         if r.N > 0 && err == io.EOF {
808                 return n, io.ErrUnexpectedEOF
809         }
810         if r.N <= 0 && err == nil {
811                 return n, io.EOF
812         }
813         return n, err
814 }
815
816 // readFromUntil reads from r into c.rawInput until c.rawInput contains
817 // at least n bytes or else returns an error.
818 func (c *Conn) readFromUntil(r io.Reader, n int) error {
819         if c.rawInput.Len() >= n {
820                 return nil
821         }
822         needs := n - c.rawInput.Len()
823         // There might be extra input waiting on the wire. Make a best effort
824         // attempt to fetch it so that it can be used in (*Conn).Read to
825         // "predict" closeNotify alerts.
826         c.rawInput.Grow(needs + bytes.MinRead)
827         _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
828         return err
829 }
830
831 // sendAlertLocked sends a TLS alert message.
832 func (c *Conn) sendAlertLocked(err alert) error {
833         if c.quic != nil {
834                 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
835         }
836
837         switch err {
838         case alertNoRenegotiation, alertCloseNotify:
839                 c.tmp[0] = alertLevelWarning
840         default:
841                 c.tmp[0] = alertLevelError
842         }
843         c.tmp[1] = byte(err)
844
845         _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
846         if err == alertCloseNotify {
847                 // closeNotify is a special case in that it isn't an error.
848                 return writeErr
849         }
850
851         return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
852 }
853
854 // sendAlert sends a TLS alert message.
855 func (c *Conn) sendAlert(err alert) error {
856         c.out.Lock()
857         defer c.out.Unlock()
858         return c.sendAlertLocked(err)
859 }
860
861 const (
862         // tcpMSSEstimate is a conservative estimate of the TCP maximum segment
863         // size (MSS). A constant is used, rather than querying the kernel for
864         // the actual MSS, to avoid complexity. The value here is the IPv6
865         // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40
866         // bytes) and a TCP header with timestamps (32 bytes).
867         tcpMSSEstimate = 1208
868
869         // recordSizeBoostThreshold is the number of bytes of application data
870         // sent after which the TLS record size will be increased to the
871         // maximum.
872         recordSizeBoostThreshold = 128 * 1024
873 )
874
875 // maxPayloadSizeForWrite returns the maximum TLS payload size to use for the
876 // next application data record. There is the following trade-off:
877 //
878 //   - For latency-sensitive applications, such as web browsing, each TLS
879 //     record should fit in one TCP segment.
880 //   - For throughput-sensitive applications, such as large file transfers,
881 //     larger TLS records better amortize framing and encryption overheads.
882 //
883 // A simple heuristic that works well in practice is to use small records for
884 // the first 1MB of data, then use larger records for subsequent data, and
885 // reset back to smaller records after the connection becomes idle. See "High
886 // Performance Web Networking", Chapter 4, or:
887 // https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/
888 //
889 // In the interests of simplicity and determinism, this code does not attempt
890 // to reset the record size once the connection is idle, however.
891 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
892         if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
893                 return maxPlaintext
894         }
895
896         if c.bytesSent >= recordSizeBoostThreshold {
897                 return maxPlaintext
898         }
899
900         // Subtract TLS overheads to get the maximum payload size.
901         payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
902         if c.out.cipher != nil {
903                 switch ciph := c.out.cipher.(type) {
904                 case cipher.Stream:
905                         payloadBytes -= c.out.mac.Size()
906                 case cipher.AEAD:
907                         payloadBytes -= ciph.Overhead()
908                 case cbcMode:
909                         blockSize := ciph.BlockSize()
910                         // The payload must fit in a multiple of blockSize, with
911                         // room for at least one padding byte.
912                         payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
913                         // The MAC is appended before padding so affects the
914                         // payload size directly.
915                         payloadBytes -= c.out.mac.Size()
916                 default:
917                         panic("unknown cipher type")
918                 }
919         }
920         if c.vers == VersionTLS13 {
921                 payloadBytes-- // encrypted ContentType
922         }
923
924         // Allow packet growth in arithmetic progression up to max.
925         pkt := c.packetsSent
926         c.packetsSent++
927         if pkt > 1000 {
928                 return maxPlaintext // avoid overflow in multiply below
929         }
930
931         n := payloadBytes * int(pkt+1)
932         if n > maxPlaintext {
933                 n = maxPlaintext
934         }
935         return n
936 }
937
938 func (c *Conn) write(data []byte) (int, error) {
939         if c.buffering {
940                 c.sendBuf = append(c.sendBuf, data...)
941                 return len(data), nil
942         }
943
944         n, err := c.conn.Write(data)
945         c.bytesSent += int64(n)
946         return n, err
947 }
948
949 func (c *Conn) flush() (int, error) {
950         if len(c.sendBuf) == 0 {
951                 return 0, nil
952         }
953
954         n, err := c.conn.Write(c.sendBuf)
955         c.bytesSent += int64(n)
956         c.sendBuf = nil
957         c.buffering = false
958         return n, err
959 }
960
961 // outBufPool pools the record-sized scratch buffers used by writeRecordLocked.
962 var outBufPool = sync.Pool{
963         New: func() any {
964                 return new([]byte)
965         },
966 }
967
968 // writeRecordLocked writes a TLS record with the given type and payload to the
969 // connection and updates the record layer state.
970 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
971         if c.quic != nil {
972                 if typ != recordTypeHandshake {
973                         return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
974                 }
975                 c.quicWriteCryptoData(c.out.level, data)
976                 if !c.buffering {
977                         if _, err := c.flush(); err != nil {
978                                 return 0, err
979                         }
980                 }
981                 return len(data), nil
982         }
983
984         outBufPtr := outBufPool.Get().(*[]byte)
985         outBuf := *outBufPtr
986         defer func() {
987                 // You might be tempted to simplify this by just passing &outBuf to Put,
988                 // but that would make the local copy of the outBuf slice header escape
989                 // to the heap, causing an allocation. Instead, we keep around the
990                 // pointer to the slice header returned by Get, which is already on the
991                 // heap, and overwrite and return that.
992                 *outBufPtr = outBuf
993                 outBufPool.Put(outBufPtr)
994         }()
995
996         var n int
997         for len(data) > 0 {
998                 m := len(data)
999                 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
1000                         m = maxPayload
1001                 }
1002
1003                 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
1004                 outBuf[0] = byte(typ)
1005                 vers := c.vers
1006                 if vers == 0 {
1007                         // Some TLS servers fail if the record version is
1008                         // greater than TLS 1.0 for the initial ClientHello.
1009                         vers = VersionTLS10
1010                 } else if vers == VersionTLS13 {
1011                         // TLS 1.3 froze the record layer version to 1.2.
1012                         // See RFC 8446, Section 5.1.
1013                         vers = VersionTLS12
1014                 }
1015                 outBuf[1] = byte(vers >> 8)
1016                 outBuf[2] = byte(vers)
1017                 outBuf[3] = byte(m >> 8)
1018                 outBuf[4] = byte(m)
1019
1020                 var err error
1021                 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
1022                 if err != nil {
1023                         return n, err
1024                 }
1025                 if _, err := c.write(outBuf); err != nil {
1026                         return n, err
1027                 }
1028                 n += m
1029                 data = data[m:]
1030         }
1031
1032         if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
1033                 if err := c.out.changeCipherSpec(); err != nil {
1034                         return n, c.sendAlertLocked(err.(alert))
1035                 }
1036         }
1037
1038         return n, nil
1039 }
1040
1041 // writeHandshakeRecord writes a handshake message to the connection and updates
1042 // the record layer state. If transcript is non-nil the marshalled message is
1043 // written to it.
1044 func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
1045         c.out.Lock()
1046         defer c.out.Unlock()
1047
1048         data, err := msg.marshal()
1049         if err != nil {
1050                 return 0, err
1051         }
1052         if transcript != nil {
1053                 transcript.Write(data)
1054         }
1055
1056         return c.writeRecordLocked(recordTypeHandshake, data)
1057 }
1058
1059 // writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and
1060 // updates the record layer state.
1061 func (c *Conn) writeChangeCipherRecord() error {
1062         c.out.Lock()
1063         defer c.out.Unlock()
1064         _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
1065         return err
1066 }
1067
1068 // readHandshakeBytes reads handshake data until c.hand contains at least n bytes.
1069 func (c *Conn) readHandshakeBytes(n int) error {
1070         if c.quic != nil {
1071                 return c.quicReadHandshakeBytes(n)
1072         }
1073         for c.hand.Len() < n {
1074                 if err := c.readRecord(); err != nil {
1075                         return err
1076                 }
1077         }
1078         return nil
1079 }
1080
1081 // readHandshake reads the next handshake message from
1082 // the record layer. If transcript is non-nil, the message
1083 // is written to the passed transcriptHash.
1084 func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
1085         if err := c.readHandshakeBytes(4); err != nil {
1086                 return nil, err
1087         }
1088         data := c.hand.Bytes()
1089         n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1090         if n > maxHandshake {
1091                 c.sendAlertLocked(alertInternalError)
1092                 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1093         }
1094         if err := c.readHandshakeBytes(4 + n); err != nil {
1095                 return nil, err
1096         }
1097         data = c.hand.Next(4 + n)
1098         return c.unmarshalHandshakeMessage(data, transcript)
1099 }
1100
1101 func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
1102         var m handshakeMessage
1103         switch data[0] {
1104         case typeHelloRequest:
1105                 m = new(helloRequestMsg)
1106         case typeClientHello:
1107                 m = new(clientHelloMsg)
1108         case typeServerHello:
1109                 m = new(serverHelloMsg)
1110         case typeNewSessionTicket:
1111                 if c.vers == VersionTLS13 {
1112                         m = new(newSessionTicketMsgTLS13)
1113                 } else {
1114                         m = new(newSessionTicketMsg)
1115                 }
1116         case typeCertificate:
1117                 if c.vers == VersionTLS13 {
1118                         m = new(certificateMsgTLS13)
1119                 } else {
1120                         m = new(certificateMsg)
1121                 }
1122         case typeCertificateRequest:
1123                 if c.vers == VersionTLS13 {
1124                         m = new(certificateRequestMsgTLS13)
1125                 } else {
1126                         m = &certificateRequestMsg{
1127                                 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1128                         }
1129                 }
1130         case typeCertificateStatus:
1131                 m = new(certificateStatusMsg)
1132         case typeServerKeyExchange:
1133                 m = new(serverKeyExchangeMsg)
1134         case typeServerHelloDone:
1135                 m = new(serverHelloDoneMsg)
1136         case typeClientKeyExchange:
1137                 m = new(clientKeyExchangeMsg)
1138         case typeCertificateVerify:
1139                 m = &certificateVerifyMsg{
1140                         hasSignatureAlgorithm: c.vers >= VersionTLS12,
1141                 }
1142         case typeFinished:
1143                 m = new(finishedMsg)
1144         case typeEncryptedExtensions:
1145                 m = new(encryptedExtensionsMsg)
1146         case typeEndOfEarlyData:
1147                 m = new(endOfEarlyDataMsg)
1148         case typeKeyUpdate:
1149                 m = new(keyUpdateMsg)
1150         default:
1151                 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1152         }
1153
1154         // The handshake message unmarshalers
1155         // expect to be able to keep references to data,
1156         // so pass in a fresh copy that won't be overwritten.
1157         data = append([]byte(nil), data...)
1158
1159         if !m.unmarshal(data) {
1160                 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1161         }
1162
1163         if transcript != nil {
1164                 transcript.Write(data)
1165         }
1166
1167         return m, nil
1168 }
1169
1170 var (
1171         errShutdown = errors.New("tls: protocol is shutdown")
1172 )
1173
1174 // Write writes data to the connection.
1175 //
1176 // As Write calls Handshake, in order to prevent indefinite blocking a deadline
1177 // must be set for both Read and Write before Write is called when the handshake
1178 // has not yet completed. See SetDeadline, SetReadDeadline, and
1179 // SetWriteDeadline.
1180 func (c *Conn) Write(b []byte) (int, error) {
1181         // interlock with Close below
1182         for {
1183                 x := c.activeCall.Load()
1184                 if x&1 != 0 {
1185                         return 0, net.ErrClosed
1186                 }
1187                 if c.activeCall.CompareAndSwap(x, x+2) {
1188                         break
1189                 }
1190         }
1191         defer c.activeCall.Add(-2)
1192
1193         if err := c.Handshake(); err != nil {
1194                 return 0, err
1195         }
1196
1197         c.out.Lock()
1198         defer c.out.Unlock()
1199
1200         if err := c.out.err; err != nil {
1201                 return 0, err
1202         }
1203
1204         if !c.isHandshakeComplete.Load() {
1205                 return 0, alertInternalError
1206         }
1207
1208         if c.closeNotifySent {
1209                 return 0, errShutdown
1210         }
1211
1212         // TLS 1.0 is susceptible to a chosen-plaintext
1213         // attack when using block mode ciphers due to predictable IVs.
1214         // This can be prevented by splitting each Application Data
1215         // record into two records, effectively randomizing the IV.
1216         //
1217         // https://www.openssl.org/~bodo/tls-cbc.txt
1218         // https://bugzilla.mozilla.org/show_bug.cgi?id=665814
1219         // https://www.imperialviolet.org/2012/01/15/beastfollowup.html
1220
1221         var m int
1222         if len(b) > 1 && c.vers == VersionTLS10 {
1223                 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1224                         n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1225                         if err != nil {
1226                                 return n, c.out.setErrorLocked(err)
1227                         }
1228                         m, b = 1, b[1:]
1229                 }
1230         }
1231
1232         n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1233         return n + m, c.out.setErrorLocked(err)
1234 }
1235
1236 // handleRenegotiation processes a HelloRequest handshake message.
1237 func (c *Conn) handleRenegotiation() error {
1238         if c.vers == VersionTLS13 {
1239                 return errors.New("tls: internal error: unexpected renegotiation")
1240         }
1241
1242         msg, err := c.readHandshake(nil)
1243         if err != nil {
1244                 return err
1245         }
1246
1247         helloReq, ok := msg.(*helloRequestMsg)
1248         if !ok {
1249                 c.sendAlert(alertUnexpectedMessage)
1250                 return unexpectedMessageError(helloReq, msg)
1251         }
1252
1253         if !c.isClient {
1254                 return c.sendAlert(alertNoRenegotiation)
1255         }
1256
1257         switch c.config.Renegotiation {
1258         case RenegotiateNever:
1259                 return c.sendAlert(alertNoRenegotiation)
1260         case RenegotiateOnceAsClient:
1261                 if c.handshakes > 1 {
1262                         return c.sendAlert(alertNoRenegotiation)
1263                 }
1264         case RenegotiateFreelyAsClient:
1265                 // Ok.
1266         default:
1267                 c.sendAlert(alertInternalError)
1268                 return errors.New("tls: unknown Renegotiation value")
1269         }
1270
1271         c.handshakeMutex.Lock()
1272         defer c.handshakeMutex.Unlock()
1273
1274         c.isHandshakeComplete.Store(false)
1275         if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1276                 c.handshakes++
1277         }
1278         return c.handshakeErr
1279 }
1280
1281 // handlePostHandshakeMessage processes a handshake message arrived after the
1282 // handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
1283 func (c *Conn) handlePostHandshakeMessage() error {
1284         if c.vers != VersionTLS13 {
1285                 return c.handleRenegotiation()
1286         }
1287
1288         msg, err := c.readHandshake(nil)
1289         if err != nil {
1290                 return err
1291         }
1292         c.retryCount++
1293         if c.retryCount > maxUselessRecords {
1294                 c.sendAlert(alertUnexpectedMessage)
1295                 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1296         }
1297
1298         switch msg := msg.(type) {
1299         case *newSessionTicketMsgTLS13:
1300                 return c.handleNewSessionTicket(msg)
1301         case *keyUpdateMsg:
1302                 return c.handleKeyUpdate(msg)
1303         }
1304         // The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
1305         // as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
1306         // unexpected_message alert here doesn't provide it with enough information to distinguish
1307         // this condition from other unexpected messages. This is probably fine.
1308         c.sendAlert(alertUnexpectedMessage)
1309         return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1310 }
1311
1312 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1313         if c.quic != nil {
1314                 c.sendAlert(alertUnexpectedMessage)
1315                 return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
1316         }
1317
1318         cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1319         if cipherSuite == nil {
1320                 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1321         }
1322
1323         newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1324         c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1325
1326         if keyUpdate.updateRequested {
1327                 c.out.Lock()
1328                 defer c.out.Unlock()
1329
1330                 msg := &keyUpdateMsg{}
1331                 msgBytes, err := msg.marshal()
1332                 if err != nil {
1333                         return err
1334                 }
1335                 _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
1336                 if err != nil {
1337                         // Surface the error at the next write.
1338                         c.out.setErrorLocked(err)
1339                         return nil
1340                 }
1341
1342                 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1343                 c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1344         }
1345
1346         return nil
1347 }
1348
1349 // Read reads data from the connection.
1350 //
1351 // As Read calls Handshake, in order to prevent indefinite blocking a deadline
1352 // must be set for both Read and Write before Read is called when the handshake
1353 // has not yet completed. See SetDeadline, SetReadDeadline, and
1354 // SetWriteDeadline.
1355 func (c *Conn) Read(b []byte) (int, error) {
1356         if err := c.Handshake(); err != nil {
1357                 return 0, err
1358         }
1359         if len(b) == 0 {
1360                 // Put this after Handshake, in case people were calling
1361                 // Read(nil) for the side effect of the Handshake.
1362                 return 0, nil
1363         }
1364
1365         c.in.Lock()
1366         defer c.in.Unlock()
1367
1368         for c.input.Len() == 0 {
1369                 if err := c.readRecord(); err != nil {
1370                         return 0, err
1371                 }
1372                 for c.hand.Len() > 0 {
1373                         if err := c.handlePostHandshakeMessage(); err != nil {
1374                                 return 0, err
1375                         }
1376                 }
1377         }
1378
1379         n, _ := c.input.Read(b)
1380
1381         // If a close-notify alert is waiting, read it so that we can return (n,
1382         // EOF) instead of (n, nil), to signal to the HTTP response reading
1383         // goroutine that the connection is now closed. This eliminates a race
1384         // where the HTTP response reading goroutine would otherwise not observe
1385         // the EOF until its next read, by which time a client goroutine might
1386         // have already tried to reuse the HTTP connection for a new request.
1387         // See https://golang.org/cl/76400046 and https://golang.org/issue/3514
1388         if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1389                 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1390                 if err := c.readRecord(); err != nil {
1391                         return n, err // will be io.EOF on closeNotify
1392                 }
1393         }
1394
1395         return n, nil
1396 }
1397
1398 // Close closes the connection.
1399 func (c *Conn) Close() error {
1400         // Interlock with Conn.Write above.
1401         var x int32
1402         for {
1403                 x = c.activeCall.Load()
1404                 if x&1 != 0 {
1405                         return net.ErrClosed
1406                 }
1407                 if c.activeCall.CompareAndSwap(x, x|1) {
1408                         break
1409                 }
1410         }
1411         if x != 0 {
1412                 // io.Writer and io.Closer should not be used concurrently.
1413                 // If Close is called while a Write is currently in-flight,
1414                 // interpret that as a sign that this Close is really just
1415                 // being used to break the Write and/or clean up resources and
1416                 // avoid sending the alertCloseNotify, which may block
1417                 // waiting on handshakeMutex or the c.out mutex.
1418                 return c.conn.Close()
1419         }
1420
1421         var alertErr error
1422         if c.isHandshakeComplete.Load() {
1423                 if err := c.closeNotify(); err != nil {
1424                         alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1425                 }
1426         }
1427
1428         if err := c.conn.Close(); err != nil {
1429                 return err
1430         }
1431         return alertErr
1432 }
1433
1434 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1435
1436 // CloseWrite shuts down the writing side of the connection. It should only be
1437 // called once the handshake has completed and does not call CloseWrite on the
1438 // underlying connection. Most callers should just use Close.
1439 func (c *Conn) CloseWrite() error {
1440         if !c.isHandshakeComplete.Load() {
1441                 return errEarlyCloseWrite
1442         }
1443
1444         return c.closeNotify()
1445 }
1446
1447 func (c *Conn) closeNotify() error {
1448         c.out.Lock()
1449         defer c.out.Unlock()
1450
1451         if !c.closeNotifySent {
1452                 // Set a Write Deadline to prevent possibly blocking forever.
1453                 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1454                 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1455                 c.closeNotifySent = true
1456                 // Any subsequent writes will fail.
1457                 c.SetWriteDeadline(time.Now())
1458         }
1459         return c.closeNotifyErr
1460 }
1461
1462 // Handshake runs the client or server handshake
1463 // protocol if it has not yet been run.
1464 //
1465 // Most uses of this package need not call Handshake explicitly: the
1466 // first Read or Write will call it automatically.
1467 //
1468 // For control over canceling or setting a timeout on a handshake, use
1469 // HandshakeContext or the Dialer's DialContext method instead.
1470 func (c *Conn) Handshake() error {
1471         return c.HandshakeContext(context.Background())
1472 }
1473
1474 // HandshakeContext runs the client or server handshake
1475 // protocol if it has not yet been run.
1476 //
1477 // The provided Context must be non-nil. If the context is canceled before
1478 // the handshake is complete, the handshake is interrupted and an error is returned.
1479 // Once the handshake has completed, cancellation of the context will not affect the
1480 // connection.
1481 //
1482 // Most uses of this package need not call HandshakeContext explicitly: the
1483 // first Read or Write will call it automatically.
1484 func (c *Conn) HandshakeContext(ctx context.Context) error {
1485         // Delegate to unexported method for named return
1486         // without confusing documented signature.
1487         return c.handshakeContext(ctx)
1488 }
1489
1490 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1491         // Fast sync/atomic-based exit if there is no handshake in flight and the
1492         // last one succeeded without an error. Avoids the expensive context setup
1493         // and mutex for most Read and Write calls.
1494         if c.isHandshakeComplete.Load() {
1495                 return nil
1496         }
1497
1498         handshakeCtx, cancel := context.WithCancel(ctx)
1499         // Note: defer this before starting the "interrupter" goroutine
1500         // so that we can tell the difference between the input being canceled and
1501         // this cancellation. In the former case, we need to close the connection.
1502         defer cancel()
1503
1504         if c.quic != nil {
1505                 c.quic.cancelc = handshakeCtx.Done()
1506                 c.quic.cancel = cancel
1507         } else if ctx.Done() != nil {
1508                 // Start the "interrupter" goroutine, if this context might be canceled.
1509                 // (The background context cannot).
1510                 //
1511                 // The interrupter goroutine waits for the input context to be done and
1512                 // closes the connection if this happens before the function returns.
1513                 done := make(chan struct{})
1514                 interruptRes := make(chan error, 1)
1515                 defer func() {
1516                         close(done)
1517                         if ctxErr := <-interruptRes; ctxErr != nil {
1518                                 // Return context error to user.
1519                                 ret = ctxErr
1520                         }
1521                 }()
1522                 go func() {
1523                         select {
1524                         case <-handshakeCtx.Done():
1525                                 // Close the connection, discarding the error
1526                                 _ = c.conn.Close()
1527                                 interruptRes <- handshakeCtx.Err()
1528                         case <-done:
1529                                 interruptRes <- nil
1530                         }
1531                 }()
1532         }
1533
1534         c.handshakeMutex.Lock()
1535         defer c.handshakeMutex.Unlock()
1536
1537         if err := c.handshakeErr; err != nil {
1538                 return err
1539         }
1540         if c.isHandshakeComplete.Load() {
1541                 return nil
1542         }
1543
1544         c.in.Lock()
1545         defer c.in.Unlock()
1546
1547         c.handshakeErr = c.handshakeFn(handshakeCtx)
1548         if c.handshakeErr == nil {
1549                 c.handshakes++
1550         } else {
1551                 // If an error occurred during the handshake try to flush the
1552                 // alert that might be left in the buffer.
1553                 c.flush()
1554         }
1555
1556         if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
1557                 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1558         }
1559         if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
1560                 panic("tls: internal error: handshake returned an error but is marked successful")
1561         }
1562
1563         if c.quic != nil {
1564                 if c.handshakeErr == nil {
1565                         c.quicHandshakeComplete()
1566                         // Provide the 1-RTT read secret now that the handshake is complete.
1567                         // The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
1568                         // the handshake (RFC 9001, Section 5.7).
1569                         c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
1570                 } else {
1571                         var a alert
1572                         c.out.Lock()
1573                         if !errors.As(c.out.err, &a) {
1574                                 a = alertInternalError
1575                         }
1576                         c.out.Unlock()
1577                         // Return an error which wraps both the handshake error and
1578                         // any alert error we may have sent, or alertInternalError
1579                         // if we didn't send an alert.
1580                         // Truncate the text of the alert to 0 characters.
1581                         c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
1582                 }
1583                 close(c.quic.blockedc)
1584                 close(c.quic.signalc)
1585         }
1586
1587         return c.handshakeErr
1588 }
1589
1590 // ConnectionState returns basic TLS details about the connection.
1591 func (c *Conn) ConnectionState() ConnectionState {
1592         c.handshakeMutex.Lock()
1593         defer c.handshakeMutex.Unlock()
1594         return c.connectionStateLocked()
1595 }
1596
1597 func (c *Conn) connectionStateLocked() ConnectionState {
1598         var state ConnectionState
1599         state.HandshakeComplete = c.isHandshakeComplete.Load()
1600         state.Version = c.vers
1601         state.NegotiatedProtocol = c.clientProtocol
1602         state.DidResume = c.didResume
1603         state.NegotiatedProtocolIsMutual = true
1604         state.ServerName = c.serverName
1605         state.CipherSuite = c.cipherSuite
1606         state.PeerCertificates = c.peerCertificates
1607         state.VerifiedChains = c.verifiedChains
1608         state.SignedCertificateTimestamps = c.scts
1609         state.OCSPResponse = c.ocspResponse
1610         if (!c.didResume || c.extMasterSecret) && c.vers != VersionTLS13 {
1611                 if c.clientFinishedIsFirst {
1612                         state.TLSUnique = c.clientFinished[:]
1613                 } else {
1614                         state.TLSUnique = c.serverFinished[:]
1615                 }
1616         }
1617         if c.config.Renegotiation != RenegotiateNever {
1618                 state.ekm = noExportedKeyingMaterial
1619         } else {
1620                 state.ekm = c.ekm
1621         }
1622         return state
1623 }
1624
1625 // OCSPResponse returns the stapled OCSP response from the TLS server, if
1626 // any. (Only valid for client connections.)
1627 func (c *Conn) OCSPResponse() []byte {
1628         c.handshakeMutex.Lock()
1629         defer c.handshakeMutex.Unlock()
1630
1631         return c.ocspResponse
1632 }
1633
1634 // VerifyHostname checks that the peer certificate chain is valid for
1635 // connecting to host. If so, it returns nil; if not, it returns an error
1636 // describing the problem.
1637 func (c *Conn) VerifyHostname(host string) error {
1638         c.handshakeMutex.Lock()
1639         defer c.handshakeMutex.Unlock()
1640         if !c.isClient {
1641                 return errors.New("tls: VerifyHostname called on TLS server connection")
1642         }
1643         if !c.isHandshakeComplete.Load() {
1644                 return errors.New("tls: handshake has not yet been performed")
1645         }
1646         if len(c.verifiedChains) == 0 {
1647                 return errors.New("tls: handshake did not verify certificate chain")
1648         }
1649         return c.peerCertificates[0].VerifyHostname(host)
1650 }