]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/handshake_client_test.go
cf7c09b08faed9d63f2ed049ee78c787b2e5e6c8
[gostls13.git] / src / crypto / tls / handshake_client_test.go
1 // Copyright 2010 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "context"
10         "crypto/rsa"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/binary"
14         "encoding/pem"
15         "errors"
16         "fmt"
17         "io"
18         "math/big"
19         "net"
20         "os"
21         "os/exec"
22         "path/filepath"
23         "reflect"
24         "runtime"
25         "strconv"
26         "strings"
27         "testing"
28         "time"
29 )
30
31 // Note: see comment in handshake_test.go for details of how the reference
32 // tests work.
33
34 // opensslInputEvent enumerates possible inputs that can be sent to an `openssl
35 // s_client` process.
36 type opensslInputEvent int
37
38 const (
39         // opensslRenegotiate causes OpenSSL to request a renegotiation of the
40         // connection.
41         opensslRenegotiate opensslInputEvent = iota
42
43         // opensslSendBanner causes OpenSSL to send the contents of
44         // opensslSentinel on the connection.
45         opensslSendSentinel
46
47         // opensslKeyUpdate causes OpenSSL to send a key update message to the
48         // client and request one back.
49         opensslKeyUpdate
50 )
51
52 const opensslSentinel = "SENTINEL\n"
53
54 type opensslInput chan opensslInputEvent
55
56 func (i opensslInput) Read(buf []byte) (n int, err error) {
57         for event := range i {
58                 switch event {
59                 case opensslRenegotiate:
60                         return copy(buf, []byte("R\n")), nil
61                 case opensslKeyUpdate:
62                         return copy(buf, []byte("K\n")), nil
63                 case opensslSendSentinel:
64                         return copy(buf, []byte(opensslSentinel)), nil
65                 default:
66                         panic("unknown event")
67                 }
68         }
69
70         return 0, io.EOF
71 }
72
73 // opensslOutputSink is an io.Writer that receives the stdout and stderr from an
74 // `openssl` process and sends a value to handshakeComplete or readKeyUpdate
75 // when certain messages are seen.
76 type opensslOutputSink struct {
77         handshakeComplete chan struct{}
78         readKeyUpdate     chan struct{}
79         all               []byte
80         line              []byte
81 }
82
83 func newOpensslOutputSink() *opensslOutputSink {
84         return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil}
85 }
86
87 // opensslEndOfHandshake is a message that the “openssl s_server” tool will
88 // print when a handshake completes if run with “-state”.
89 const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
90
91 // opensslReadKeyUpdate is a message that the “openssl s_server” tool will
92 // print when a KeyUpdate message is received if run with “-state”.
93 const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update"
94
95 func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
96         o.line = append(o.line, data...)
97         o.all = append(o.all, data...)
98
99         for {
100                 line, next, ok := bytes.Cut(o.line, []byte("\n"))
101                 if !ok {
102                         break
103                 }
104
105                 if bytes.Equal([]byte(opensslEndOfHandshake), line) {
106                         o.handshakeComplete <- struct{}{}
107                 }
108                 if bytes.Equal([]byte(opensslReadKeyUpdate), line) {
109                         o.readKeyUpdate <- struct{}{}
110                 }
111                 o.line = next
112         }
113
114         return len(data), nil
115 }
116
117 func (o *opensslOutputSink) String() string {
118         return string(o.all)
119 }
120
121 // clientTest represents a test of the TLS client handshake against a reference
122 // implementation.
123 type clientTest struct {
124         // name is a freeform string identifying the test and the file in which
125         // the expected results will be stored.
126         name string
127         // args, if not empty, contains a series of arguments for the
128         // command to run for the reference server.
129         args []string
130         // config, if not nil, contains a custom Config to use for this test.
131         config *Config
132         // cert, if not empty, contains a DER-encoded certificate for the
133         // reference server.
134         cert []byte
135         // key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or
136         // *ecdsa.PrivateKey which is the private key for the reference server.
137         key any
138         // extensions, if not nil, contains a list of extension data to be returned
139         // from the ServerHello. The data should be in standard TLS format with
140         // a 2-byte uint16 type, 2-byte data length, followed by the extension data.
141         extensions [][]byte
142         // validate, if not nil, is a function that will be called with the
143         // ConnectionState of the resulting connection. It returns a non-nil
144         // error if the ConnectionState is unacceptable.
145         validate func(ConnectionState) error
146         // numRenegotiations is the number of times that the connection will be
147         // renegotiated.
148         numRenegotiations int
149         // renegotiationExpectedToFail, if not zero, is the number of the
150         // renegotiation attempt that is expected to fail.
151         renegotiationExpectedToFail int
152         // checkRenegotiationError, if not nil, is called with any error
153         // arising from renegotiation. It can map expected errors to nil to
154         // ignore them.
155         checkRenegotiationError func(renegotiationNum int, err error) error
156         // sendKeyUpdate will cause the server to send a KeyUpdate message.
157         sendKeyUpdate bool
158 }
159
160 var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"}
161
162 // connFromCommand starts the reference server process, connects to it and
163 // returns a recordingConn for the connection. The stdin return value is an
164 // opensslInput for the stdin of the child process. It must be closed before
165 // Waiting for child.
166 func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
167         cert := testRSACertificate
168         if len(test.cert) > 0 {
169                 cert = test.cert
170         }
171         certPath := tempFile(string(cert))
172         defer os.Remove(certPath)
173
174         var key any = testRSAPrivateKey
175         if test.key != nil {
176                 key = test.key
177         }
178         derBytes, err := x509.MarshalPKCS8PrivateKey(key)
179         if err != nil {
180                 panic(err)
181         }
182
183         var pemOut bytes.Buffer
184         pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes})
185
186         keyPath := tempFile(pemOut.String())
187         defer os.Remove(keyPath)
188
189         var command []string
190         command = append(command, serverCommand...)
191         command = append(command, test.args...)
192         command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
193         // serverPort contains the port that OpenSSL will listen on. OpenSSL
194         // can't take "0" as an argument here so we have to pick a number and
195         // hope that it's not in use on the machine. Since this only occurs
196         // when -update is given and thus when there's a human watching the
197         // test, this isn't too bad.
198         const serverPort = 24323
199         command = append(command, "-accept", strconv.Itoa(serverPort))
200
201         if len(test.extensions) > 0 {
202                 var serverInfo bytes.Buffer
203                 for _, ext := range test.extensions {
204                         pem.Encode(&serverInfo, &pem.Block{
205                                 Type:  fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
206                                 Bytes: ext,
207                         })
208                 }
209                 serverInfoPath := tempFile(serverInfo.String())
210                 defer os.Remove(serverInfoPath)
211                 command = append(command, "-serverinfo", serverInfoPath)
212         }
213
214         if test.numRenegotiations > 0 || test.sendKeyUpdate {
215                 found := false
216                 for _, flag := range command[1:] {
217                         if flag == "-state" {
218                                 found = true
219                                 break
220                         }
221                 }
222
223                 if !found {
224                         panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate")
225                 }
226         }
227
228         cmd := exec.Command(command[0], command[1:]...)
229         stdin = opensslInput(make(chan opensslInputEvent))
230         cmd.Stdin = stdin
231         out := newOpensslOutputSink()
232         cmd.Stdout = out
233         cmd.Stderr = out
234         if err := cmd.Start(); err != nil {
235                 return nil, nil, nil, nil, err
236         }
237
238         // OpenSSL does print an "ACCEPT" banner, but it does so *before*
239         // opening the listening socket, so we can't use that to wait until it
240         // has started listening. Thus we are forced to poll until we get a
241         // connection.
242         var tcpConn net.Conn
243         for i := uint(0); i < 5; i++ {
244                 tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
245                         IP:   net.IPv4(127, 0, 0, 1),
246                         Port: serverPort,
247                 })
248                 if err == nil {
249                         break
250                 }
251                 time.Sleep((1 << i) * 5 * time.Millisecond)
252         }
253         if err != nil {
254                 close(stdin)
255                 cmd.Process.Kill()
256                 err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out)
257                 return nil, nil, nil, nil, err
258         }
259
260         record := &recordingConn{
261                 Conn: tcpConn,
262         }
263
264         return record, cmd, stdin, out, nil
265 }
266
267 func (test *clientTest) dataPath() string {
268         return filepath.Join("testdata", "Client-"+test.name)
269 }
270
271 func (test *clientTest) loadData() (flows [][]byte, err error) {
272         in, err := os.Open(test.dataPath())
273         if err != nil {
274                 return nil, err
275         }
276         defer in.Close()
277         return parseTestData(in)
278 }
279
280 func (test *clientTest) run(t *testing.T, write bool) {
281         var clientConn, serverConn net.Conn
282         var recordingConn *recordingConn
283         var childProcess *exec.Cmd
284         var stdin opensslInput
285         var stdout *opensslOutputSink
286
287         if write {
288                 var err error
289                 recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
290                 if err != nil {
291                         t.Fatalf("Failed to start subcommand: %s", err)
292                 }
293                 clientConn = recordingConn
294                 defer func() {
295                         if t.Failed() {
296                                 t.Logf("OpenSSL output:\n\n%s", stdout.all)
297                         }
298                 }()
299         } else {
300                 clientConn, serverConn = localPipe(t)
301         }
302
303         doneChan := make(chan bool)
304         defer func() {
305                 clientConn.Close()
306                 <-doneChan
307         }()
308         go func() {
309                 defer close(doneChan)
310
311                 config := test.config
312                 if config == nil {
313                         config = testConfig
314                 }
315                 client := Client(clientConn, config)
316                 defer client.Close()
317
318                 if _, err := client.Write([]byte("hello\n")); err != nil {
319                         t.Errorf("Client.Write failed: %s", err)
320                         return
321                 }
322
323                 for i := 1; i <= test.numRenegotiations; i++ {
324                         // The initial handshake will generate a
325                         // handshakeComplete signal which needs to be quashed.
326                         if i == 1 && write {
327                                 <-stdout.handshakeComplete
328                         }
329
330                         // OpenSSL will try to interleave application data and
331                         // a renegotiation if we send both concurrently.
332                         // Therefore: ask OpensSSL to start a renegotiation, run
333                         // a goroutine to call client.Read and thus process the
334                         // renegotiation request, watch for OpenSSL's stdout to
335                         // indicate that the handshake is complete and,
336                         // finally, have OpenSSL write something to cause
337                         // client.Read to complete.
338                         if write {
339                                 stdin <- opensslRenegotiate
340                         }
341
342                         signalChan := make(chan struct{})
343
344                         go func() {
345                                 defer close(signalChan)
346
347                                 buf := make([]byte, 256)
348                                 n, err := client.Read(buf)
349
350                                 if test.checkRenegotiationError != nil {
351                                         newErr := test.checkRenegotiationError(i, err)
352                                         if err != nil && newErr == nil {
353                                                 return
354                                         }
355                                         err = newErr
356                                 }
357
358                                 if err != nil {
359                                         t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
360                                         return
361                                 }
362
363                                 buf = buf[:n]
364                                 if !bytes.Equal([]byte(opensslSentinel), buf) {
365                                         t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
366                                 }
367
368                                 if expected := i + 1; client.handshakes != expected {
369                                         t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
370                                 }
371                         }()
372
373                         if write && test.renegotiationExpectedToFail != i {
374                                 <-stdout.handshakeComplete
375                                 stdin <- opensslSendSentinel
376                         }
377                         <-signalChan
378                 }
379
380                 if test.sendKeyUpdate {
381                         if write {
382                                 <-stdout.handshakeComplete
383                                 stdin <- opensslKeyUpdate
384                         }
385
386                         doneRead := make(chan struct{})
387
388                         go func() {
389                                 defer close(doneRead)
390
391                                 buf := make([]byte, 256)
392                                 n, err := client.Read(buf)
393
394                                 if err != nil {
395                                         t.Errorf("Client.Read failed after KeyUpdate: %s", err)
396                                         return
397                                 }
398
399                                 buf = buf[:n]
400                                 if !bytes.Equal([]byte(opensslSentinel), buf) {
401                                         t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
402                                 }
403                         }()
404
405                         if write {
406                                 // There's no real reason to wait for the client KeyUpdate to
407                                 // send data with the new server keys, except that s_server
408                                 // drops writes if they are sent at the wrong time.
409                                 <-stdout.readKeyUpdate
410                                 stdin <- opensslSendSentinel
411                         }
412                         <-doneRead
413
414                         if _, err := client.Write([]byte("hello again\n")); err != nil {
415                                 t.Errorf("Client.Write failed: %s", err)
416                                 return
417                         }
418                 }
419
420                 if test.validate != nil {
421                         if err := test.validate(client.ConnectionState()); err != nil {
422                                 t.Errorf("validate callback returned error: %s", err)
423                         }
424                 }
425
426                 // If the server sent us an alert after our last flight, give it a
427                 // chance to arrive.
428                 if write && test.renegotiationExpectedToFail == 0 {
429                         if err := peekError(client); err != nil {
430                                 t.Errorf("final Read returned an error: %s", err)
431                         }
432                 }
433         }()
434
435         if !write {
436                 flows, err := test.loadData()
437                 if err != nil {
438                         t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
439                 }
440                 for i, b := range flows {
441                         if i%2 == 1 {
442                                 if *fast {
443                                         serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
444                                 } else {
445                                         serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
446                                 }
447                                 serverConn.Write(b)
448                                 continue
449                         }
450                         bb := make([]byte, len(b))
451                         if *fast {
452                                 serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
453                         } else {
454                                 serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
455                         }
456                         _, err := io.ReadFull(serverConn, bb)
457                         if err != nil {
458                                 t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
459                         }
460                         if !bytes.Equal(b, bb) {
461                                 t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
462                         }
463                 }
464         }
465
466         <-doneChan
467         if !write {
468                 serverConn.Close()
469         }
470
471         if write {
472                 path := test.dataPath()
473                 out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
474                 if err != nil {
475                         t.Fatalf("Failed to create output file: %s", err)
476                 }
477                 defer out.Close()
478                 recordingConn.Close()
479                 close(stdin)
480                 childProcess.Process.Kill()
481                 childProcess.Wait()
482                 if len(recordingConn.flows) < 3 {
483                         t.Fatalf("Client connection didn't work")
484                 }
485                 recordingConn.WriteTo(out)
486                 t.Logf("Wrote %s\n", path)
487         }
488 }
489
490 // peekError does a read with a short timeout to check if the next read would
491 // cause an error, for example if there is an alert waiting on the wire.
492 func peekError(conn net.Conn) error {
493         conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
494         if n, err := conn.Read(make([]byte, 1)); n != 0 {
495                 return errors.New("unexpectedly read data")
496         } else if err != nil {
497                 if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
498                         return err
499                 }
500         }
501         return nil
502 }
503
504 func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) {
505         // Make a deep copy of the template before going parallel.
506         test := *template
507         if template.config != nil {
508                 test.config = template.config.Clone()
509         }
510         test.name = version + "-" + test.name
511         test.args = append([]string{option}, test.args...)
512
513         runTestAndUpdateIfNeeded(t, version, test.run, false)
514 }
515
516 func runClientTestTLS10(t *testing.T, template *clientTest) {
517         runClientTestForVersion(t, template, "TLSv10", "-tls1")
518 }
519
520 func runClientTestTLS11(t *testing.T, template *clientTest) {
521         runClientTestForVersion(t, template, "TLSv11", "-tls1_1")
522 }
523
524 func runClientTestTLS12(t *testing.T, template *clientTest) {
525         runClientTestForVersion(t, template, "TLSv12", "-tls1_2")
526 }
527
528 func runClientTestTLS13(t *testing.T, template *clientTest) {
529         runClientTestForVersion(t, template, "TLSv13", "-tls1_3")
530 }
531
532 func TestHandshakeClientRSARC4(t *testing.T) {
533         test := &clientTest{
534                 name: "RSA-RC4",
535                 args: []string{"-cipher", "RC4-SHA"},
536         }
537         runClientTestTLS10(t, test)
538         runClientTestTLS11(t, test)
539         runClientTestTLS12(t, test)
540 }
541
542 func TestHandshakeClientRSAAES128GCM(t *testing.T) {
543         test := &clientTest{
544                 name: "AES128-GCM-SHA256",
545                 args: []string{"-cipher", "AES128-GCM-SHA256"},
546         }
547         runClientTestTLS12(t, test)
548 }
549
550 func TestHandshakeClientRSAAES256GCM(t *testing.T) {
551         test := &clientTest{
552                 name: "AES256-GCM-SHA384",
553                 args: []string{"-cipher", "AES256-GCM-SHA384"},
554         }
555         runClientTestTLS12(t, test)
556 }
557
558 func TestHandshakeClientECDHERSAAES(t *testing.T) {
559         test := &clientTest{
560                 name: "ECDHE-RSA-AES",
561                 args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"},
562         }
563         runClientTestTLS10(t, test)
564         runClientTestTLS11(t, test)
565         runClientTestTLS12(t, test)
566 }
567
568 func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
569         test := &clientTest{
570                 name: "ECDHE-ECDSA-AES",
571                 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"},
572                 cert: testECDSACertificate,
573                 key:  testECDSAPrivateKey,
574         }
575         runClientTestTLS10(t, test)
576         runClientTestTLS11(t, test)
577         runClientTestTLS12(t, test)
578 }
579
580 func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
581         test := &clientTest{
582                 name: "ECDHE-ECDSA-AES-GCM",
583                 args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
584                 cert: testECDSACertificate,
585                 key:  testECDSAPrivateKey,
586         }
587         runClientTestTLS12(t, test)
588 }
589
590 func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
591         test := &clientTest{
592                 name: "ECDHE-ECDSA-AES256-GCM-SHA384",
593                 args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
594                 cert: testECDSACertificate,
595                 key:  testECDSAPrivateKey,
596         }
597         runClientTestTLS12(t, test)
598 }
599
600 func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
601         test := &clientTest{
602                 name: "AES128-SHA256",
603                 args: []string{"-cipher", "AES128-SHA256"},
604         }
605         runClientTestTLS12(t, test)
606 }
607
608 func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
609         test := &clientTest{
610                 name: "ECDHE-RSA-AES128-SHA256",
611                 args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"},
612         }
613         runClientTestTLS12(t, test)
614 }
615
616 func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
617         test := &clientTest{
618                 name: "ECDHE-ECDSA-AES128-SHA256",
619                 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"},
620                 cert: testECDSACertificate,
621                 key:  testECDSAPrivateKey,
622         }
623         runClientTestTLS12(t, test)
624 }
625
626 func TestHandshakeClientX25519(t *testing.T) {
627         config := testConfig.Clone()
628         config.CurvePreferences = []CurveID{X25519}
629
630         test := &clientTest{
631                 name:   "X25519-ECDHE",
632                 args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"},
633                 config: config,
634         }
635
636         runClientTestTLS12(t, test)
637         runClientTestTLS13(t, test)
638 }
639
640 func TestHandshakeClientP256(t *testing.T) {
641         config := testConfig.Clone()
642         config.CurvePreferences = []CurveID{CurveP256}
643
644         test := &clientTest{
645                 name:   "P256-ECDHE",
646                 args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
647                 config: config,
648         }
649
650         runClientTestTLS12(t, test)
651         runClientTestTLS13(t, test)
652 }
653
654 func TestHandshakeClientHelloRetryRequest(t *testing.T) {
655         config := testConfig.Clone()
656         config.CurvePreferences = []CurveID{X25519, CurveP256}
657
658         test := &clientTest{
659                 name:   "HelloRetryRequest",
660                 args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
661                 config: config,
662         }
663
664         runClientTestTLS13(t, test)
665 }
666
667 func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
668         config := testConfig.Clone()
669         config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
670
671         test := &clientTest{
672                 name:   "ECDHE-RSA-CHACHA20-POLY1305",
673                 args:   []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
674                 config: config,
675         }
676
677         runClientTestTLS12(t, test)
678 }
679
680 func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
681         config := testConfig.Clone()
682         config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
683
684         test := &clientTest{
685                 name:   "ECDHE-ECDSA-CHACHA20-POLY1305",
686                 args:   []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
687                 config: config,
688                 cert:   testECDSACertificate,
689                 key:    testECDSAPrivateKey,
690         }
691
692         runClientTestTLS12(t, test)
693 }
694
695 func TestHandshakeClientAES128SHA256(t *testing.T) {
696         test := &clientTest{
697                 name: "AES128-SHA256",
698                 args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"},
699         }
700         runClientTestTLS13(t, test)
701 }
702 func TestHandshakeClientAES256SHA384(t *testing.T) {
703         test := &clientTest{
704                 name: "AES256-SHA384",
705                 args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"},
706         }
707         runClientTestTLS13(t, test)
708 }
709 func TestHandshakeClientCHACHA20SHA256(t *testing.T) {
710         test := &clientTest{
711                 name: "CHACHA20-SHA256",
712                 args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"},
713         }
714         runClientTestTLS13(t, test)
715 }
716
717 func TestHandshakeClientECDSATLS13(t *testing.T) {
718         test := &clientTest{
719                 name: "ECDSA",
720                 cert: testECDSACertificate,
721                 key:  testECDSAPrivateKey,
722         }
723         runClientTestTLS13(t, test)
724 }
725
726 func TestHandshakeClientEd25519(t *testing.T) {
727         test := &clientTest{
728                 name: "Ed25519",
729                 cert: testEd25519Certificate,
730                 key:  testEd25519PrivateKey,
731         }
732         runClientTestTLS12(t, test)
733         runClientTestTLS13(t, test)
734
735         config := testConfig.Clone()
736         cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM))
737         config.Certificates = []Certificate{cert}
738
739         test = &clientTest{
740                 name:   "ClientCert-Ed25519",
741                 args:   []string{"-Verify", "1"},
742                 config: config,
743         }
744
745         runClientTestTLS12(t, test)
746         runClientTestTLS13(t, test)
747 }
748
749 func TestHandshakeClientCertRSA(t *testing.T) {
750         config := testConfig.Clone()
751         cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
752         config.Certificates = []Certificate{cert}
753
754         test := &clientTest{
755                 name:   "ClientCert-RSA-RSA",
756                 args:   []string{"-cipher", "AES128", "-Verify", "1"},
757                 config: config,
758         }
759
760         runClientTestTLS10(t, test)
761         runClientTestTLS12(t, test)
762
763         test = &clientTest{
764                 name:   "ClientCert-RSA-ECDSA",
765                 args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
766                 config: config,
767                 cert:   testECDSACertificate,
768                 key:    testECDSAPrivateKey,
769         }
770
771         runClientTestTLS10(t, test)
772         runClientTestTLS12(t, test)
773         runClientTestTLS13(t, test)
774
775         test = &clientTest{
776                 name:   "ClientCert-RSA-AES256-GCM-SHA384",
777                 args:   []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"},
778                 config: config,
779                 cert:   testRSACertificate,
780                 key:    testRSAPrivateKey,
781         }
782
783         runClientTestTLS12(t, test)
784 }
785
786 func TestHandshakeClientCertECDSA(t *testing.T) {
787         config := testConfig.Clone()
788         cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
789         config.Certificates = []Certificate{cert}
790
791         test := &clientTest{
792                 name:   "ClientCert-ECDSA-RSA",
793                 args:   []string{"-cipher", "AES128", "-Verify", "1"},
794                 config: config,
795         }
796
797         runClientTestTLS10(t, test)
798         runClientTestTLS12(t, test)
799         runClientTestTLS13(t, test)
800
801         test = &clientTest{
802                 name:   "ClientCert-ECDSA-ECDSA",
803                 args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
804                 config: config,
805                 cert:   testECDSACertificate,
806                 key:    testECDSAPrivateKey,
807         }
808
809         runClientTestTLS10(t, test)
810         runClientTestTLS12(t, test)
811 }
812
813 // TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both
814 // client and server certificates. It also serves from both sides a certificate
815 // signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation
816 // works.
817 func TestHandshakeClientCertRSAPSS(t *testing.T) {
818         cert, err := x509.ParseCertificate(testRSAPSSCertificate)
819         if err != nil {
820                 panic(err)
821         }
822         rootCAs := x509.NewCertPool()
823         rootCAs.AddCert(cert)
824
825         config := testConfig.Clone()
826         // Use GetClientCertificate to bypass the client certificate selection logic.
827         config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) {
828                 return &Certificate{
829                         Certificate: [][]byte{testRSAPSSCertificate},
830                         PrivateKey:  testRSAPrivateKey,
831                 }, nil
832         }
833         config.RootCAs = rootCAs
834
835         test := &clientTest{
836                 name: "ClientCert-RSA-RSAPSS",
837                 args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
838                         "rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"},
839                 config: config,
840                 cert:   testRSAPSSCertificate,
841                 key:    testRSAPrivateKey,
842         }
843         runClientTestTLS12(t, test)
844         runClientTestTLS13(t, test)
845 }
846
847 func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) {
848         config := testConfig.Clone()
849         cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
850         config.Certificates = []Certificate{cert}
851
852         test := &clientTest{
853                 name: "ClientCert-RSA-RSAPKCS1v15",
854                 args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
855                         "rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"},
856                 config: config,
857         }
858
859         runClientTestTLS12(t, test)
860 }
861
862 func TestClientKeyUpdate(t *testing.T) {
863         test := &clientTest{
864                 name:          "KeyUpdate",
865                 args:          []string{"-state"},
866                 sendKeyUpdate: true,
867         }
868         runClientTestTLS13(t, test)
869 }
870
871 func TestResumption(t *testing.T) {
872         t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
873         t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
874 }
875
876 func testResumption(t *testing.T, version uint16) {
877         if testing.Short() {
878                 t.Skip("skipping in -short mode")
879         }
880         serverConfig := &Config{
881                 MaxVersion:   version,
882                 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
883                 Certificates: testConfig.Certificates,
884         }
885
886         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
887         if err != nil {
888                 panic(err)
889         }
890
891         rootCAs := x509.NewCertPool()
892         rootCAs.AddCert(issuer)
893
894         clientConfig := &Config{
895                 MaxVersion:         version,
896                 CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
897                 ClientSessionCache: NewLRUClientSessionCache(32),
898                 RootCAs:            rootCAs,
899                 ServerName:         "example.golang",
900         }
901
902         testResumeState := func(test string, didResume bool) {
903                 _, hs, err := testHandshake(t, clientConfig, serverConfig)
904                 if err != nil {
905                         t.Fatalf("%s: handshake failed: %s", test, err)
906                 }
907                 if hs.DidResume != didResume {
908                         t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
909                 }
910                 if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
911                         t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
912                 }
913                 if got, want := hs.ServerName, clientConfig.ServerName; got != want {
914                         t.Errorf("%s: server name %s, want %s", test, got, want)
915                 }
916         }
917
918         getTicket := func() []byte {
919                 return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.ticket
920         }
921         deleteTicket := func() {
922                 ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
923                 clientConfig.ClientSessionCache.Put(ticketKey, nil)
924         }
925         corruptTicket := func() {
926                 clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.session.secret[0] ^= 0xff
927         }
928         randomKey := func() [32]byte {
929                 var k [32]byte
930                 if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
931                         t.Fatalf("Failed to read new SessionTicketKey: %s", err)
932                 }
933                 return k
934         }
935
936         testResumeState("Handshake", false)
937         ticket := getTicket()
938         testResumeState("Resume", true)
939         if bytes.Equal(ticket, getTicket()) {
940                 t.Fatal("ticket didn't change after resumption")
941         }
942
943         // An old session ticket is replaced with a ticket encrypted with a fresh key.
944         ticket = getTicket()
945         serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
946         testResumeState("ResumeWithOldTicket", true)
947         if bytes.Equal(ticket, getTicket()) {
948                 t.Fatal("old first ticket matches the fresh one")
949         }
950
951         // Once the session master secret is expired, a full handshake should occur.
952         ticket = getTicket()
953         serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
954         testResumeState("ResumeWithExpiredTicket", false)
955         if bytes.Equal(ticket, getTicket()) {
956                 t.Fatal("expired first ticket matches the fresh one")
957         }
958
959         serverConfig.Time = func() time.Time { return time.Now() } // reset the time back
960         key1 := randomKey()
961         serverConfig.SetSessionTicketKeys([][32]byte{key1})
962
963         testResumeState("InvalidSessionTicketKey", false)
964         testResumeState("ResumeAfterInvalidSessionTicketKey", true)
965
966         key2 := randomKey()
967         serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
968         ticket = getTicket()
969         testResumeState("KeyChange", true)
970         if bytes.Equal(ticket, getTicket()) {
971                 t.Fatal("new ticket wasn't included while resuming")
972         }
973         testResumeState("KeyChangeFinish", true)
974
975         // Age the session ticket a bit, but not yet expired.
976         serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
977         testResumeState("OldSessionTicket", true)
978         ticket = getTicket()
979         // Expire the session ticket, which would force a full handshake.
980         serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
981         testResumeState("ExpiredSessionTicket", false)
982         if bytes.Equal(ticket, getTicket()) {
983                 t.Fatal("new ticket wasn't provided after old ticket expired")
984         }
985
986         // Age the session ticket a bit at a time, but don't expire it.
987         d := 0 * time.Hour
988         for i := 0; i < 13; i++ {
989                 d += 12 * time.Hour
990                 serverConfig.Time = func() time.Time { return time.Now().Add(d) }
991                 testResumeState("OldSessionTicket", true)
992         }
993         // Expire it (now a little more than 7 days) and make sure a full
994         // handshake occurs for TLS 1.2. Resumption should still occur for
995         // TLS 1.3 since the client should be using a fresh ticket sent over
996         // by the server.
997         d += 12 * time.Hour
998         serverConfig.Time = func() time.Time { return time.Now().Add(d) }
999         if version == VersionTLS13 {
1000                 testResumeState("ExpiredSessionTicket", true)
1001         } else {
1002                 testResumeState("ExpiredSessionTicket", false)
1003         }
1004         if bytes.Equal(ticket, getTicket()) {
1005                 t.Fatal("new ticket wasn't provided after old ticket expired")
1006         }
1007
1008         // Reset serverConfig to ensure that calling SetSessionTicketKeys
1009         // before the serverConfig is used works.
1010         serverConfig = &Config{
1011                 MaxVersion:   version,
1012                 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
1013                 Certificates: testConfig.Certificates,
1014         }
1015         serverConfig.SetSessionTicketKeys([][32]byte{key2})
1016
1017         testResumeState("FreshConfig", true)
1018
1019         // In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF
1020         // hash matches. Also, Config.CipherSuites does not apply to TLS 1.3.
1021         if version != VersionTLS13 {
1022                 clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
1023                 testResumeState("DifferentCipherSuite", false)
1024                 testResumeState("DifferentCipherSuiteRecovers", true)
1025         }
1026
1027         deleteTicket()
1028         testResumeState("WithoutSessionTicket", false)
1029
1030         // In TLS 1.3, HelloRetryRequest is sent after incorrect key share.
1031         // See https://www.rfc-editor.org/rfc/rfc8446#page-14.
1032         if version == VersionTLS13 {
1033                 deleteTicket()
1034                 serverConfig = &Config{
1035                         // Use a different curve than the client to force a HelloRetryRequest.
1036                         CurvePreferences: []CurveID{CurveP521, CurveP384, CurveP256},
1037                         MaxVersion:       version,
1038                         Certificates:     testConfig.Certificates,
1039                 }
1040                 testResumeState("InitialHandshake", false)
1041                 testResumeState("WithHelloRetryRequest", true)
1042
1043                 // Reset serverConfig back.
1044                 serverConfig = &Config{
1045                         MaxVersion:   version,
1046                         CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
1047                         Certificates: testConfig.Certificates,
1048                 }
1049         }
1050
1051         // Session resumption should work when using client certificates
1052         deleteTicket()
1053         serverConfig.ClientCAs = rootCAs
1054         serverConfig.ClientAuth = RequireAndVerifyClientCert
1055         clientConfig.Certificates = serverConfig.Certificates
1056         testResumeState("InitialHandshake", false)
1057         testResumeState("WithClientCertificates", true)
1058         serverConfig.ClientAuth = NoClientCert
1059
1060         // Tickets should be removed from the session cache on TLS handshake
1061         // failure, and the client should recover from a corrupted PSK
1062         testResumeState("FetchTicketToCorrupt", false)
1063         corruptTicket()
1064         _, _, err = testHandshake(t, clientConfig, serverConfig)
1065         if err == nil {
1066                 t.Fatalf("handshake did not fail with a corrupted client secret")
1067         }
1068         testResumeState("AfterHandshakeFailure", false)
1069
1070         clientConfig.ClientSessionCache = nil
1071         testResumeState("WithoutSessionCache", false)
1072 }
1073
1074 func TestLRUClientSessionCache(t *testing.T) {
1075         // Initialize cache of capacity 4.
1076         cache := NewLRUClientSessionCache(4)
1077         cs := make([]ClientSessionState, 6)
1078         keys := []string{"0", "1", "2", "3", "4", "5", "6"}
1079
1080         // Add 4 entries to the cache and look them up.
1081         for i := 0; i < 4; i++ {
1082                 cache.Put(keys[i], &cs[i])
1083         }
1084         for i := 0; i < 4; i++ {
1085                 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1086                         t.Fatalf("session cache failed lookup for added key: %s", keys[i])
1087                 }
1088         }
1089
1090         // Add 2 more entries to the cache. First 2 should be evicted.
1091         for i := 4; i < 6; i++ {
1092                 cache.Put(keys[i], &cs[i])
1093         }
1094         for i := 0; i < 2; i++ {
1095                 if s, ok := cache.Get(keys[i]); ok || s != nil {
1096                         t.Fatalf("session cache should have evicted key: %s", keys[i])
1097                 }
1098         }
1099
1100         // Touch entry 2. LRU should evict 3 next.
1101         cache.Get(keys[2])
1102         cache.Put(keys[0], &cs[0])
1103         if s, ok := cache.Get(keys[3]); ok || s != nil {
1104                 t.Fatalf("session cache should have evicted key 3")
1105         }
1106
1107         // Update entry 0 in place.
1108         cache.Put(keys[0], &cs[3])
1109         if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
1110                 t.Fatalf("session cache failed update for key 0")
1111         }
1112
1113         // Calling Put with a nil entry deletes the key.
1114         cache.Put(keys[0], nil)
1115         if _, ok := cache.Get(keys[0]); ok {
1116                 t.Fatalf("session cache failed to delete key 0")
1117         }
1118
1119         // Delete entry 2. LRU should keep 4 and 5
1120         cache.Put(keys[2], nil)
1121         if _, ok := cache.Get(keys[2]); ok {
1122                 t.Fatalf("session cache failed to delete key 4")
1123         }
1124         for i := 4; i < 6; i++ {
1125                 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1126                         t.Fatalf("session cache should not have deleted key: %s", keys[i])
1127                 }
1128         }
1129 }
1130
1131 func TestKeyLogTLS12(t *testing.T) {
1132         var serverBuf, clientBuf bytes.Buffer
1133
1134         clientConfig := testConfig.Clone()
1135         clientConfig.KeyLogWriter = &clientBuf
1136         clientConfig.MaxVersion = VersionTLS12
1137
1138         serverConfig := testConfig.Clone()
1139         serverConfig.KeyLogWriter = &serverBuf
1140         serverConfig.MaxVersion = VersionTLS12
1141
1142         c, s := localPipe(t)
1143         done := make(chan bool)
1144
1145         go func() {
1146                 defer close(done)
1147
1148                 if err := Server(s, serverConfig).Handshake(); err != nil {
1149                         t.Errorf("server: %s", err)
1150                         return
1151                 }
1152                 s.Close()
1153         }()
1154
1155         if err := Client(c, clientConfig).Handshake(); err != nil {
1156                 t.Fatalf("client: %s", err)
1157         }
1158
1159         c.Close()
1160         <-done
1161
1162         checkKeylogLine := func(side, loggedLine string) {
1163                 if len(loggedLine) == 0 {
1164                         t.Fatalf("%s: no keylog line was produced", side)
1165                 }
1166                 const expectedLen = 13 /* "CLIENT_RANDOM" */ +
1167                         1 /* space */ +
1168                         32*2 /* hex client nonce */ +
1169                         1 /* space */ +
1170                         48*2 /* hex master secret */ +
1171                         1 /* new line */
1172                 if len(loggedLine) != expectedLen {
1173                         t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
1174                 }
1175                 if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
1176                         t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
1177                 }
1178         }
1179
1180         checkKeylogLine("client", clientBuf.String())
1181         checkKeylogLine("server", serverBuf.String())
1182 }
1183
1184 func TestKeyLogTLS13(t *testing.T) {
1185         var serverBuf, clientBuf bytes.Buffer
1186
1187         clientConfig := testConfig.Clone()
1188         clientConfig.KeyLogWriter = &clientBuf
1189
1190         serverConfig := testConfig.Clone()
1191         serverConfig.KeyLogWriter = &serverBuf
1192
1193         c, s := localPipe(t)
1194         done := make(chan bool)
1195
1196         go func() {
1197                 defer close(done)
1198
1199                 if err := Server(s, serverConfig).Handshake(); err != nil {
1200                         t.Errorf("server: %s", err)
1201                         return
1202                 }
1203                 s.Close()
1204         }()
1205
1206         if err := Client(c, clientConfig).Handshake(); err != nil {
1207                 t.Fatalf("client: %s", err)
1208         }
1209
1210         c.Close()
1211         <-done
1212
1213         checkKeylogLines := func(side, loggedLines string) {
1214                 loggedLines = strings.TrimSpace(loggedLines)
1215                 lines := strings.Split(loggedLines, "\n")
1216                 if len(lines) != 4 {
1217                         t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines))
1218                 }
1219         }
1220
1221         checkKeylogLines("client", clientBuf.String())
1222         checkKeylogLines("server", serverBuf.String())
1223 }
1224
1225 func TestHandshakeClientALPNMatch(t *testing.T) {
1226         config := testConfig.Clone()
1227         config.NextProtos = []string{"proto2", "proto1"}
1228
1229         test := &clientTest{
1230                 name: "ALPN",
1231                 // Note that this needs OpenSSL 1.0.2 because that is the first
1232                 // version that supports the -alpn flag.
1233                 args:   []string{"-alpn", "proto1,proto2"},
1234                 config: config,
1235                 validate: func(state ConnectionState) error {
1236                         // The server's preferences should override the client.
1237                         if state.NegotiatedProtocol != "proto1" {
1238                                 return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
1239                         }
1240                         return nil
1241                 },
1242         }
1243         runClientTestTLS12(t, test)
1244         runClientTestTLS13(t, test)
1245 }
1246
1247 func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) {
1248         // This checks that the server can't select an application protocol that the
1249         // client didn't offer.
1250
1251         c, s := localPipe(t)
1252         errChan := make(chan error, 1)
1253
1254         go func() {
1255                 client := Client(c, &Config{
1256                         ServerName:   "foo",
1257                         CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1258                         NextProtos:   []string{"http", "something-else"},
1259                 })
1260                 errChan <- client.Handshake()
1261         }()
1262
1263         var header [5]byte
1264         if _, err := io.ReadFull(s, header[:]); err != nil {
1265                 t.Fatal(err)
1266         }
1267         recordLen := int(header[3])<<8 | int(header[4])
1268
1269         record := make([]byte, recordLen)
1270         if _, err := io.ReadFull(s, record); err != nil {
1271                 t.Fatal(err)
1272         }
1273
1274         serverHello := &serverHelloMsg{
1275                 vers:         VersionTLS12,
1276                 random:       make([]byte, 32),
1277                 cipherSuite:  TLS_RSA_WITH_AES_128_GCM_SHA256,
1278                 alpnProtocol: "how-about-this",
1279         }
1280         serverHelloBytes := mustMarshal(t, serverHello)
1281
1282         s.Write([]byte{
1283                 byte(recordTypeHandshake),
1284                 byte(VersionTLS12 >> 8),
1285                 byte(VersionTLS12 & 0xff),
1286                 byte(len(serverHelloBytes) >> 8),
1287                 byte(len(serverHelloBytes)),
1288         })
1289         s.Write(serverHelloBytes)
1290         s.Close()
1291
1292         if err := <-errChan; !strings.Contains(err.Error(), "server selected unadvertised ALPN protocol") {
1293                 t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1294         }
1295 }
1296
1297 // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
1298 const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
1299
1300 func TestHandshakClientSCTs(t *testing.T) {
1301         config := testConfig.Clone()
1302
1303         scts, err := base64.StdEncoding.DecodeString(sctsBase64)
1304         if err != nil {
1305                 t.Fatal(err)
1306         }
1307
1308         // Note that this needs OpenSSL 1.0.2 because that is the first
1309         // version that supports the -serverinfo flag.
1310         test := &clientTest{
1311                 name:       "SCT",
1312                 config:     config,
1313                 extensions: [][]byte{scts},
1314                 validate: func(state ConnectionState) error {
1315                         expectedSCTs := [][]byte{
1316                                 scts[8:125],
1317                                 scts[127:245],
1318                                 scts[247:],
1319                         }
1320                         if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
1321                                 return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
1322                         }
1323                         for i, expected := range expectedSCTs {
1324                                 if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
1325                                         return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
1326                                 }
1327                         }
1328                         return nil
1329                 },
1330         }
1331         runClientTestTLS12(t, test)
1332
1333         // TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only
1334         // supports ServerHello extensions.
1335 }
1336
1337 func TestRenegotiationRejected(t *testing.T) {
1338         config := testConfig.Clone()
1339         test := &clientTest{
1340                 name:                        "RenegotiationRejected",
1341                 args:                        []string{"-state"},
1342                 config:                      config,
1343                 numRenegotiations:           1,
1344                 renegotiationExpectedToFail: 1,
1345                 checkRenegotiationError: func(renegotiationNum int, err error) error {
1346                         if err == nil {
1347                                 return errors.New("expected error from renegotiation but got nil")
1348                         }
1349                         if !strings.Contains(err.Error(), "no renegotiation") {
1350                                 return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1351                         }
1352                         return nil
1353                 },
1354         }
1355         runClientTestTLS12(t, test)
1356 }
1357
1358 func TestRenegotiateOnce(t *testing.T) {
1359         config := testConfig.Clone()
1360         config.Renegotiation = RenegotiateOnceAsClient
1361
1362         test := &clientTest{
1363                 name:              "RenegotiateOnce",
1364                 args:              []string{"-state"},
1365                 config:            config,
1366                 numRenegotiations: 1,
1367         }
1368
1369         runClientTestTLS12(t, test)
1370 }
1371
1372 func TestRenegotiateTwice(t *testing.T) {
1373         config := testConfig.Clone()
1374         config.Renegotiation = RenegotiateFreelyAsClient
1375
1376         test := &clientTest{
1377                 name:              "RenegotiateTwice",
1378                 args:              []string{"-state"},
1379                 config:            config,
1380                 numRenegotiations: 2,
1381         }
1382
1383         runClientTestTLS12(t, test)
1384 }
1385
1386 func TestRenegotiateTwiceRejected(t *testing.T) {
1387         config := testConfig.Clone()
1388         config.Renegotiation = RenegotiateOnceAsClient
1389
1390         test := &clientTest{
1391                 name:                        "RenegotiateTwiceRejected",
1392                 args:                        []string{"-state"},
1393                 config:                      config,
1394                 numRenegotiations:           2,
1395                 renegotiationExpectedToFail: 2,
1396                 checkRenegotiationError: func(renegotiationNum int, err error) error {
1397                         if renegotiationNum == 1 {
1398                                 return err
1399                         }
1400
1401                         if err == nil {
1402                                 return errors.New("expected error from renegotiation but got nil")
1403                         }
1404                         if !strings.Contains(err.Error(), "no renegotiation") {
1405                                 return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1406                         }
1407                         return nil
1408                 },
1409         }
1410
1411         runClientTestTLS12(t, test)
1412 }
1413
1414 func TestHandshakeClientExportKeyingMaterial(t *testing.T) {
1415         test := &clientTest{
1416                 name:   "ExportKeyingMaterial",
1417                 config: testConfig.Clone(),
1418                 validate: func(state ConnectionState) error {
1419                         if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil {
1420                                 return fmt.Errorf("ExportKeyingMaterial failed: %v", err)
1421                         } else if len(km) != 42 {
1422                                 return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42)
1423                         }
1424                         return nil
1425                 },
1426         }
1427         runClientTestTLS10(t, test)
1428         runClientTestTLS12(t, test)
1429         runClientTestTLS13(t, test)
1430 }
1431
1432 var hostnameInSNITests = []struct {
1433         in, out string
1434 }{
1435         // Opaque string
1436         {"", ""},
1437         {"localhost", "localhost"},
1438         {"foo, bar, baz and qux", "foo, bar, baz and qux"},
1439
1440         // DNS hostname
1441         {"golang.org", "golang.org"},
1442         {"golang.org.", "golang.org"},
1443
1444         // Literal IPv4 address
1445         {"1.2.3.4", ""},
1446
1447         // Literal IPv6 address
1448         {"::1", ""},
1449         {"::1%lo0", ""}, // with zone identifier
1450         {"[::1]", ""},   // as per RFC 5952 we allow the [] style as IPv6 literal
1451         {"[::1%lo0]", ""},
1452 }
1453
1454 func TestHostnameInSNI(t *testing.T) {
1455         for _, tt := range hostnameInSNITests {
1456                 c, s := localPipe(t)
1457
1458                 go func(host string) {
1459                         Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
1460                 }(tt.in)
1461
1462                 var header [5]byte
1463                 if _, err := io.ReadFull(s, header[:]); err != nil {
1464                         t.Fatal(err)
1465                 }
1466                 recordLen := int(header[3])<<8 | int(header[4])
1467
1468                 record := make([]byte, recordLen)
1469                 if _, err := io.ReadFull(s, record[:]); err != nil {
1470                         t.Fatal(err)
1471                 }
1472
1473                 c.Close()
1474                 s.Close()
1475
1476                 var m clientHelloMsg
1477                 if !m.unmarshal(record) {
1478                         t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
1479                         continue
1480                 }
1481                 if tt.in != tt.out && m.serverName == tt.in {
1482                         t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
1483                 }
1484                 if m.serverName != tt.out {
1485                         t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
1486                 }
1487         }
1488 }
1489
1490 func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
1491         // This checks that the server can't select a cipher suite that the
1492         // client didn't offer. See #13174.
1493
1494         c, s := localPipe(t)
1495         errChan := make(chan error, 1)
1496
1497         go func() {
1498                 client := Client(c, &Config{
1499                         ServerName:   "foo",
1500                         CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1501                 })
1502                 errChan <- client.Handshake()
1503         }()
1504
1505         var header [5]byte
1506         if _, err := io.ReadFull(s, header[:]); err != nil {
1507                 t.Fatal(err)
1508         }
1509         recordLen := int(header[3])<<8 | int(header[4])
1510
1511         record := make([]byte, recordLen)
1512         if _, err := io.ReadFull(s, record); err != nil {
1513                 t.Fatal(err)
1514         }
1515
1516         // Create a ServerHello that selects a different cipher suite than the
1517         // sole one that the client offered.
1518         serverHello := &serverHelloMsg{
1519                 vers:        VersionTLS12,
1520                 random:      make([]byte, 32),
1521                 cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
1522         }
1523         serverHelloBytes := mustMarshal(t, serverHello)
1524
1525         s.Write([]byte{
1526                 byte(recordTypeHandshake),
1527                 byte(VersionTLS12 >> 8),
1528                 byte(VersionTLS12 & 0xff),
1529                 byte(len(serverHelloBytes) >> 8),
1530                 byte(len(serverHelloBytes)),
1531         })
1532         s.Write(serverHelloBytes)
1533         s.Close()
1534
1535         if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
1536                 t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1537         }
1538 }
1539
1540 func TestVerifyConnection(t *testing.T) {
1541         t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
1542         t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
1543 }
1544
1545 func testVerifyConnection(t *testing.T, version uint16) {
1546         checkFields := func(c ConnectionState, called *int, errorType string) error {
1547                 if c.Version != version {
1548                         return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
1549                 }
1550                 if c.HandshakeComplete {
1551                         return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
1552                 }
1553                 if c.ServerName != "example.golang" {
1554                         return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
1555                 }
1556                 if c.NegotiatedProtocol != "protocol1" {
1557                         return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
1558                 }
1559                 if c.CipherSuite == 0 {
1560                         return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
1561                 }
1562                 wantDidResume := false
1563                 if *called == 2 { // if this is the second time, then it should be a resumption
1564                         wantDidResume = true
1565                 }
1566                 if c.DidResume != wantDidResume {
1567                         return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
1568                 }
1569                 return nil
1570         }
1571
1572         tests := []struct {
1573                 name            string
1574                 configureServer func(*Config, *int)
1575                 configureClient func(*Config, *int)
1576         }{
1577                 {
1578                         name: "RequireAndVerifyClientCert",
1579                         configureServer: func(config *Config, called *int) {
1580                                 config.ClientAuth = RequireAndVerifyClientCert
1581                                 config.VerifyConnection = func(c ConnectionState) error {
1582                                         *called++
1583                                         if l := len(c.PeerCertificates); l != 1 {
1584                                                 return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
1585                                         }
1586                                         if len(c.VerifiedChains) == 0 {
1587                                                 return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
1588                                         }
1589                                         return checkFields(c, called, "server")
1590                                 }
1591                         },
1592                         configureClient: func(config *Config, called *int) {
1593                                 config.VerifyConnection = func(c ConnectionState) error {
1594                                         *called++
1595                                         if l := len(c.PeerCertificates); l != 1 {
1596                                                 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1597                                         }
1598                                         if len(c.VerifiedChains) == 0 {
1599                                                 return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
1600                                         }
1601                                         if c.DidResume {
1602                                                 return nil
1603                                                 // The SCTs and OCSP Response are dropped on resumption.
1604                                                 // See http://golang.org/issue/39075.
1605                                         }
1606                                         if len(c.OCSPResponse) == 0 {
1607                                                 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1608                                         }
1609                                         if len(c.SignedCertificateTimestamps) == 0 {
1610                                                 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1611                                         }
1612                                         return checkFields(c, called, "client")
1613                                 }
1614                         },
1615                 },
1616                 {
1617                         name: "InsecureSkipVerify",
1618                         configureServer: func(config *Config, called *int) {
1619                                 config.ClientAuth = RequireAnyClientCert
1620                                 config.InsecureSkipVerify = true
1621                                 config.VerifyConnection = func(c ConnectionState) error {
1622                                         *called++
1623                                         if l := len(c.PeerCertificates); l != 1 {
1624                                                 return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
1625                                         }
1626                                         if c.VerifiedChains != nil {
1627                                                 return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
1628                                         }
1629                                         return checkFields(c, called, "server")
1630                                 }
1631                         },
1632                         configureClient: func(config *Config, called *int) {
1633                                 config.InsecureSkipVerify = true
1634                                 config.VerifyConnection = func(c ConnectionState) error {
1635                                         *called++
1636                                         if l := len(c.PeerCertificates); l != 1 {
1637                                                 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1638                                         }
1639                                         if c.VerifiedChains != nil {
1640                                                 return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
1641                                         }
1642                                         if c.DidResume {
1643                                                 return nil
1644                                                 // The SCTs and OCSP Response are dropped on resumption.
1645                                                 // See http://golang.org/issue/39075.
1646                                         }
1647                                         if len(c.OCSPResponse) == 0 {
1648                                                 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1649                                         }
1650                                         if len(c.SignedCertificateTimestamps) == 0 {
1651                                                 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1652                                         }
1653                                         return checkFields(c, called, "client")
1654                                 }
1655                         },
1656                 },
1657                 {
1658                         name: "NoClientCert",
1659                         configureServer: func(config *Config, called *int) {
1660                                 config.ClientAuth = NoClientCert
1661                                 config.VerifyConnection = func(c ConnectionState) error {
1662                                         *called++
1663                                         return checkFields(c, called, "server")
1664                                 }
1665                         },
1666                         configureClient: func(config *Config, called *int) {
1667                                 config.VerifyConnection = func(c ConnectionState) error {
1668                                         *called++
1669                                         return checkFields(c, called, "client")
1670                                 }
1671                         },
1672                 },
1673                 {
1674                         name: "RequestClientCert",
1675                         configureServer: func(config *Config, called *int) {
1676                                 config.ClientAuth = RequestClientCert
1677                                 config.VerifyConnection = func(c ConnectionState) error {
1678                                         *called++
1679                                         return checkFields(c, called, "server")
1680                                 }
1681                         },
1682                         configureClient: func(config *Config, called *int) {
1683                                 config.Certificates = nil // clear the client cert
1684                                 config.VerifyConnection = func(c ConnectionState) error {
1685                                         *called++
1686                                         if l := len(c.PeerCertificates); l != 1 {
1687                                                 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1688                                         }
1689                                         if len(c.VerifiedChains) == 0 {
1690                                                 return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
1691                                         }
1692                                         if c.DidResume {
1693                                                 return nil
1694                                                 // The SCTs and OCSP Response are dropped on resumption.
1695                                                 // See http://golang.org/issue/39075.
1696                                         }
1697                                         if len(c.OCSPResponse) == 0 {
1698                                                 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1699                                         }
1700                                         if len(c.SignedCertificateTimestamps) == 0 {
1701                                                 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1702                                         }
1703                                         return checkFields(c, called, "client")
1704                                 }
1705                         },
1706                 },
1707         }
1708         for _, test := range tests {
1709                 issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1710                 if err != nil {
1711                         panic(err)
1712                 }
1713                 rootCAs := x509.NewCertPool()
1714                 rootCAs.AddCert(issuer)
1715
1716                 var serverCalled, clientCalled int
1717
1718                 serverConfig := &Config{
1719                         MaxVersion:   version,
1720                         Certificates: []Certificate{testConfig.Certificates[0]},
1721                         ClientCAs:    rootCAs,
1722                         NextProtos:   []string{"protocol1"},
1723                 }
1724                 serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1725                 serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
1726                 test.configureServer(serverConfig, &serverCalled)
1727
1728                 clientConfig := &Config{
1729                         MaxVersion:         version,
1730                         ClientSessionCache: NewLRUClientSessionCache(32),
1731                         RootCAs:            rootCAs,
1732                         ServerName:         "example.golang",
1733                         Certificates:       []Certificate{testConfig.Certificates[0]},
1734                         NextProtos:         []string{"protocol1"},
1735                 }
1736                 test.configureClient(clientConfig, &clientCalled)
1737
1738                 testHandshakeState := func(name string, didResume bool) {
1739                         _, hs, err := testHandshake(t, clientConfig, serverConfig)
1740                         if err != nil {
1741                                 t.Fatalf("%s: handshake failed: %s", name, err)
1742                         }
1743                         if hs.DidResume != didResume {
1744                                 t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
1745                         }
1746                         wantCalled := 1
1747                         if didResume {
1748                                 wantCalled = 2 // resumption would mean this is the second time it was called in this test
1749                         }
1750                         if clientCalled != wantCalled {
1751                                 t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
1752                         }
1753                         if serverCalled != wantCalled {
1754                                 t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
1755                         }
1756                 }
1757                 testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
1758                 testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
1759         }
1760 }
1761
1762 func TestVerifyPeerCertificate(t *testing.T) {
1763         t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
1764         t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
1765 }
1766
1767 func testVerifyPeerCertificate(t *testing.T, version uint16) {
1768         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1769         if err != nil {
1770                 panic(err)
1771         }
1772
1773         rootCAs := x509.NewCertPool()
1774         rootCAs.AddCert(issuer)
1775
1776         now := func() time.Time { return time.Unix(1476984729, 0) }
1777
1778         sentinelErr := errors.New("TestVerifyPeerCertificate")
1779
1780         verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1781                 if l := len(rawCerts); l != 1 {
1782                         return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1783                 }
1784                 if len(validatedChains) == 0 {
1785                         return errors.New("got len(validatedChains) = 0, wanted non-zero")
1786                 }
1787                 *called = true
1788                 return nil
1789         }
1790         verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
1791                 if l := len(c.PeerCertificates); l != 1 {
1792                         return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
1793                 }
1794                 if len(c.VerifiedChains) == 0 {
1795                         return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
1796                 }
1797                 if isClient && len(c.OCSPResponse) == 0 {
1798                         return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
1799                 }
1800                 *called = true
1801                 return nil
1802         }
1803
1804         tests := []struct {
1805                 configureServer func(*Config, *bool)
1806                 configureClient func(*Config, *bool)
1807                 validate        func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
1808         }{
1809                 {
1810                         configureServer: func(config *Config, called *bool) {
1811                                 config.InsecureSkipVerify = false
1812                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1813                                         return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1814                                 }
1815                         },
1816                         configureClient: func(config *Config, called *bool) {
1817                                 config.InsecureSkipVerify = false
1818                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1819                                         return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1820                                 }
1821                         },
1822                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1823                                 if clientErr != nil {
1824                                         t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1825                                 }
1826                                 if serverErr != nil {
1827                                         t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1828                                 }
1829                                 if !clientCalled {
1830                                         t.Errorf("test[%d]: client did not call callback", testNo)
1831                                 }
1832                                 if !serverCalled {
1833                                         t.Errorf("test[%d]: server did not call callback", testNo)
1834                                 }
1835                         },
1836                 },
1837                 {
1838                         configureServer: func(config *Config, called *bool) {
1839                                 config.InsecureSkipVerify = false
1840                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1841                                         return sentinelErr
1842                                 }
1843                         },
1844                         configureClient: func(config *Config, called *bool) {
1845                                 config.VerifyPeerCertificate = nil
1846                         },
1847                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1848                                 if serverErr != sentinelErr {
1849                                         t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1850                                 }
1851                         },
1852                 },
1853                 {
1854                         configureServer: func(config *Config, called *bool) {
1855                                 config.InsecureSkipVerify = false
1856                         },
1857                         configureClient: func(config *Config, called *bool) {
1858                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1859                                         return sentinelErr
1860                                 }
1861                         },
1862                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1863                                 if clientErr != sentinelErr {
1864                                         t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1865                                 }
1866                         },
1867                 },
1868                 {
1869                         configureServer: func(config *Config, called *bool) {
1870                                 config.InsecureSkipVerify = false
1871                         },
1872                         configureClient: func(config *Config, called *bool) {
1873                                 config.InsecureSkipVerify = true
1874                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1875                                         if l := len(rawCerts); l != 1 {
1876                                                 return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1877                                         }
1878                                         // With InsecureSkipVerify set, this
1879                                         // callback should still be called but
1880                                         // validatedChains must be empty.
1881                                         if l := len(validatedChains); l != 0 {
1882                                                 return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
1883                                         }
1884                                         *called = true
1885                                         return nil
1886                                 }
1887                         },
1888                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1889                                 if clientErr != nil {
1890                                         t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1891                                 }
1892                                 if serverErr != nil {
1893                                         t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1894                                 }
1895                                 if !clientCalled {
1896                                         t.Errorf("test[%d]: client did not call callback", testNo)
1897                                 }
1898                         },
1899                 },
1900                 {
1901                         configureServer: func(config *Config, called *bool) {
1902                                 config.InsecureSkipVerify = false
1903                                 config.VerifyConnection = func(c ConnectionState) error {
1904                                         return verifyConnectionCallback(called, false, c)
1905                                 }
1906                         },
1907                         configureClient: func(config *Config, called *bool) {
1908                                 config.InsecureSkipVerify = false
1909                                 config.VerifyConnection = func(c ConnectionState) error {
1910                                         return verifyConnectionCallback(called, true, c)
1911                                 }
1912                         },
1913                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1914                                 if clientErr != nil {
1915                                         t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1916                                 }
1917                                 if serverErr != nil {
1918                                         t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1919                                 }
1920                                 if !clientCalled {
1921                                         t.Errorf("test[%d]: client did not call callback", testNo)
1922                                 }
1923                                 if !serverCalled {
1924                                         t.Errorf("test[%d]: server did not call callback", testNo)
1925                                 }
1926                         },
1927                 },
1928                 {
1929                         configureServer: func(config *Config, called *bool) {
1930                                 config.InsecureSkipVerify = false
1931                                 config.VerifyConnection = func(c ConnectionState) error {
1932                                         return sentinelErr
1933                                 }
1934                         },
1935                         configureClient: func(config *Config, called *bool) {
1936                                 config.InsecureSkipVerify = false
1937                                 config.VerifyConnection = nil
1938                         },
1939                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1940                                 if serverErr != sentinelErr {
1941                                         t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1942                                 }
1943                         },
1944                 },
1945                 {
1946                         configureServer: func(config *Config, called *bool) {
1947                                 config.InsecureSkipVerify = false
1948                                 config.VerifyConnection = nil
1949                         },
1950                         configureClient: func(config *Config, called *bool) {
1951                                 config.InsecureSkipVerify = false
1952                                 config.VerifyConnection = func(c ConnectionState) error {
1953                                         return sentinelErr
1954                                 }
1955                         },
1956                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1957                                 if clientErr != sentinelErr {
1958                                         t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1959                                 }
1960                         },
1961                 },
1962                 {
1963                         configureServer: func(config *Config, called *bool) {
1964                                 config.InsecureSkipVerify = false
1965                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1966                                         return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1967                                 }
1968                                 config.VerifyConnection = func(c ConnectionState) error {
1969                                         return sentinelErr
1970                                 }
1971                         },
1972                         configureClient: func(config *Config, called *bool) {
1973                                 config.InsecureSkipVerify = false
1974                                 config.VerifyPeerCertificate = nil
1975                                 config.VerifyConnection = nil
1976                         },
1977                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1978                                 if serverErr != sentinelErr {
1979                                         t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1980                                 }
1981                                 if !serverCalled {
1982                                         t.Errorf("test[%d]: server did not call callback", testNo)
1983                                 }
1984                         },
1985                 },
1986                 {
1987                         configureServer: func(config *Config, called *bool) {
1988                                 config.InsecureSkipVerify = false
1989                                 config.VerifyPeerCertificate = nil
1990                                 config.VerifyConnection = nil
1991                         },
1992                         configureClient: func(config *Config, called *bool) {
1993                                 config.InsecureSkipVerify = false
1994                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1995                                         return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1996                                 }
1997                                 config.VerifyConnection = func(c ConnectionState) error {
1998                                         return sentinelErr
1999                                 }
2000                         },
2001                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
2002                                 if clientErr != sentinelErr {
2003                                         t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
2004                                 }
2005                                 if !clientCalled {
2006                                         t.Errorf("test[%d]: client did not call callback", testNo)
2007                                 }
2008                         },
2009                 },
2010         }
2011
2012         for i, test := range tests {
2013                 c, s := localPipe(t)
2014                 done := make(chan error)
2015
2016                 var clientCalled, serverCalled bool
2017
2018                 go func() {
2019                         config := testConfig.Clone()
2020                         config.ServerName = "example.golang"
2021                         config.ClientAuth = RequireAndVerifyClientCert
2022                         config.ClientCAs = rootCAs
2023                         config.Time = now
2024                         config.MaxVersion = version
2025                         config.Certificates = make([]Certificate, 1)
2026                         config.Certificates[0].Certificate = [][]byte{testRSACertificate}
2027                         config.Certificates[0].PrivateKey = testRSAPrivateKey
2028                         config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
2029                         config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
2030                         test.configureServer(config, &serverCalled)
2031
2032                         err = Server(s, config).Handshake()
2033                         s.Close()
2034                         done <- err
2035                 }()
2036
2037                 config := testConfig.Clone()
2038                 config.ServerName = "example.golang"
2039                 config.RootCAs = rootCAs
2040                 config.Time = now
2041                 config.MaxVersion = version
2042                 test.configureClient(config, &clientCalled)
2043                 clientErr := Client(c, config).Handshake()
2044                 c.Close()
2045                 serverErr := <-done
2046
2047                 test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
2048         }
2049 }
2050
2051 // brokenConn wraps a net.Conn and causes all Writes after a certain number to
2052 // fail with brokenConnErr.
2053 type brokenConn struct {
2054         net.Conn
2055
2056         // breakAfter is the number of successful writes that will be allowed
2057         // before all subsequent writes fail.
2058         breakAfter int
2059
2060         // numWrites is the number of writes that have been done.
2061         numWrites int
2062 }
2063
2064 // brokenConnErr is the error that brokenConn returns once exhausted.
2065 var brokenConnErr = errors.New("too many writes to brokenConn")
2066
2067 func (b *brokenConn) Write(data []byte) (int, error) {
2068         if b.numWrites >= b.breakAfter {
2069                 return 0, brokenConnErr
2070         }
2071
2072         b.numWrites++
2073         return b.Conn.Write(data)
2074 }
2075
2076 func TestFailedWrite(t *testing.T) {
2077         // Test that a write error during the handshake is returned.
2078         for _, breakAfter := range []int{0, 1} {
2079                 c, s := localPipe(t)
2080                 done := make(chan bool)
2081
2082                 go func() {
2083                         Server(s, testConfig).Handshake()
2084                         s.Close()
2085                         done <- true
2086                 }()
2087
2088                 brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
2089                 err := Client(brokenC, testConfig).Handshake()
2090                 if err != brokenConnErr {
2091                         t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
2092                 }
2093                 brokenC.Close()
2094
2095                 <-done
2096         }
2097 }
2098
2099 // writeCountingConn wraps a net.Conn and counts the number of Write calls.
2100 type writeCountingConn struct {
2101         net.Conn
2102
2103         // numWrites is the number of writes that have been done.
2104         numWrites int
2105 }
2106
2107 func (wcc *writeCountingConn) Write(data []byte) (int, error) {
2108         wcc.numWrites++
2109         return wcc.Conn.Write(data)
2110 }
2111
2112 func TestBuffering(t *testing.T) {
2113         t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) })
2114         t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) })
2115 }
2116
2117 func testBuffering(t *testing.T, version uint16) {
2118         c, s := localPipe(t)
2119         done := make(chan bool)
2120
2121         clientWCC := &writeCountingConn{Conn: c}
2122         serverWCC := &writeCountingConn{Conn: s}
2123
2124         go func() {
2125                 config := testConfig.Clone()
2126                 config.MaxVersion = version
2127                 Server(serverWCC, config).Handshake()
2128                 serverWCC.Close()
2129                 done <- true
2130         }()
2131
2132         err := Client(clientWCC, testConfig).Handshake()
2133         if err != nil {
2134                 t.Fatal(err)
2135         }
2136         clientWCC.Close()
2137         <-done
2138
2139         var expectedClient, expectedServer int
2140         if version == VersionTLS13 {
2141                 expectedClient = 2
2142                 expectedServer = 1
2143         } else {
2144                 expectedClient = 2
2145                 expectedServer = 2
2146         }
2147
2148         if n := clientWCC.numWrites; n != expectedClient {
2149                 t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n)
2150         }
2151
2152         if n := serverWCC.numWrites; n != expectedServer {
2153                 t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n)
2154         }
2155 }
2156
2157 func TestAlertFlushing(t *testing.T) {
2158         c, s := localPipe(t)
2159         done := make(chan bool)
2160
2161         clientWCC := &writeCountingConn{Conn: c}
2162         serverWCC := &writeCountingConn{Conn: s}
2163
2164         serverConfig := testConfig.Clone()
2165
2166         // Cause a signature-time error
2167         brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
2168         brokenKey.D = big.NewInt(42)
2169         serverConfig.Certificates = []Certificate{{
2170                 Certificate: [][]byte{testRSACertificate},
2171                 PrivateKey:  &brokenKey,
2172         }}
2173
2174         go func() {
2175                 Server(serverWCC, serverConfig).Handshake()
2176                 serverWCC.Close()
2177                 done <- true
2178         }()
2179
2180         err := Client(clientWCC, testConfig).Handshake()
2181         if err == nil {
2182                 t.Fatal("client unexpectedly returned no error")
2183         }
2184
2185         const expectedError = "remote error: tls: internal error"
2186         if e := err.Error(); !strings.Contains(e, expectedError) {
2187                 t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
2188         }
2189         clientWCC.Close()
2190         <-done
2191
2192         if n := serverWCC.numWrites; n != 1 {
2193                 t.Errorf("expected server handshake to complete with one write, but saw %d", n)
2194         }
2195 }
2196
2197 func TestHandshakeRace(t *testing.T) {
2198         if testing.Short() {
2199                 t.Skip("skipping in -short mode")
2200         }
2201         t.Parallel()
2202         // This test races a Read and Write to try and complete a handshake in
2203         // order to provide some evidence that there are no races or deadlocks
2204         // in the handshake locking.
2205         for i := 0; i < 32; i++ {
2206                 c, s := localPipe(t)
2207
2208                 go func() {
2209                         server := Server(s, testConfig)
2210                         if err := server.Handshake(); err != nil {
2211                                 panic(err)
2212                         }
2213
2214                         var request [1]byte
2215                         if n, err := server.Read(request[:]); err != nil || n != 1 {
2216                                 panic(err)
2217                         }
2218
2219                         server.Write(request[:])
2220                         server.Close()
2221                 }()
2222
2223                 startWrite := make(chan struct{})
2224                 startRead := make(chan struct{})
2225                 readDone := make(chan struct{}, 1)
2226
2227                 client := Client(c, testConfig)
2228                 go func() {
2229                         <-startWrite
2230                         var request [1]byte
2231                         client.Write(request[:])
2232                 }()
2233
2234                 go func() {
2235                         <-startRead
2236                         var reply [1]byte
2237                         if _, err := io.ReadFull(client, reply[:]); err != nil {
2238                                 panic(err)
2239                         }
2240                         c.Close()
2241                         readDone <- struct{}{}
2242                 }()
2243
2244                 if i&1 == 1 {
2245                         startWrite <- struct{}{}
2246                         startRead <- struct{}{}
2247                 } else {
2248                         startRead <- struct{}{}
2249                         startWrite <- struct{}{}
2250                 }
2251                 <-readDone
2252         }
2253 }
2254
2255 var getClientCertificateTests = []struct {
2256         setup               func(*Config, *Config)
2257         expectedClientError string
2258         verify              func(*testing.T, int, *ConnectionState)
2259 }{
2260         {
2261                 func(clientConfig, serverConfig *Config) {
2262                         // Returning a Certificate with no certificate data
2263                         // should result in an empty message being sent to the
2264                         // server.
2265                         serverConfig.ClientCAs = nil
2266                         clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2267                                 if len(cri.SignatureSchemes) == 0 {
2268                                         panic("empty SignatureSchemes")
2269                                 }
2270                                 if len(cri.AcceptableCAs) != 0 {
2271                                         panic("AcceptableCAs should have been empty")
2272                                 }
2273                                 return new(Certificate), nil
2274                         }
2275                 },
2276                 "",
2277                 func(t *testing.T, testNum int, cs *ConnectionState) {
2278                         if l := len(cs.PeerCertificates); l != 0 {
2279                                 t.Errorf("#%d: expected no certificates but got %d", testNum, l)
2280                         }
2281                 },
2282         },
2283         {
2284                 func(clientConfig, serverConfig *Config) {
2285                         // With TLS 1.1, the SignatureSchemes should be
2286                         // synthesised from the supported certificate types.
2287                         clientConfig.MaxVersion = VersionTLS11
2288                         clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2289                                 if len(cri.SignatureSchemes) == 0 {
2290                                         panic("empty SignatureSchemes")
2291                                 }
2292                                 return new(Certificate), nil
2293                         }
2294                 },
2295                 "",
2296                 func(t *testing.T, testNum int, cs *ConnectionState) {
2297                         if l := len(cs.PeerCertificates); l != 0 {
2298                                 t.Errorf("#%d: expected no certificates but got %d", testNum, l)
2299                         }
2300                 },
2301         },
2302         {
2303                 func(clientConfig, serverConfig *Config) {
2304                         // Returning an error should abort the handshake with
2305                         // that error.
2306                         clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2307                                 return nil, errors.New("GetClientCertificate")
2308                         }
2309                 },
2310                 "GetClientCertificate",
2311                 func(t *testing.T, testNum int, cs *ConnectionState) {
2312                 },
2313         },
2314         {
2315                 func(clientConfig, serverConfig *Config) {
2316                         clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2317                                 if len(cri.AcceptableCAs) == 0 {
2318                                         panic("empty AcceptableCAs")
2319                                 }
2320                                 cert := &Certificate{
2321                                         Certificate: [][]byte{testRSACertificate},
2322                                         PrivateKey:  testRSAPrivateKey,
2323                                 }
2324                                 return cert, nil
2325                         }
2326                 },
2327                 "",
2328                 func(t *testing.T, testNum int, cs *ConnectionState) {
2329                         if len(cs.VerifiedChains) == 0 {
2330                                 t.Errorf("#%d: expected some verified chains, but found none", testNum)
2331                         }
2332                 },
2333         },
2334 }
2335
2336 func TestGetClientCertificate(t *testing.T) {
2337         t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) })
2338         t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) })
2339 }
2340
2341 func testGetClientCertificate(t *testing.T, version uint16) {
2342         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
2343         if err != nil {
2344                 panic(err)
2345         }
2346
2347         for i, test := range getClientCertificateTests {
2348                 serverConfig := testConfig.Clone()
2349                 serverConfig.ClientAuth = VerifyClientCertIfGiven
2350                 serverConfig.RootCAs = x509.NewCertPool()
2351                 serverConfig.RootCAs.AddCert(issuer)
2352                 serverConfig.ClientCAs = serverConfig.RootCAs
2353                 serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
2354                 serverConfig.MaxVersion = version
2355
2356                 clientConfig := testConfig.Clone()
2357                 clientConfig.MaxVersion = version
2358
2359                 test.setup(clientConfig, serverConfig)
2360
2361                 type serverResult struct {
2362                         cs  ConnectionState
2363                         err error
2364                 }
2365
2366                 c, s := localPipe(t)
2367                 done := make(chan serverResult)
2368
2369                 go func() {
2370                         defer s.Close()
2371                         server := Server(s, serverConfig)
2372                         err := server.Handshake()
2373
2374                         var cs ConnectionState
2375                         if err == nil {
2376                                 cs = server.ConnectionState()
2377                         }
2378                         done <- serverResult{cs, err}
2379                 }()
2380
2381                 clientErr := Client(c, clientConfig).Handshake()
2382                 c.Close()
2383
2384                 result := <-done
2385
2386                 if clientErr != nil {
2387                         if len(test.expectedClientError) == 0 {
2388                                 t.Errorf("#%d: client error: %v", i, clientErr)
2389                         } else if got := clientErr.Error(); got != test.expectedClientError {
2390                                 t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
2391                         } else {
2392                                 test.verify(t, i, &result.cs)
2393                         }
2394                 } else if len(test.expectedClientError) > 0 {
2395                         t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
2396                 } else if err := result.err; err != nil {
2397                         t.Errorf("#%d: server error: %v", i, err)
2398                 } else {
2399                         test.verify(t, i, &result.cs)
2400                 }
2401         }
2402 }
2403
2404 func TestRSAPSSKeyError(t *testing.T) {
2405         // crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for
2406         // public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with
2407         // the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't
2408         // parse, or that they don't carry *rsa.PublicKey keys.
2409         b, _ := pem.Decode([]byte(`
2410 -----BEGIN CERTIFICATE-----
2411 MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK
2412 MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC
2413 AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3
2414 MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP
2415 ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z
2416 /a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5
2417 b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL
2418 QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou
2419 czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT
2420 JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz
2421 AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn
2422 OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME
2423 AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab
2424 sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z
2425 H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1
2426 KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ
2427 bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD
2428 HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi
2429 RwBA9Xk1KBNF
2430 -----END CERTIFICATE-----`))
2431         if b == nil {
2432                 t.Fatal("Failed to decode certificate")
2433         }
2434         cert, err := x509.ParseCertificate(b.Bytes)
2435         if err != nil {
2436                 return
2437         }
2438         if _, ok := cert.PublicKey.(*rsa.PublicKey); ok {
2439                 t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms")
2440         }
2441 }
2442
2443 func TestCloseClientConnectionOnIdleServer(t *testing.T) {
2444         clientConn, serverConn := localPipe(t)
2445         client := Client(clientConn, testConfig.Clone())
2446         go func() {
2447                 var b [1]byte
2448                 serverConn.Read(b[:])
2449                 client.Close()
2450         }()
2451         client.SetWriteDeadline(time.Now().Add(time.Minute))
2452         err := client.Handshake()
2453         if err != nil {
2454                 if err, ok := err.(net.Error); ok && err.Timeout() {
2455                         t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
2456                 }
2457         } else {
2458                 t.Errorf("Error expected, but no error returned")
2459         }
2460 }
2461
2462 func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error {
2463         defer func() { testingOnlyForceDowngradeCanary = false }()
2464         testingOnlyForceDowngradeCanary = true
2465
2466         clientConfig := testConfig.Clone()
2467         clientConfig.MaxVersion = clientVersion
2468         serverConfig := testConfig.Clone()
2469         serverConfig.MaxVersion = serverVersion
2470         _, _, err := testHandshake(t, clientConfig, serverConfig)
2471         return err
2472 }
2473
2474 func TestDowngradeCanary(t *testing.T) {
2475         if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil {
2476                 t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected")
2477         }
2478         if testing.Short() {
2479                 t.Skip("skipping the rest of the checks in short mode")
2480         }
2481         if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil {
2482                 t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected")
2483         }
2484         if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil {
2485                 t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected")
2486         }
2487         if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil {
2488                 t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected")
2489         }
2490         if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil {
2491                 t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected")
2492         }
2493         if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil {
2494                 t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3")
2495         }
2496         if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil {
2497                 t.Errorf("client didn't ignore expected TLS 1.2 canary")
2498         }
2499         if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil {
2500                 t.Errorf("client unexpectedly reacted to a canary in TLS 1.1")
2501         }
2502         if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil {
2503                 t.Errorf("client unexpectedly reacted to a canary in TLS 1.0")
2504         }
2505 }
2506
2507 func TestResumptionKeepsOCSPAndSCT(t *testing.T) {
2508         t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) })
2509         t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) })
2510 }
2511
2512 func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
2513         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
2514         if err != nil {
2515                 t.Fatalf("failed to parse test issuer")
2516         }
2517         roots := x509.NewCertPool()
2518         roots.AddCert(issuer)
2519         clientConfig := &Config{
2520                 MaxVersion:         ver,
2521                 ClientSessionCache: NewLRUClientSessionCache(32),
2522                 ServerName:         "example.golang",
2523                 RootCAs:            roots,
2524         }
2525         serverConfig := testConfig.Clone()
2526         serverConfig.MaxVersion = ver
2527         serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3}
2528         serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}}
2529
2530         _, ccs, err := testHandshake(t, clientConfig, serverConfig)
2531         if err != nil {
2532                 t.Fatalf("handshake failed: %s", err)
2533         }
2534         // after a new session we expect to see OCSPResponse and
2535         // SignedCertificateTimestamps populated as usual
2536         if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
2537                 t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v",
2538                         serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
2539         }
2540         if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
2541                 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v",
2542                         serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
2543         }
2544
2545         // if the server doesn't send any SCTs, repopulate the old SCTs
2546         oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps
2547         serverConfig.Certificates[0].SignedCertificateTimestamps = nil
2548         _, ccs, err = testHandshake(t, clientConfig, serverConfig)
2549         if err != nil {
2550                 t.Fatalf("handshake failed: %s", err)
2551         }
2552         if !ccs.DidResume {
2553                 t.Fatalf("expected session to be resumed")
2554         }
2555         // after a resumed session we also expect to see OCSPResponse
2556         // and SignedCertificateTimestamps populated
2557         if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
2558                 t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v",
2559                         serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
2560         }
2561         if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) {
2562                 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
2563                         oldSCTs, ccs.SignedCertificateTimestamps)
2564         }
2565
2566         //  Only test overriding the SCTs for TLS 1.2, since in 1.3
2567         // the server won't send the message containing them
2568         if ver == VersionTLS13 {
2569                 return
2570         }
2571
2572         // if the server changes the SCTs it sends, they should override the saved SCTs
2573         serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}}
2574         _, ccs, err = testHandshake(t, clientConfig, serverConfig)
2575         if err != nil {
2576                 t.Fatalf("handshake failed: %s", err)
2577         }
2578         if !ccs.DidResume {
2579                 t.Fatalf("expected session to be resumed")
2580         }
2581         if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
2582                 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
2583                         serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
2584         }
2585 }
2586
2587 // TestClientHandshakeContextCancellation tests that canceling
2588 // the context given to the client side conn.HandshakeContext
2589 // interrupts the in-progress handshake.
2590 func TestClientHandshakeContextCancellation(t *testing.T) {
2591         c, s := localPipe(t)
2592         ctx, cancel := context.WithCancel(context.Background())
2593         unblockServer := make(chan struct{})
2594         defer close(unblockServer)
2595         go func() {
2596                 cancel()
2597                 <-unblockServer
2598                 _ = s.Close()
2599         }()
2600         cli := Client(c, testConfig)
2601         // Initiates client side handshake, which will block until the client hello is read
2602         // by the server, unless the cancellation works.
2603         err := cli.HandshakeContext(ctx)
2604         if err == nil {
2605                 t.Fatal("Client handshake did not error when the context was canceled")
2606         }
2607         if err != context.Canceled {
2608                 t.Errorf("Unexpected client handshake error: %v", err)
2609         }
2610         if runtime.GOARCH == "wasm" {
2611                 t.Skip("conn.Close does not error as expected when called multiple times on WASM")
2612         }
2613         err = cli.Close()
2614         if err == nil {
2615                 t.Error("Client connection was not closed when the context was canceled")
2616         }
2617 }