]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/handshake_client_test.go
crypto/tls: add WrapSession and UnwrapSession
[gostls13.git] / src / crypto / tls / handshake_client_test.go
1 // Copyright 2010 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "context"
10         "crypto/rsa"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/binary"
14         "encoding/pem"
15         "errors"
16         "fmt"
17         "io"
18         "math/big"
19         "net"
20         "os"
21         "os/exec"
22         "path/filepath"
23         "reflect"
24         "runtime"
25         "strconv"
26         "strings"
27         "testing"
28         "time"
29 )
30
31 // Note: see comment in handshake_test.go for details of how the reference
32 // tests work.
33
34 // opensslInputEvent enumerates possible inputs that can be sent to an `openssl
35 // s_client` process.
36 type opensslInputEvent int
37
38 const (
39         // opensslRenegotiate causes OpenSSL to request a renegotiation of the
40         // connection.
41         opensslRenegotiate opensslInputEvent = iota
42
43         // opensslSendBanner causes OpenSSL to send the contents of
44         // opensslSentinel on the connection.
45         opensslSendSentinel
46
47         // opensslKeyUpdate causes OpenSSL to send a key update message to the
48         // client and request one back.
49         opensslKeyUpdate
50 )
51
52 const opensslSentinel = "SENTINEL\n"
53
54 type opensslInput chan opensslInputEvent
55
56 func (i opensslInput) Read(buf []byte) (n int, err error) {
57         for event := range i {
58                 switch event {
59                 case opensslRenegotiate:
60                         return copy(buf, []byte("R\n")), nil
61                 case opensslKeyUpdate:
62                         return copy(buf, []byte("K\n")), nil
63                 case opensslSendSentinel:
64                         return copy(buf, []byte(opensslSentinel)), nil
65                 default:
66                         panic("unknown event")
67                 }
68         }
69
70         return 0, io.EOF
71 }
72
73 // opensslOutputSink is an io.Writer that receives the stdout and stderr from an
74 // `openssl` process and sends a value to handshakeComplete or readKeyUpdate
75 // when certain messages are seen.
76 type opensslOutputSink struct {
77         handshakeComplete chan struct{}
78         readKeyUpdate     chan struct{}
79         all               []byte
80         line              []byte
81 }
82
83 func newOpensslOutputSink() *opensslOutputSink {
84         return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil}
85 }
86
87 // opensslEndOfHandshake is a message that the “openssl s_server” tool will
88 // print when a handshake completes if run with “-state”.
89 const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
90
91 // opensslReadKeyUpdate is a message that the “openssl s_server” tool will
92 // print when a KeyUpdate message is received if run with “-state”.
93 const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update"
94
95 func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
96         o.line = append(o.line, data...)
97         o.all = append(o.all, data...)
98
99         for {
100                 line, next, ok := bytes.Cut(o.line, []byte("\n"))
101                 if !ok {
102                         break
103                 }
104
105                 if bytes.Equal([]byte(opensslEndOfHandshake), line) {
106                         o.handshakeComplete <- struct{}{}
107                 }
108                 if bytes.Equal([]byte(opensslReadKeyUpdate), line) {
109                         o.readKeyUpdate <- struct{}{}
110                 }
111                 o.line = next
112         }
113
114         return len(data), nil
115 }
116
117 func (o *opensslOutputSink) String() string {
118         return string(o.all)
119 }
120
121 // clientTest represents a test of the TLS client handshake against a reference
122 // implementation.
123 type clientTest struct {
124         // name is a freeform string identifying the test and the file in which
125         // the expected results will be stored.
126         name string
127         // args, if not empty, contains a series of arguments for the
128         // command to run for the reference server.
129         args []string
130         // config, if not nil, contains a custom Config to use for this test.
131         config *Config
132         // cert, if not empty, contains a DER-encoded certificate for the
133         // reference server.
134         cert []byte
135         // key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or
136         // *ecdsa.PrivateKey which is the private key for the reference server.
137         key any
138         // extensions, if not nil, contains a list of extension data to be returned
139         // from the ServerHello. The data should be in standard TLS format with
140         // a 2-byte uint16 type, 2-byte data length, followed by the extension data.
141         extensions [][]byte
142         // validate, if not nil, is a function that will be called with the
143         // ConnectionState of the resulting connection. It returns a non-nil
144         // error if the ConnectionState is unacceptable.
145         validate func(ConnectionState) error
146         // numRenegotiations is the number of times that the connection will be
147         // renegotiated.
148         numRenegotiations int
149         // renegotiationExpectedToFail, if not zero, is the number of the
150         // renegotiation attempt that is expected to fail.
151         renegotiationExpectedToFail int
152         // checkRenegotiationError, if not nil, is called with any error
153         // arising from renegotiation. It can map expected errors to nil to
154         // ignore them.
155         checkRenegotiationError func(renegotiationNum int, err error) error
156         // sendKeyUpdate will cause the server to send a KeyUpdate message.
157         sendKeyUpdate bool
158 }
159
160 var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"}
161
162 // connFromCommand starts the reference server process, connects to it and
163 // returns a recordingConn for the connection. The stdin return value is an
164 // opensslInput for the stdin of the child process. It must be closed before
165 // Waiting for child.
166 func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
167         cert := testRSACertificate
168         if len(test.cert) > 0 {
169                 cert = test.cert
170         }
171         certPath := tempFile(string(cert))
172         defer os.Remove(certPath)
173
174         var key any = testRSAPrivateKey
175         if test.key != nil {
176                 key = test.key
177         }
178         derBytes, err := x509.MarshalPKCS8PrivateKey(key)
179         if err != nil {
180                 panic(err)
181         }
182
183         var pemOut bytes.Buffer
184         pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes})
185
186         keyPath := tempFile(pemOut.String())
187         defer os.Remove(keyPath)
188
189         var command []string
190         command = append(command, serverCommand...)
191         command = append(command, test.args...)
192         command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
193         // serverPort contains the port that OpenSSL will listen on. OpenSSL
194         // can't take "0" as an argument here so we have to pick a number and
195         // hope that it's not in use on the machine. Since this only occurs
196         // when -update is given and thus when there's a human watching the
197         // test, this isn't too bad.
198         const serverPort = 24323
199         command = append(command, "-accept", strconv.Itoa(serverPort))
200
201         if len(test.extensions) > 0 {
202                 var serverInfo bytes.Buffer
203                 for _, ext := range test.extensions {
204                         pem.Encode(&serverInfo, &pem.Block{
205                                 Type:  fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
206                                 Bytes: ext,
207                         })
208                 }
209                 serverInfoPath := tempFile(serverInfo.String())
210                 defer os.Remove(serverInfoPath)
211                 command = append(command, "-serverinfo", serverInfoPath)
212         }
213
214         if test.numRenegotiations > 0 || test.sendKeyUpdate {
215                 found := false
216                 for _, flag := range command[1:] {
217                         if flag == "-state" {
218                                 found = true
219                                 break
220                         }
221                 }
222
223                 if !found {
224                         panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate")
225                 }
226         }
227
228         cmd := exec.Command(command[0], command[1:]...)
229         stdin = opensslInput(make(chan opensslInputEvent))
230         cmd.Stdin = stdin
231         out := newOpensslOutputSink()
232         cmd.Stdout = out
233         cmd.Stderr = out
234         if err := cmd.Start(); err != nil {
235                 return nil, nil, nil, nil, err
236         }
237
238         // OpenSSL does print an "ACCEPT" banner, but it does so *before*
239         // opening the listening socket, so we can't use that to wait until it
240         // has started listening. Thus we are forced to poll until we get a
241         // connection.
242         var tcpConn net.Conn
243         for i := uint(0); i < 5; i++ {
244                 tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
245                         IP:   net.IPv4(127, 0, 0, 1),
246                         Port: serverPort,
247                 })
248                 if err == nil {
249                         break
250                 }
251                 time.Sleep((1 << i) * 5 * time.Millisecond)
252         }
253         if err != nil {
254                 close(stdin)
255                 cmd.Process.Kill()
256                 err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out)
257                 return nil, nil, nil, nil, err
258         }
259
260         record := &recordingConn{
261                 Conn: tcpConn,
262         }
263
264         return record, cmd, stdin, out, nil
265 }
266
267 func (test *clientTest) dataPath() string {
268         return filepath.Join("testdata", "Client-"+test.name)
269 }
270
271 func (test *clientTest) loadData() (flows [][]byte, err error) {
272         in, err := os.Open(test.dataPath())
273         if err != nil {
274                 return nil, err
275         }
276         defer in.Close()
277         return parseTestData(in)
278 }
279
280 func (test *clientTest) run(t *testing.T, write bool) {
281         var clientConn, serverConn net.Conn
282         var recordingConn *recordingConn
283         var childProcess *exec.Cmd
284         var stdin opensslInput
285         var stdout *opensslOutputSink
286
287         if write {
288                 var err error
289                 recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
290                 if err != nil {
291                         t.Fatalf("Failed to start subcommand: %s", err)
292                 }
293                 clientConn = recordingConn
294                 defer func() {
295                         if t.Failed() {
296                                 t.Logf("OpenSSL output:\n\n%s", stdout.all)
297                         }
298                 }()
299         } else {
300                 clientConn, serverConn = localPipe(t)
301         }
302
303         doneChan := make(chan bool)
304         defer func() {
305                 clientConn.Close()
306                 <-doneChan
307         }()
308         go func() {
309                 defer close(doneChan)
310
311                 config := test.config
312                 if config == nil {
313                         config = testConfig
314                 }
315                 client := Client(clientConn, config)
316                 defer client.Close()
317
318                 if _, err := client.Write([]byte("hello\n")); err != nil {
319                         t.Errorf("Client.Write failed: %s", err)
320                         return
321                 }
322
323                 for i := 1; i <= test.numRenegotiations; i++ {
324                         // The initial handshake will generate a
325                         // handshakeComplete signal which needs to be quashed.
326                         if i == 1 && write {
327                                 <-stdout.handshakeComplete
328                         }
329
330                         // OpenSSL will try to interleave application data and
331                         // a renegotiation if we send both concurrently.
332                         // Therefore: ask OpensSSL to start a renegotiation, run
333                         // a goroutine to call client.Read and thus process the
334                         // renegotiation request, watch for OpenSSL's stdout to
335                         // indicate that the handshake is complete and,
336                         // finally, have OpenSSL write something to cause
337                         // client.Read to complete.
338                         if write {
339                                 stdin <- opensslRenegotiate
340                         }
341
342                         signalChan := make(chan struct{})
343
344                         go func() {
345                                 defer close(signalChan)
346
347                                 buf := make([]byte, 256)
348                                 n, err := client.Read(buf)
349
350                                 if test.checkRenegotiationError != nil {
351                                         newErr := test.checkRenegotiationError(i, err)
352                                         if err != nil && newErr == nil {
353                                                 return
354                                         }
355                                         err = newErr
356                                 }
357
358                                 if err != nil {
359                                         t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
360                                         return
361                                 }
362
363                                 buf = buf[:n]
364                                 if !bytes.Equal([]byte(opensslSentinel), buf) {
365                                         t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
366                                 }
367
368                                 if expected := i + 1; client.handshakes != expected {
369                                         t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
370                                 }
371                         }()
372
373                         if write && test.renegotiationExpectedToFail != i {
374                                 <-stdout.handshakeComplete
375                                 stdin <- opensslSendSentinel
376                         }
377                         <-signalChan
378                 }
379
380                 if test.sendKeyUpdate {
381                         if write {
382                                 <-stdout.handshakeComplete
383                                 stdin <- opensslKeyUpdate
384                         }
385
386                         doneRead := make(chan struct{})
387
388                         go func() {
389                                 defer close(doneRead)
390
391                                 buf := make([]byte, 256)
392                                 n, err := client.Read(buf)
393
394                                 if err != nil {
395                                         t.Errorf("Client.Read failed after KeyUpdate: %s", err)
396                                         return
397                                 }
398
399                                 buf = buf[:n]
400                                 if !bytes.Equal([]byte(opensslSentinel), buf) {
401                                         t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
402                                 }
403                         }()
404
405                         if write {
406                                 // There's no real reason to wait for the client KeyUpdate to
407                                 // send data with the new server keys, except that s_server
408                                 // drops writes if they are sent at the wrong time.
409                                 <-stdout.readKeyUpdate
410                                 stdin <- opensslSendSentinel
411                         }
412                         <-doneRead
413
414                         if _, err := client.Write([]byte("hello again\n")); err != nil {
415                                 t.Errorf("Client.Write failed: %s", err)
416                                 return
417                         }
418                 }
419
420                 if test.validate != nil {
421                         if err := test.validate(client.ConnectionState()); err != nil {
422                                 t.Errorf("validate callback returned error: %s", err)
423                         }
424                 }
425
426                 // If the server sent us an alert after our last flight, give it a
427                 // chance to arrive.
428                 if write && test.renegotiationExpectedToFail == 0 {
429                         if err := peekError(client); err != nil {
430                                 t.Errorf("final Read returned an error: %s", err)
431                         }
432                 }
433         }()
434
435         if !write {
436                 flows, err := test.loadData()
437                 if err != nil {
438                         t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
439                 }
440                 for i, b := range flows {
441                         if i%2 == 1 {
442                                 if *fast {
443                                         serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
444                                 } else {
445                                         serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
446                                 }
447                                 serverConn.Write(b)
448                                 continue
449                         }
450                         bb := make([]byte, len(b))
451                         if *fast {
452                                 serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
453                         } else {
454                                 serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
455                         }
456                         _, err := io.ReadFull(serverConn, bb)
457                         if err != nil {
458                                 t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
459                         }
460                         if !bytes.Equal(b, bb) {
461                                 t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
462                         }
463                 }
464         }
465
466         <-doneChan
467         if !write {
468                 serverConn.Close()
469         }
470
471         if write {
472                 path := test.dataPath()
473                 out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
474                 if err != nil {
475                         t.Fatalf("Failed to create output file: %s", err)
476                 }
477                 defer out.Close()
478                 recordingConn.Close()
479                 close(stdin)
480                 childProcess.Process.Kill()
481                 childProcess.Wait()
482                 if len(recordingConn.flows) < 3 {
483                         t.Fatalf("Client connection didn't work")
484                 }
485                 recordingConn.WriteTo(out)
486                 t.Logf("Wrote %s\n", path)
487         }
488 }
489
490 // peekError does a read with a short timeout to check if the next read would
491 // cause an error, for example if there is an alert waiting on the wire.
492 func peekError(conn net.Conn) error {
493         conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
494         if n, err := conn.Read(make([]byte, 1)); n != 0 {
495                 return errors.New("unexpectedly read data")
496         } else if err != nil {
497                 if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
498                         return err
499                 }
500         }
501         return nil
502 }
503
504 func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) {
505         // Make a deep copy of the template before going parallel.
506         test := *template
507         if template.config != nil {
508                 test.config = template.config.Clone()
509         }
510         test.name = version + "-" + test.name
511         test.args = append([]string{option}, test.args...)
512
513         runTestAndUpdateIfNeeded(t, version, test.run, false)
514 }
515
516 func runClientTestTLS10(t *testing.T, template *clientTest) {
517         runClientTestForVersion(t, template, "TLSv10", "-tls1")
518 }
519
520 func runClientTestTLS11(t *testing.T, template *clientTest) {
521         runClientTestForVersion(t, template, "TLSv11", "-tls1_1")
522 }
523
524 func runClientTestTLS12(t *testing.T, template *clientTest) {
525         runClientTestForVersion(t, template, "TLSv12", "-tls1_2")
526 }
527
528 func runClientTestTLS13(t *testing.T, template *clientTest) {
529         runClientTestForVersion(t, template, "TLSv13", "-tls1_3")
530 }
531
532 func TestHandshakeClientRSARC4(t *testing.T) {
533         test := &clientTest{
534                 name: "RSA-RC4",
535                 args: []string{"-cipher", "RC4-SHA"},
536         }
537         runClientTestTLS10(t, test)
538         runClientTestTLS11(t, test)
539         runClientTestTLS12(t, test)
540 }
541
542 func TestHandshakeClientRSAAES128GCM(t *testing.T) {
543         test := &clientTest{
544                 name: "AES128-GCM-SHA256",
545                 args: []string{"-cipher", "AES128-GCM-SHA256"},
546         }
547         runClientTestTLS12(t, test)
548 }
549
550 func TestHandshakeClientRSAAES256GCM(t *testing.T) {
551         test := &clientTest{
552                 name: "AES256-GCM-SHA384",
553                 args: []string{"-cipher", "AES256-GCM-SHA384"},
554         }
555         runClientTestTLS12(t, test)
556 }
557
558 func TestHandshakeClientECDHERSAAES(t *testing.T) {
559         test := &clientTest{
560                 name: "ECDHE-RSA-AES",
561                 args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"},
562         }
563         runClientTestTLS10(t, test)
564         runClientTestTLS11(t, test)
565         runClientTestTLS12(t, test)
566 }
567
568 func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
569         test := &clientTest{
570                 name: "ECDHE-ECDSA-AES",
571                 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"},
572                 cert: testECDSACertificate,
573                 key:  testECDSAPrivateKey,
574         }
575         runClientTestTLS10(t, test)
576         runClientTestTLS11(t, test)
577         runClientTestTLS12(t, test)
578 }
579
580 func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
581         test := &clientTest{
582                 name: "ECDHE-ECDSA-AES-GCM",
583                 args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
584                 cert: testECDSACertificate,
585                 key:  testECDSAPrivateKey,
586         }
587         runClientTestTLS12(t, test)
588 }
589
590 func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
591         test := &clientTest{
592                 name: "ECDHE-ECDSA-AES256-GCM-SHA384",
593                 args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
594                 cert: testECDSACertificate,
595                 key:  testECDSAPrivateKey,
596         }
597         runClientTestTLS12(t, test)
598 }
599
600 func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
601         test := &clientTest{
602                 name: "AES128-SHA256",
603                 args: []string{"-cipher", "AES128-SHA256"},
604         }
605         runClientTestTLS12(t, test)
606 }
607
608 func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
609         test := &clientTest{
610                 name: "ECDHE-RSA-AES128-SHA256",
611                 args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"},
612         }
613         runClientTestTLS12(t, test)
614 }
615
616 func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
617         test := &clientTest{
618                 name: "ECDHE-ECDSA-AES128-SHA256",
619                 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"},
620                 cert: testECDSACertificate,
621                 key:  testECDSAPrivateKey,
622         }
623         runClientTestTLS12(t, test)
624 }
625
626 func TestHandshakeClientX25519(t *testing.T) {
627         config := testConfig.Clone()
628         config.CurvePreferences = []CurveID{X25519}
629
630         test := &clientTest{
631                 name:   "X25519-ECDHE",
632                 args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"},
633                 config: config,
634         }
635
636         runClientTestTLS12(t, test)
637         runClientTestTLS13(t, test)
638 }
639
640 func TestHandshakeClientP256(t *testing.T) {
641         config := testConfig.Clone()
642         config.CurvePreferences = []CurveID{CurveP256}
643
644         test := &clientTest{
645                 name:   "P256-ECDHE",
646                 args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
647                 config: config,
648         }
649
650         runClientTestTLS12(t, test)
651         runClientTestTLS13(t, test)
652 }
653
654 func TestHandshakeClientHelloRetryRequest(t *testing.T) {
655         config := testConfig.Clone()
656         config.CurvePreferences = []CurveID{X25519, CurveP256}
657
658         test := &clientTest{
659                 name:   "HelloRetryRequest",
660                 args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
661                 config: config,
662         }
663
664         runClientTestTLS13(t, test)
665 }
666
667 func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
668         config := testConfig.Clone()
669         config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
670
671         test := &clientTest{
672                 name:   "ECDHE-RSA-CHACHA20-POLY1305",
673                 args:   []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
674                 config: config,
675         }
676
677         runClientTestTLS12(t, test)
678 }
679
680 func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
681         config := testConfig.Clone()
682         config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
683
684         test := &clientTest{
685                 name:   "ECDHE-ECDSA-CHACHA20-POLY1305",
686                 args:   []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
687                 config: config,
688                 cert:   testECDSACertificate,
689                 key:    testECDSAPrivateKey,
690         }
691
692         runClientTestTLS12(t, test)
693 }
694
695 func TestHandshakeClientAES128SHA256(t *testing.T) {
696         test := &clientTest{
697                 name: "AES128-SHA256",
698                 args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"},
699         }
700         runClientTestTLS13(t, test)
701 }
702 func TestHandshakeClientAES256SHA384(t *testing.T) {
703         test := &clientTest{
704                 name: "AES256-SHA384",
705                 args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"},
706         }
707         runClientTestTLS13(t, test)
708 }
709 func TestHandshakeClientCHACHA20SHA256(t *testing.T) {
710         test := &clientTest{
711                 name: "CHACHA20-SHA256",
712                 args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"},
713         }
714         runClientTestTLS13(t, test)
715 }
716
717 func TestHandshakeClientECDSATLS13(t *testing.T) {
718         test := &clientTest{
719                 name: "ECDSA",
720                 cert: testECDSACertificate,
721                 key:  testECDSAPrivateKey,
722         }
723         runClientTestTLS13(t, test)
724 }
725
726 func TestHandshakeClientEd25519(t *testing.T) {
727         test := &clientTest{
728                 name: "Ed25519",
729                 cert: testEd25519Certificate,
730                 key:  testEd25519PrivateKey,
731         }
732         runClientTestTLS12(t, test)
733         runClientTestTLS13(t, test)
734
735         config := testConfig.Clone()
736         cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM))
737         config.Certificates = []Certificate{cert}
738
739         test = &clientTest{
740                 name:   "ClientCert-Ed25519",
741                 args:   []string{"-Verify", "1"},
742                 config: config,
743         }
744
745         runClientTestTLS12(t, test)
746         runClientTestTLS13(t, test)
747 }
748
749 func TestHandshakeClientCertRSA(t *testing.T) {
750         config := testConfig.Clone()
751         cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
752         config.Certificates = []Certificate{cert}
753
754         test := &clientTest{
755                 name:   "ClientCert-RSA-RSA",
756                 args:   []string{"-cipher", "AES128", "-Verify", "1"},
757                 config: config,
758         }
759
760         runClientTestTLS10(t, test)
761         runClientTestTLS12(t, test)
762
763         test = &clientTest{
764                 name:   "ClientCert-RSA-ECDSA",
765                 args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
766                 config: config,
767                 cert:   testECDSACertificate,
768                 key:    testECDSAPrivateKey,
769         }
770
771         runClientTestTLS10(t, test)
772         runClientTestTLS12(t, test)
773         runClientTestTLS13(t, test)
774
775         test = &clientTest{
776                 name:   "ClientCert-RSA-AES256-GCM-SHA384",
777                 args:   []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"},
778                 config: config,
779                 cert:   testRSACertificate,
780                 key:    testRSAPrivateKey,
781         }
782
783         runClientTestTLS12(t, test)
784 }
785
786 func TestHandshakeClientCertECDSA(t *testing.T) {
787         config := testConfig.Clone()
788         cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
789         config.Certificates = []Certificate{cert}
790
791         test := &clientTest{
792                 name:   "ClientCert-ECDSA-RSA",
793                 args:   []string{"-cipher", "AES128", "-Verify", "1"},
794                 config: config,
795         }
796
797         runClientTestTLS10(t, test)
798         runClientTestTLS12(t, test)
799         runClientTestTLS13(t, test)
800
801         test = &clientTest{
802                 name:   "ClientCert-ECDSA-ECDSA",
803                 args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
804                 config: config,
805                 cert:   testECDSACertificate,
806                 key:    testECDSAPrivateKey,
807         }
808
809         runClientTestTLS10(t, test)
810         runClientTestTLS12(t, test)
811 }
812
813 // TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both
814 // client and server certificates. It also serves from both sides a certificate
815 // signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation
816 // works.
817 func TestHandshakeClientCertRSAPSS(t *testing.T) {
818         cert, err := x509.ParseCertificate(testRSAPSSCertificate)
819         if err != nil {
820                 panic(err)
821         }
822         rootCAs := x509.NewCertPool()
823         rootCAs.AddCert(cert)
824
825         config := testConfig.Clone()
826         // Use GetClientCertificate to bypass the client certificate selection logic.
827         config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) {
828                 return &Certificate{
829                         Certificate: [][]byte{testRSAPSSCertificate},
830                         PrivateKey:  testRSAPrivateKey,
831                 }, nil
832         }
833         config.RootCAs = rootCAs
834
835         test := &clientTest{
836                 name: "ClientCert-RSA-RSAPSS",
837                 args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
838                         "rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"},
839                 config: config,
840                 cert:   testRSAPSSCertificate,
841                 key:    testRSAPrivateKey,
842         }
843         runClientTestTLS12(t, test)
844         runClientTestTLS13(t, test)
845 }
846
847 func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) {
848         config := testConfig.Clone()
849         cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
850         config.Certificates = []Certificate{cert}
851
852         test := &clientTest{
853                 name: "ClientCert-RSA-RSAPKCS1v15",
854                 args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
855                         "rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"},
856                 config: config,
857         }
858
859         runClientTestTLS12(t, test)
860 }
861
862 func TestClientKeyUpdate(t *testing.T) {
863         test := &clientTest{
864                 name:          "KeyUpdate",
865                 args:          []string{"-state"},
866                 sendKeyUpdate: true,
867         }
868         runClientTestTLS13(t, test)
869 }
870
871 func TestResumption(t *testing.T) {
872         t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
873         t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
874 }
875
876 func testResumption(t *testing.T, version uint16) {
877         if testing.Short() {
878                 t.Skip("skipping in -short mode")
879         }
880         serverConfig := &Config{
881                 MaxVersion:   version,
882                 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
883                 Certificates: testConfig.Certificates,
884         }
885
886         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
887         if err != nil {
888                 panic(err)
889         }
890
891         rootCAs := x509.NewCertPool()
892         rootCAs.AddCert(issuer)
893
894         clientConfig := &Config{
895                 MaxVersion:         version,
896                 CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
897                 ClientSessionCache: NewLRUClientSessionCache(32),
898                 RootCAs:            rootCAs,
899                 ServerName:         "example.golang",
900         }
901
902         testResumeState := func(test string, didResume bool) {
903                 t.Helper()
904                 _, hs, err := testHandshake(t, clientConfig, serverConfig)
905                 if err != nil {
906                         t.Fatalf("%s: handshake failed: %s", test, err)
907                 }
908                 if hs.DidResume != didResume {
909                         t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
910                 }
911                 if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
912                         t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
913                 }
914                 if got, want := hs.ServerName, clientConfig.ServerName; got != want {
915                         t.Errorf("%s: server name %s, want %s", test, got, want)
916                 }
917         }
918
919         getTicket := func() []byte {
920                 return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.ticket
921         }
922         deleteTicket := func() {
923                 ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
924                 clientConfig.ClientSessionCache.Put(ticketKey, nil)
925         }
926         corruptTicket := func() {
927                 clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.session.secret[0] ^= 0xff
928         }
929         randomKey := func() [32]byte {
930                 var k [32]byte
931                 if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
932                         t.Fatalf("Failed to read new SessionTicketKey: %s", err)
933                 }
934                 return k
935         }
936
937         testResumeState("Handshake", false)
938         ticket := getTicket()
939         testResumeState("Resume", true)
940         if bytes.Equal(ticket, getTicket()) {
941                 t.Fatal("ticket didn't change after resumption")
942         }
943
944         // An old session ticket is replaced with a ticket encrypted with a fresh key.
945         ticket = getTicket()
946         serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
947         testResumeState("ResumeWithOldTicket", true)
948         if bytes.Equal(ticket, getTicket()) {
949                 t.Fatal("old first ticket matches the fresh one")
950         }
951
952         // Once the session master secret is expired, a full handshake should occur.
953         ticket = getTicket()
954         serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
955         testResumeState("ResumeWithExpiredTicket", false)
956         if bytes.Equal(ticket, getTicket()) {
957                 t.Fatal("expired first ticket matches the fresh one")
958         }
959
960         serverConfig.Time = func() time.Time { return time.Now() } // reset the time back
961         key1 := randomKey()
962         serverConfig.SetSessionTicketKeys([][32]byte{key1})
963
964         testResumeState("InvalidSessionTicketKey", false)
965         testResumeState("ResumeAfterInvalidSessionTicketKey", true)
966
967         key2 := randomKey()
968         serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
969         ticket = getTicket()
970         testResumeState("KeyChange", true)
971         if bytes.Equal(ticket, getTicket()) {
972                 t.Fatal("new ticket wasn't included while resuming")
973         }
974         testResumeState("KeyChangeFinish", true)
975
976         // Age the session ticket a bit, but not yet expired.
977         serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
978         testResumeState("OldSessionTicket", true)
979         ticket = getTicket()
980         // Expire the session ticket, which would force a full handshake.
981         serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
982         testResumeState("ExpiredSessionTicket", false)
983         if bytes.Equal(ticket, getTicket()) {
984                 t.Fatal("new ticket wasn't provided after old ticket expired")
985         }
986
987         // Age the session ticket a bit at a time, but don't expire it.
988         d := 0 * time.Hour
989         serverConfig.Time = func() time.Time { return time.Now().Add(d) }
990         deleteTicket()
991         testResumeState("GetFreshSessionTicket", false)
992         for i := 0; i < 13; i++ {
993                 d += 12 * time.Hour
994                 testResumeState("OldSessionTicket", true)
995         }
996         // Expire it (now a little more than 7 days) and make sure a full
997         // handshake occurs for TLS 1.2. Resumption should still occur for
998         // TLS 1.3 since the client should be using a fresh ticket sent over
999         // by the server.
1000         d += 12 * time.Hour
1001         if version == VersionTLS13 {
1002                 testResumeState("ExpiredSessionTicket", true)
1003         } else {
1004                 testResumeState("ExpiredSessionTicket", false)
1005         }
1006         if bytes.Equal(ticket, getTicket()) {
1007                 t.Fatal("new ticket wasn't provided after old ticket expired")
1008         }
1009
1010         // Reset serverConfig to ensure that calling SetSessionTicketKeys
1011         // before the serverConfig is used works.
1012         serverConfig = &Config{
1013                 MaxVersion:   version,
1014                 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
1015                 Certificates: testConfig.Certificates,
1016         }
1017         serverConfig.SetSessionTicketKeys([][32]byte{key2})
1018
1019         testResumeState("FreshConfig", true)
1020
1021         // In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF
1022         // hash matches. Also, Config.CipherSuites does not apply to TLS 1.3.
1023         if version != VersionTLS13 {
1024                 clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
1025                 testResumeState("DifferentCipherSuite", false)
1026                 testResumeState("DifferentCipherSuiteRecovers", true)
1027         }
1028
1029         deleteTicket()
1030         testResumeState("WithoutSessionTicket", false)
1031
1032         // In TLS 1.3, HelloRetryRequest is sent after incorrect key share.
1033         // See https://www.rfc-editor.org/rfc/rfc8446#page-14.
1034         if version == VersionTLS13 {
1035                 deleteTicket()
1036                 serverConfig = &Config{
1037                         // Use a different curve than the client to force a HelloRetryRequest.
1038                         CurvePreferences: []CurveID{CurveP521, CurveP384, CurveP256},
1039                         MaxVersion:       version,
1040                         Certificates:     testConfig.Certificates,
1041                 }
1042                 testResumeState("InitialHandshake", false)
1043                 testResumeState("WithHelloRetryRequest", true)
1044
1045                 // Reset serverConfig back.
1046                 serverConfig = &Config{
1047                         MaxVersion:   version,
1048                         CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
1049                         Certificates: testConfig.Certificates,
1050                 }
1051         }
1052
1053         // Session resumption should work when using client certificates
1054         deleteTicket()
1055         serverConfig.ClientCAs = rootCAs
1056         serverConfig.ClientAuth = RequireAndVerifyClientCert
1057         clientConfig.Certificates = serverConfig.Certificates
1058         testResumeState("InitialHandshake", false)
1059         testResumeState("WithClientCertificates", true)
1060         serverConfig.ClientAuth = NoClientCert
1061
1062         // Tickets should be removed from the session cache on TLS handshake
1063         // failure, and the client should recover from a corrupted PSK
1064         testResumeState("FetchTicketToCorrupt", false)
1065         corruptTicket()
1066         _, _, err = testHandshake(t, clientConfig, serverConfig)
1067         if err == nil {
1068                 t.Fatalf("handshake did not fail with a corrupted client secret")
1069         }
1070         testResumeState("AfterHandshakeFailure", false)
1071
1072         clientConfig.ClientSessionCache = nil
1073         testResumeState("WithoutSessionCache", false)
1074
1075         clientConfig.ClientSessionCache = &serializingClientCache{t: t}
1076         testResumeState("BeforeSerializingCache", false)
1077         testResumeState("WithSerializingCache", true)
1078 }
1079
1080 type serializingClientCache struct {
1081         t *testing.T
1082
1083         ticket, state []byte
1084 }
1085
1086 func (c *serializingClientCache) Get(sessionKey string) (session *ClientSessionState, ok bool) {
1087         if c.ticket == nil {
1088                 return nil, false
1089         }
1090         state, err := ParseSessionState(c.state)
1091         if err != nil {
1092                 c.t.Error(err)
1093                 return nil, false
1094         }
1095         cs, err := NewResumptionState(c.ticket, state)
1096         if err != nil {
1097                 c.t.Error(err)
1098                 return nil, false
1099         }
1100         return cs, true
1101 }
1102
1103 func (c *serializingClientCache) Put(sessionKey string, cs *ClientSessionState) {
1104         ticket, state, err := cs.ResumptionState()
1105         if err != nil {
1106                 c.t.Error(err)
1107                 return
1108         }
1109         stateBytes, err := state.Bytes()
1110         if err != nil {
1111                 c.t.Error(err)
1112                 return
1113         }
1114         c.ticket, c.state = ticket, stateBytes
1115 }
1116
1117 func TestLRUClientSessionCache(t *testing.T) {
1118         // Initialize cache of capacity 4.
1119         cache := NewLRUClientSessionCache(4)
1120         cs := make([]ClientSessionState, 6)
1121         keys := []string{"0", "1", "2", "3", "4", "5", "6"}
1122
1123         // Add 4 entries to the cache and look them up.
1124         for i := 0; i < 4; i++ {
1125                 cache.Put(keys[i], &cs[i])
1126         }
1127         for i := 0; i < 4; i++ {
1128                 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1129                         t.Fatalf("session cache failed lookup for added key: %s", keys[i])
1130                 }
1131         }
1132
1133         // Add 2 more entries to the cache. First 2 should be evicted.
1134         for i := 4; i < 6; i++ {
1135                 cache.Put(keys[i], &cs[i])
1136         }
1137         for i := 0; i < 2; i++ {
1138                 if s, ok := cache.Get(keys[i]); ok || s != nil {
1139                         t.Fatalf("session cache should have evicted key: %s", keys[i])
1140                 }
1141         }
1142
1143         // Touch entry 2. LRU should evict 3 next.
1144         cache.Get(keys[2])
1145         cache.Put(keys[0], &cs[0])
1146         if s, ok := cache.Get(keys[3]); ok || s != nil {
1147                 t.Fatalf("session cache should have evicted key 3")
1148         }
1149
1150         // Update entry 0 in place.
1151         cache.Put(keys[0], &cs[3])
1152         if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
1153                 t.Fatalf("session cache failed update for key 0")
1154         }
1155
1156         // Calling Put with a nil entry deletes the key.
1157         cache.Put(keys[0], nil)
1158         if _, ok := cache.Get(keys[0]); ok {
1159                 t.Fatalf("session cache failed to delete key 0")
1160         }
1161
1162         // Delete entry 2. LRU should keep 4 and 5
1163         cache.Put(keys[2], nil)
1164         if _, ok := cache.Get(keys[2]); ok {
1165                 t.Fatalf("session cache failed to delete key 4")
1166         }
1167         for i := 4; i < 6; i++ {
1168                 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1169                         t.Fatalf("session cache should not have deleted key: %s", keys[i])
1170                 }
1171         }
1172 }
1173
1174 func TestKeyLogTLS12(t *testing.T) {
1175         var serverBuf, clientBuf bytes.Buffer
1176
1177         clientConfig := testConfig.Clone()
1178         clientConfig.KeyLogWriter = &clientBuf
1179         clientConfig.MaxVersion = VersionTLS12
1180
1181         serverConfig := testConfig.Clone()
1182         serverConfig.KeyLogWriter = &serverBuf
1183         serverConfig.MaxVersion = VersionTLS12
1184
1185         c, s := localPipe(t)
1186         done := make(chan bool)
1187
1188         go func() {
1189                 defer close(done)
1190
1191                 if err := Server(s, serverConfig).Handshake(); err != nil {
1192                         t.Errorf("server: %s", err)
1193                         return
1194                 }
1195                 s.Close()
1196         }()
1197
1198         if err := Client(c, clientConfig).Handshake(); err != nil {
1199                 t.Fatalf("client: %s", err)
1200         }
1201
1202         c.Close()
1203         <-done
1204
1205         checkKeylogLine := func(side, loggedLine string) {
1206                 if len(loggedLine) == 0 {
1207                         t.Fatalf("%s: no keylog line was produced", side)
1208                 }
1209                 const expectedLen = 13 /* "CLIENT_RANDOM" */ +
1210                         1 /* space */ +
1211                         32*2 /* hex client nonce */ +
1212                         1 /* space */ +
1213                         48*2 /* hex master secret */ +
1214                         1 /* new line */
1215                 if len(loggedLine) != expectedLen {
1216                         t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
1217                 }
1218                 if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
1219                         t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
1220                 }
1221         }
1222
1223         checkKeylogLine("client", clientBuf.String())
1224         checkKeylogLine("server", serverBuf.String())
1225 }
1226
1227 func TestKeyLogTLS13(t *testing.T) {
1228         var serverBuf, clientBuf bytes.Buffer
1229
1230         clientConfig := testConfig.Clone()
1231         clientConfig.KeyLogWriter = &clientBuf
1232
1233         serverConfig := testConfig.Clone()
1234         serverConfig.KeyLogWriter = &serverBuf
1235
1236         c, s := localPipe(t)
1237         done := make(chan bool)
1238
1239         go func() {
1240                 defer close(done)
1241
1242                 if err := Server(s, serverConfig).Handshake(); err != nil {
1243                         t.Errorf("server: %s", err)
1244                         return
1245                 }
1246                 s.Close()
1247         }()
1248
1249         if err := Client(c, clientConfig).Handshake(); err != nil {
1250                 t.Fatalf("client: %s", err)
1251         }
1252
1253         c.Close()
1254         <-done
1255
1256         checkKeylogLines := func(side, loggedLines string) {
1257                 loggedLines = strings.TrimSpace(loggedLines)
1258                 lines := strings.Split(loggedLines, "\n")
1259                 if len(lines) != 4 {
1260                         t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines))
1261                 }
1262         }
1263
1264         checkKeylogLines("client", clientBuf.String())
1265         checkKeylogLines("server", serverBuf.String())
1266 }
1267
1268 func TestHandshakeClientALPNMatch(t *testing.T) {
1269         config := testConfig.Clone()
1270         config.NextProtos = []string{"proto2", "proto1"}
1271
1272         test := &clientTest{
1273                 name: "ALPN",
1274                 // Note that this needs OpenSSL 1.0.2 because that is the first
1275                 // version that supports the -alpn flag.
1276                 args:   []string{"-alpn", "proto1,proto2"},
1277                 config: config,
1278                 validate: func(state ConnectionState) error {
1279                         // The server's preferences should override the client.
1280                         if state.NegotiatedProtocol != "proto1" {
1281                                 return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
1282                         }
1283                         return nil
1284                 },
1285         }
1286         runClientTestTLS12(t, test)
1287         runClientTestTLS13(t, test)
1288 }
1289
1290 func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) {
1291         // This checks that the server can't select an application protocol that the
1292         // client didn't offer.
1293
1294         c, s := localPipe(t)
1295         errChan := make(chan error, 1)
1296
1297         go func() {
1298                 client := Client(c, &Config{
1299                         ServerName:   "foo",
1300                         CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1301                         NextProtos:   []string{"http", "something-else"},
1302                 })
1303                 errChan <- client.Handshake()
1304         }()
1305
1306         var header [5]byte
1307         if _, err := io.ReadFull(s, header[:]); err != nil {
1308                 t.Fatal(err)
1309         }
1310         recordLen := int(header[3])<<8 | int(header[4])
1311
1312         record := make([]byte, recordLen)
1313         if _, err := io.ReadFull(s, record); err != nil {
1314                 t.Fatal(err)
1315         }
1316
1317         serverHello := &serverHelloMsg{
1318                 vers:         VersionTLS12,
1319                 random:       make([]byte, 32),
1320                 cipherSuite:  TLS_RSA_WITH_AES_128_GCM_SHA256,
1321                 alpnProtocol: "how-about-this",
1322         }
1323         serverHelloBytes := mustMarshal(t, serverHello)
1324
1325         s.Write([]byte{
1326                 byte(recordTypeHandshake),
1327                 byte(VersionTLS12 >> 8),
1328                 byte(VersionTLS12 & 0xff),
1329                 byte(len(serverHelloBytes) >> 8),
1330                 byte(len(serverHelloBytes)),
1331         })
1332         s.Write(serverHelloBytes)
1333         s.Close()
1334
1335         if err := <-errChan; !strings.Contains(err.Error(), "server selected unadvertised ALPN protocol") {
1336                 t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1337         }
1338 }
1339
1340 // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
1341 const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
1342
1343 func TestHandshakClientSCTs(t *testing.T) {
1344         config := testConfig.Clone()
1345
1346         scts, err := base64.StdEncoding.DecodeString(sctsBase64)
1347         if err != nil {
1348                 t.Fatal(err)
1349         }
1350
1351         // Note that this needs OpenSSL 1.0.2 because that is the first
1352         // version that supports the -serverinfo flag.
1353         test := &clientTest{
1354                 name:       "SCT",
1355                 config:     config,
1356                 extensions: [][]byte{scts},
1357                 validate: func(state ConnectionState) error {
1358                         expectedSCTs := [][]byte{
1359                                 scts[8:125],
1360                                 scts[127:245],
1361                                 scts[247:],
1362                         }
1363                         if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
1364                                 return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
1365                         }
1366                         for i, expected := range expectedSCTs {
1367                                 if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
1368                                         return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
1369                                 }
1370                         }
1371                         return nil
1372                 },
1373         }
1374         runClientTestTLS12(t, test)
1375
1376         // TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only
1377         // supports ServerHello extensions.
1378 }
1379
1380 func TestRenegotiationRejected(t *testing.T) {
1381         config := testConfig.Clone()
1382         test := &clientTest{
1383                 name:                        "RenegotiationRejected",
1384                 args:                        []string{"-state"},
1385                 config:                      config,
1386                 numRenegotiations:           1,
1387                 renegotiationExpectedToFail: 1,
1388                 checkRenegotiationError: func(renegotiationNum int, err error) error {
1389                         if err == nil {
1390                                 return errors.New("expected error from renegotiation but got nil")
1391                         }
1392                         if !strings.Contains(err.Error(), "no renegotiation") {
1393                                 return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1394                         }
1395                         return nil
1396                 },
1397         }
1398         runClientTestTLS12(t, test)
1399 }
1400
1401 func TestRenegotiateOnce(t *testing.T) {
1402         config := testConfig.Clone()
1403         config.Renegotiation = RenegotiateOnceAsClient
1404
1405         test := &clientTest{
1406                 name:              "RenegotiateOnce",
1407                 args:              []string{"-state"},
1408                 config:            config,
1409                 numRenegotiations: 1,
1410         }
1411
1412         runClientTestTLS12(t, test)
1413 }
1414
1415 func TestRenegotiateTwice(t *testing.T) {
1416         config := testConfig.Clone()
1417         config.Renegotiation = RenegotiateFreelyAsClient
1418
1419         test := &clientTest{
1420                 name:              "RenegotiateTwice",
1421                 args:              []string{"-state"},
1422                 config:            config,
1423                 numRenegotiations: 2,
1424         }
1425
1426         runClientTestTLS12(t, test)
1427 }
1428
1429 func TestRenegotiateTwiceRejected(t *testing.T) {
1430         config := testConfig.Clone()
1431         config.Renegotiation = RenegotiateOnceAsClient
1432
1433         test := &clientTest{
1434                 name:                        "RenegotiateTwiceRejected",
1435                 args:                        []string{"-state"},
1436                 config:                      config,
1437                 numRenegotiations:           2,
1438                 renegotiationExpectedToFail: 2,
1439                 checkRenegotiationError: func(renegotiationNum int, err error) error {
1440                         if renegotiationNum == 1 {
1441                                 return err
1442                         }
1443
1444                         if err == nil {
1445                                 return errors.New("expected error from renegotiation but got nil")
1446                         }
1447                         if !strings.Contains(err.Error(), "no renegotiation") {
1448                                 return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1449                         }
1450                         return nil
1451                 },
1452         }
1453
1454         runClientTestTLS12(t, test)
1455 }
1456
1457 func TestHandshakeClientExportKeyingMaterial(t *testing.T) {
1458         test := &clientTest{
1459                 name:   "ExportKeyingMaterial",
1460                 config: testConfig.Clone(),
1461                 validate: func(state ConnectionState) error {
1462                         if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil {
1463                                 return fmt.Errorf("ExportKeyingMaterial failed: %v", err)
1464                         } else if len(km) != 42 {
1465                                 return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42)
1466                         }
1467                         return nil
1468                 },
1469         }
1470         runClientTestTLS10(t, test)
1471         runClientTestTLS12(t, test)
1472         runClientTestTLS13(t, test)
1473 }
1474
1475 var hostnameInSNITests = []struct {
1476         in, out string
1477 }{
1478         // Opaque string
1479         {"", ""},
1480         {"localhost", "localhost"},
1481         {"foo, bar, baz and qux", "foo, bar, baz and qux"},
1482
1483         // DNS hostname
1484         {"golang.org", "golang.org"},
1485         {"golang.org.", "golang.org"},
1486
1487         // Literal IPv4 address
1488         {"1.2.3.4", ""},
1489
1490         // Literal IPv6 address
1491         {"::1", ""},
1492         {"::1%lo0", ""}, // with zone identifier
1493         {"[::1]", ""},   // as per RFC 5952 we allow the [] style as IPv6 literal
1494         {"[::1%lo0]", ""},
1495 }
1496
1497 func TestHostnameInSNI(t *testing.T) {
1498         for _, tt := range hostnameInSNITests {
1499                 c, s := localPipe(t)
1500
1501                 go func(host string) {
1502                         Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
1503                 }(tt.in)
1504
1505                 var header [5]byte
1506                 if _, err := io.ReadFull(s, header[:]); err != nil {
1507                         t.Fatal(err)
1508                 }
1509                 recordLen := int(header[3])<<8 | int(header[4])
1510
1511                 record := make([]byte, recordLen)
1512                 if _, err := io.ReadFull(s, record[:]); err != nil {
1513                         t.Fatal(err)
1514                 }
1515
1516                 c.Close()
1517                 s.Close()
1518
1519                 var m clientHelloMsg
1520                 if !m.unmarshal(record) {
1521                         t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
1522                         continue
1523                 }
1524                 if tt.in != tt.out && m.serverName == tt.in {
1525                         t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
1526                 }
1527                 if m.serverName != tt.out {
1528                         t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
1529                 }
1530         }
1531 }
1532
1533 func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
1534         // This checks that the server can't select a cipher suite that the
1535         // client didn't offer. See #13174.
1536
1537         c, s := localPipe(t)
1538         errChan := make(chan error, 1)
1539
1540         go func() {
1541                 client := Client(c, &Config{
1542                         ServerName:   "foo",
1543                         CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1544                 })
1545                 errChan <- client.Handshake()
1546         }()
1547
1548         var header [5]byte
1549         if _, err := io.ReadFull(s, header[:]); err != nil {
1550                 t.Fatal(err)
1551         }
1552         recordLen := int(header[3])<<8 | int(header[4])
1553
1554         record := make([]byte, recordLen)
1555         if _, err := io.ReadFull(s, record); err != nil {
1556                 t.Fatal(err)
1557         }
1558
1559         // Create a ServerHello that selects a different cipher suite than the
1560         // sole one that the client offered.
1561         serverHello := &serverHelloMsg{
1562                 vers:        VersionTLS12,
1563                 random:      make([]byte, 32),
1564                 cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
1565         }
1566         serverHelloBytes := mustMarshal(t, serverHello)
1567
1568         s.Write([]byte{
1569                 byte(recordTypeHandshake),
1570                 byte(VersionTLS12 >> 8),
1571                 byte(VersionTLS12 & 0xff),
1572                 byte(len(serverHelloBytes) >> 8),
1573                 byte(len(serverHelloBytes)),
1574         })
1575         s.Write(serverHelloBytes)
1576         s.Close()
1577
1578         if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
1579                 t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1580         }
1581 }
1582
1583 func TestVerifyConnection(t *testing.T) {
1584         t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
1585         t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
1586 }
1587
1588 func testVerifyConnection(t *testing.T, version uint16) {
1589         checkFields := func(c ConnectionState, called *int, errorType string) error {
1590                 if c.Version != version {
1591                         return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
1592                 }
1593                 if c.HandshakeComplete {
1594                         return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
1595                 }
1596                 if c.ServerName != "example.golang" {
1597                         return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
1598                 }
1599                 if c.NegotiatedProtocol != "protocol1" {
1600                         return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
1601                 }
1602                 if c.CipherSuite == 0 {
1603                         return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
1604                 }
1605                 wantDidResume := false
1606                 if *called == 2 { // if this is the second time, then it should be a resumption
1607                         wantDidResume = true
1608                 }
1609                 if c.DidResume != wantDidResume {
1610                         return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
1611                 }
1612                 return nil
1613         }
1614
1615         tests := []struct {
1616                 name            string
1617                 configureServer func(*Config, *int)
1618                 configureClient func(*Config, *int)
1619         }{
1620                 {
1621                         name: "RequireAndVerifyClientCert",
1622                         configureServer: func(config *Config, called *int) {
1623                                 config.ClientAuth = RequireAndVerifyClientCert
1624                                 config.VerifyConnection = func(c ConnectionState) error {
1625                                         *called++
1626                                         if l := len(c.PeerCertificates); l != 1 {
1627                                                 return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
1628                                         }
1629                                         if len(c.VerifiedChains) == 0 {
1630                                                 return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
1631                                         }
1632                                         return checkFields(c, called, "server")
1633                                 }
1634                         },
1635                         configureClient: func(config *Config, called *int) {
1636                                 config.VerifyConnection = func(c ConnectionState) error {
1637                                         *called++
1638                                         if l := len(c.PeerCertificates); l != 1 {
1639                                                 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1640                                         }
1641                                         if len(c.VerifiedChains) == 0 {
1642                                                 return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
1643                                         }
1644                                         if c.DidResume {
1645                                                 return nil
1646                                                 // The SCTs and OCSP Response are dropped on resumption.
1647                                                 // See http://golang.org/issue/39075.
1648                                         }
1649                                         if len(c.OCSPResponse) == 0 {
1650                                                 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1651                                         }
1652                                         if len(c.SignedCertificateTimestamps) == 0 {
1653                                                 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1654                                         }
1655                                         return checkFields(c, called, "client")
1656                                 }
1657                         },
1658                 },
1659                 {
1660                         name: "InsecureSkipVerify",
1661                         configureServer: func(config *Config, called *int) {
1662                                 config.ClientAuth = RequireAnyClientCert
1663                                 config.InsecureSkipVerify = true
1664                                 config.VerifyConnection = func(c ConnectionState) error {
1665                                         *called++
1666                                         if l := len(c.PeerCertificates); l != 1 {
1667                                                 return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
1668                                         }
1669                                         if c.VerifiedChains != nil {
1670                                                 return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
1671                                         }
1672                                         return checkFields(c, called, "server")
1673                                 }
1674                         },
1675                         configureClient: func(config *Config, called *int) {
1676                                 config.InsecureSkipVerify = true
1677                                 config.VerifyConnection = func(c ConnectionState) error {
1678                                         *called++
1679                                         if l := len(c.PeerCertificates); l != 1 {
1680                                                 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1681                                         }
1682                                         if c.VerifiedChains != nil {
1683                                                 return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
1684                                         }
1685                                         if c.DidResume {
1686                                                 return nil
1687                                                 // The SCTs and OCSP Response are dropped on resumption.
1688                                                 // See http://golang.org/issue/39075.
1689                                         }
1690                                         if len(c.OCSPResponse) == 0 {
1691                                                 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1692                                         }
1693                                         if len(c.SignedCertificateTimestamps) == 0 {
1694                                                 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1695                                         }
1696                                         return checkFields(c, called, "client")
1697                                 }
1698                         },
1699                 },
1700                 {
1701                         name: "NoClientCert",
1702                         configureServer: func(config *Config, called *int) {
1703                                 config.ClientAuth = NoClientCert
1704                                 config.VerifyConnection = func(c ConnectionState) error {
1705                                         *called++
1706                                         return checkFields(c, called, "server")
1707                                 }
1708                         },
1709                         configureClient: func(config *Config, called *int) {
1710                                 config.VerifyConnection = func(c ConnectionState) error {
1711                                         *called++
1712                                         return checkFields(c, called, "client")
1713                                 }
1714                         },
1715                 },
1716                 {
1717                         name: "RequestClientCert",
1718                         configureServer: func(config *Config, called *int) {
1719                                 config.ClientAuth = RequestClientCert
1720                                 config.VerifyConnection = func(c ConnectionState) error {
1721                                         *called++
1722                                         return checkFields(c, called, "server")
1723                                 }
1724                         },
1725                         configureClient: func(config *Config, called *int) {
1726                                 config.Certificates = nil // clear the client cert
1727                                 config.VerifyConnection = func(c ConnectionState) error {
1728                                         *called++
1729                                         if l := len(c.PeerCertificates); l != 1 {
1730                                                 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1731                                         }
1732                                         if len(c.VerifiedChains) == 0 {
1733                                                 return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
1734                                         }
1735                                         if c.DidResume {
1736                                                 return nil
1737                                                 // The SCTs and OCSP Response are dropped on resumption.
1738                                                 // See http://golang.org/issue/39075.
1739                                         }
1740                                         if len(c.OCSPResponse) == 0 {
1741                                                 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1742                                         }
1743                                         if len(c.SignedCertificateTimestamps) == 0 {
1744                                                 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1745                                         }
1746                                         return checkFields(c, called, "client")
1747                                 }
1748                         },
1749                 },
1750         }
1751         for _, test := range tests {
1752                 issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1753                 if err != nil {
1754                         panic(err)
1755                 }
1756                 rootCAs := x509.NewCertPool()
1757                 rootCAs.AddCert(issuer)
1758
1759                 var serverCalled, clientCalled int
1760
1761                 serverConfig := &Config{
1762                         MaxVersion:   version,
1763                         Certificates: []Certificate{testConfig.Certificates[0]},
1764                         ClientCAs:    rootCAs,
1765                         NextProtos:   []string{"protocol1"},
1766                 }
1767                 serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1768                 serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
1769                 test.configureServer(serverConfig, &serverCalled)
1770
1771                 clientConfig := &Config{
1772                         MaxVersion:         version,
1773                         ClientSessionCache: NewLRUClientSessionCache(32),
1774                         RootCAs:            rootCAs,
1775                         ServerName:         "example.golang",
1776                         Certificates:       []Certificate{testConfig.Certificates[0]},
1777                         NextProtos:         []string{"protocol1"},
1778                 }
1779                 test.configureClient(clientConfig, &clientCalled)
1780
1781                 testHandshakeState := func(name string, didResume bool) {
1782                         _, hs, err := testHandshake(t, clientConfig, serverConfig)
1783                         if err != nil {
1784                                 t.Fatalf("%s: handshake failed: %s", name, err)
1785                         }
1786                         if hs.DidResume != didResume {
1787                                 t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
1788                         }
1789                         wantCalled := 1
1790                         if didResume {
1791                                 wantCalled = 2 // resumption would mean this is the second time it was called in this test
1792                         }
1793                         if clientCalled != wantCalled {
1794                                 t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
1795                         }
1796                         if serverCalled != wantCalled {
1797                                 t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
1798                         }
1799                 }
1800                 testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
1801                 testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
1802         }
1803 }
1804
1805 func TestVerifyPeerCertificate(t *testing.T) {
1806         t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
1807         t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
1808 }
1809
1810 func testVerifyPeerCertificate(t *testing.T, version uint16) {
1811         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1812         if err != nil {
1813                 panic(err)
1814         }
1815
1816         rootCAs := x509.NewCertPool()
1817         rootCAs.AddCert(issuer)
1818
1819         now := func() time.Time { return time.Unix(1476984729, 0) }
1820
1821         sentinelErr := errors.New("TestVerifyPeerCertificate")
1822
1823         verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1824                 if l := len(rawCerts); l != 1 {
1825                         return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1826                 }
1827                 if len(validatedChains) == 0 {
1828                         return errors.New("got len(validatedChains) = 0, wanted non-zero")
1829                 }
1830                 *called = true
1831                 return nil
1832         }
1833         verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
1834                 if l := len(c.PeerCertificates); l != 1 {
1835                         return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
1836                 }
1837                 if len(c.VerifiedChains) == 0 {
1838                         return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
1839                 }
1840                 if isClient && len(c.OCSPResponse) == 0 {
1841                         return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
1842                 }
1843                 *called = true
1844                 return nil
1845         }
1846
1847         tests := []struct {
1848                 configureServer func(*Config, *bool)
1849                 configureClient func(*Config, *bool)
1850                 validate        func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
1851         }{
1852                 {
1853                         configureServer: func(config *Config, called *bool) {
1854                                 config.InsecureSkipVerify = false
1855                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1856                                         return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1857                                 }
1858                         },
1859                         configureClient: func(config *Config, called *bool) {
1860                                 config.InsecureSkipVerify = false
1861                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1862                                         return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1863                                 }
1864                         },
1865                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1866                                 if clientErr != nil {
1867                                         t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1868                                 }
1869                                 if serverErr != nil {
1870                                         t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1871                                 }
1872                                 if !clientCalled {
1873                                         t.Errorf("test[%d]: client did not call callback", testNo)
1874                                 }
1875                                 if !serverCalled {
1876                                         t.Errorf("test[%d]: server did not call callback", testNo)
1877                                 }
1878                         },
1879                 },
1880                 {
1881                         configureServer: func(config *Config, called *bool) {
1882                                 config.InsecureSkipVerify = false
1883                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1884                                         return sentinelErr
1885                                 }
1886                         },
1887                         configureClient: func(config *Config, called *bool) {
1888                                 config.VerifyPeerCertificate = nil
1889                         },
1890                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1891                                 if serverErr != sentinelErr {
1892                                         t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1893                                 }
1894                         },
1895                 },
1896                 {
1897                         configureServer: func(config *Config, called *bool) {
1898                                 config.InsecureSkipVerify = false
1899                         },
1900                         configureClient: func(config *Config, called *bool) {
1901                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1902                                         return sentinelErr
1903                                 }
1904                         },
1905                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1906                                 if clientErr != sentinelErr {
1907                                         t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1908                                 }
1909                         },
1910                 },
1911                 {
1912                         configureServer: func(config *Config, called *bool) {
1913                                 config.InsecureSkipVerify = false
1914                         },
1915                         configureClient: func(config *Config, called *bool) {
1916                                 config.InsecureSkipVerify = true
1917                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1918                                         if l := len(rawCerts); l != 1 {
1919                                                 return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1920                                         }
1921                                         // With InsecureSkipVerify set, this
1922                                         // callback should still be called but
1923                                         // validatedChains must be empty.
1924                                         if l := len(validatedChains); l != 0 {
1925                                                 return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
1926                                         }
1927                                         *called = true
1928                                         return nil
1929                                 }
1930                         },
1931                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1932                                 if clientErr != nil {
1933                                         t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1934                                 }
1935                                 if serverErr != nil {
1936                                         t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1937                                 }
1938                                 if !clientCalled {
1939                                         t.Errorf("test[%d]: client did not call callback", testNo)
1940                                 }
1941                         },
1942                 },
1943                 {
1944                         configureServer: func(config *Config, called *bool) {
1945                                 config.InsecureSkipVerify = false
1946                                 config.VerifyConnection = func(c ConnectionState) error {
1947                                         return verifyConnectionCallback(called, false, c)
1948                                 }
1949                         },
1950                         configureClient: func(config *Config, called *bool) {
1951                                 config.InsecureSkipVerify = false
1952                                 config.VerifyConnection = func(c ConnectionState) error {
1953                                         return verifyConnectionCallback(called, true, c)
1954                                 }
1955                         },
1956                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1957                                 if clientErr != nil {
1958                                         t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1959                                 }
1960                                 if serverErr != nil {
1961                                         t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1962                                 }
1963                                 if !clientCalled {
1964                                         t.Errorf("test[%d]: client did not call callback", testNo)
1965                                 }
1966                                 if !serverCalled {
1967                                         t.Errorf("test[%d]: server did not call callback", testNo)
1968                                 }
1969                         },
1970                 },
1971                 {
1972                         configureServer: func(config *Config, called *bool) {
1973                                 config.InsecureSkipVerify = false
1974                                 config.VerifyConnection = func(c ConnectionState) error {
1975                                         return sentinelErr
1976                                 }
1977                         },
1978                         configureClient: func(config *Config, called *bool) {
1979                                 config.InsecureSkipVerify = false
1980                                 config.VerifyConnection = nil
1981                         },
1982                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1983                                 if serverErr != sentinelErr {
1984                                         t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1985                                 }
1986                         },
1987                 },
1988                 {
1989                         configureServer: func(config *Config, called *bool) {
1990                                 config.InsecureSkipVerify = false
1991                                 config.VerifyConnection = nil
1992                         },
1993                         configureClient: func(config *Config, called *bool) {
1994                                 config.InsecureSkipVerify = false
1995                                 config.VerifyConnection = func(c ConnectionState) error {
1996                                         return sentinelErr
1997                                 }
1998                         },
1999                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
2000                                 if clientErr != sentinelErr {
2001                                         t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
2002                                 }
2003                         },
2004                 },
2005                 {
2006                         configureServer: func(config *Config, called *bool) {
2007                                 config.InsecureSkipVerify = false
2008                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
2009                                         return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
2010                                 }
2011                                 config.VerifyConnection = func(c ConnectionState) error {
2012                                         return sentinelErr
2013                                 }
2014                         },
2015                         configureClient: func(config *Config, called *bool) {
2016                                 config.InsecureSkipVerify = false
2017                                 config.VerifyPeerCertificate = nil
2018                                 config.VerifyConnection = nil
2019                         },
2020                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
2021                                 if serverErr != sentinelErr {
2022                                         t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
2023                                 }
2024                                 if !serverCalled {
2025                                         t.Errorf("test[%d]: server did not call callback", testNo)
2026                                 }
2027                         },
2028                 },
2029                 {
2030                         configureServer: func(config *Config, called *bool) {
2031                                 config.InsecureSkipVerify = false
2032                                 config.VerifyPeerCertificate = nil
2033                                 config.VerifyConnection = nil
2034                         },
2035                         configureClient: func(config *Config, called *bool) {
2036                                 config.InsecureSkipVerify = false
2037                                 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
2038                                         return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
2039                                 }
2040                                 config.VerifyConnection = func(c ConnectionState) error {
2041                                         return sentinelErr
2042                                 }
2043                         },
2044                         validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
2045                                 if clientErr != sentinelErr {
2046                                         t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
2047                                 }
2048                                 if !clientCalled {
2049                                         t.Errorf("test[%d]: client did not call callback", testNo)
2050                                 }
2051                         },
2052                 },
2053         }
2054
2055         for i, test := range tests {
2056                 c, s := localPipe(t)
2057                 done := make(chan error)
2058
2059                 var clientCalled, serverCalled bool
2060
2061                 go func() {
2062                         config := testConfig.Clone()
2063                         config.ServerName = "example.golang"
2064                         config.ClientAuth = RequireAndVerifyClientCert
2065                         config.ClientCAs = rootCAs
2066                         config.Time = now
2067                         config.MaxVersion = version
2068                         config.Certificates = make([]Certificate, 1)
2069                         config.Certificates[0].Certificate = [][]byte{testRSACertificate}
2070                         config.Certificates[0].PrivateKey = testRSAPrivateKey
2071                         config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
2072                         config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
2073                         test.configureServer(config, &serverCalled)
2074
2075                         err = Server(s, config).Handshake()
2076                         s.Close()
2077                         done <- err
2078                 }()
2079
2080                 config := testConfig.Clone()
2081                 config.ServerName = "example.golang"
2082                 config.RootCAs = rootCAs
2083                 config.Time = now
2084                 config.MaxVersion = version
2085                 test.configureClient(config, &clientCalled)
2086                 clientErr := Client(c, config).Handshake()
2087                 c.Close()
2088                 serverErr := <-done
2089
2090                 test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
2091         }
2092 }
2093
2094 // brokenConn wraps a net.Conn and causes all Writes after a certain number to
2095 // fail with brokenConnErr.
2096 type brokenConn struct {
2097         net.Conn
2098
2099         // breakAfter is the number of successful writes that will be allowed
2100         // before all subsequent writes fail.
2101         breakAfter int
2102
2103         // numWrites is the number of writes that have been done.
2104         numWrites int
2105 }
2106
2107 // brokenConnErr is the error that brokenConn returns once exhausted.
2108 var brokenConnErr = errors.New("too many writes to brokenConn")
2109
2110 func (b *brokenConn) Write(data []byte) (int, error) {
2111         if b.numWrites >= b.breakAfter {
2112                 return 0, brokenConnErr
2113         }
2114
2115         b.numWrites++
2116         return b.Conn.Write(data)
2117 }
2118
2119 func TestFailedWrite(t *testing.T) {
2120         // Test that a write error during the handshake is returned.
2121         for _, breakAfter := range []int{0, 1} {
2122                 c, s := localPipe(t)
2123                 done := make(chan bool)
2124
2125                 go func() {
2126                         Server(s, testConfig).Handshake()
2127                         s.Close()
2128                         done <- true
2129                 }()
2130
2131                 brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
2132                 err := Client(brokenC, testConfig).Handshake()
2133                 if err != brokenConnErr {
2134                         t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
2135                 }
2136                 brokenC.Close()
2137
2138                 <-done
2139         }
2140 }
2141
2142 // writeCountingConn wraps a net.Conn and counts the number of Write calls.
2143 type writeCountingConn struct {
2144         net.Conn
2145
2146         // numWrites is the number of writes that have been done.
2147         numWrites int
2148 }
2149
2150 func (wcc *writeCountingConn) Write(data []byte) (int, error) {
2151         wcc.numWrites++
2152         return wcc.Conn.Write(data)
2153 }
2154
2155 func TestBuffering(t *testing.T) {
2156         t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) })
2157         t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) })
2158 }
2159
2160 func testBuffering(t *testing.T, version uint16) {
2161         c, s := localPipe(t)
2162         done := make(chan bool)
2163
2164         clientWCC := &writeCountingConn{Conn: c}
2165         serverWCC := &writeCountingConn{Conn: s}
2166
2167         go func() {
2168                 config := testConfig.Clone()
2169                 config.MaxVersion = version
2170                 Server(serverWCC, config).Handshake()
2171                 serverWCC.Close()
2172                 done <- true
2173         }()
2174
2175         err := Client(clientWCC, testConfig).Handshake()
2176         if err != nil {
2177                 t.Fatal(err)
2178         }
2179         clientWCC.Close()
2180         <-done
2181
2182         var expectedClient, expectedServer int
2183         if version == VersionTLS13 {
2184                 expectedClient = 2
2185                 expectedServer = 1
2186         } else {
2187                 expectedClient = 2
2188                 expectedServer = 2
2189         }
2190
2191         if n := clientWCC.numWrites; n != expectedClient {
2192                 t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n)
2193         }
2194
2195         if n := serverWCC.numWrites; n != expectedServer {
2196                 t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n)
2197         }
2198 }
2199
2200 func TestAlertFlushing(t *testing.T) {
2201         c, s := localPipe(t)
2202         done := make(chan bool)
2203
2204         clientWCC := &writeCountingConn{Conn: c}
2205         serverWCC := &writeCountingConn{Conn: s}
2206
2207         serverConfig := testConfig.Clone()
2208
2209         // Cause a signature-time error
2210         brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
2211         brokenKey.D = big.NewInt(42)
2212         serverConfig.Certificates = []Certificate{{
2213                 Certificate: [][]byte{testRSACertificate},
2214                 PrivateKey:  &brokenKey,
2215         }}
2216
2217         go func() {
2218                 Server(serverWCC, serverConfig).Handshake()
2219                 serverWCC.Close()
2220                 done <- true
2221         }()
2222
2223         err := Client(clientWCC, testConfig).Handshake()
2224         if err == nil {
2225                 t.Fatal("client unexpectedly returned no error")
2226         }
2227
2228         const expectedError = "remote error: tls: internal error"
2229         if e := err.Error(); !strings.Contains(e, expectedError) {
2230                 t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
2231         }
2232         clientWCC.Close()
2233         <-done
2234
2235         if n := serverWCC.numWrites; n != 1 {
2236                 t.Errorf("expected server handshake to complete with one write, but saw %d", n)
2237         }
2238 }
2239
2240 func TestHandshakeRace(t *testing.T) {
2241         if testing.Short() {
2242                 t.Skip("skipping in -short mode")
2243         }
2244         t.Parallel()
2245         // This test races a Read and Write to try and complete a handshake in
2246         // order to provide some evidence that there are no races or deadlocks
2247         // in the handshake locking.
2248         for i := 0; i < 32; i++ {
2249                 c, s := localPipe(t)
2250
2251                 go func() {
2252                         server := Server(s, testConfig)
2253                         if err := server.Handshake(); err != nil {
2254                                 panic(err)
2255                         }
2256
2257                         var request [1]byte
2258                         if n, err := server.Read(request[:]); err != nil || n != 1 {
2259                                 panic(err)
2260                         }
2261
2262                         server.Write(request[:])
2263                         server.Close()
2264                 }()
2265
2266                 startWrite := make(chan struct{})
2267                 startRead := make(chan struct{})
2268                 readDone := make(chan struct{}, 1)
2269
2270                 client := Client(c, testConfig)
2271                 go func() {
2272                         <-startWrite
2273                         var request [1]byte
2274                         client.Write(request[:])
2275                 }()
2276
2277                 go func() {
2278                         <-startRead
2279                         var reply [1]byte
2280                         if _, err := io.ReadFull(client, reply[:]); err != nil {
2281                                 panic(err)
2282                         }
2283                         c.Close()
2284                         readDone <- struct{}{}
2285                 }()
2286
2287                 if i&1 == 1 {
2288                         startWrite <- struct{}{}
2289                         startRead <- struct{}{}
2290                 } else {
2291                         startRead <- struct{}{}
2292                         startWrite <- struct{}{}
2293                 }
2294                 <-readDone
2295         }
2296 }
2297
2298 var getClientCertificateTests = []struct {
2299         setup               func(*Config, *Config)
2300         expectedClientError string
2301         verify              func(*testing.T, int, *ConnectionState)
2302 }{
2303         {
2304                 func(clientConfig, serverConfig *Config) {
2305                         // Returning a Certificate with no certificate data
2306                         // should result in an empty message being sent to the
2307                         // server.
2308                         serverConfig.ClientCAs = nil
2309                         clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2310                                 if len(cri.SignatureSchemes) == 0 {
2311                                         panic("empty SignatureSchemes")
2312                                 }
2313                                 if len(cri.AcceptableCAs) != 0 {
2314                                         panic("AcceptableCAs should have been empty")
2315                                 }
2316                                 return new(Certificate), nil
2317                         }
2318                 },
2319                 "",
2320                 func(t *testing.T, testNum int, cs *ConnectionState) {
2321                         if l := len(cs.PeerCertificates); l != 0 {
2322                                 t.Errorf("#%d: expected no certificates but got %d", testNum, l)
2323                         }
2324                 },
2325         },
2326         {
2327                 func(clientConfig, serverConfig *Config) {
2328                         // With TLS 1.1, the SignatureSchemes should be
2329                         // synthesised from the supported certificate types.
2330                         clientConfig.MaxVersion = VersionTLS11
2331                         clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2332                                 if len(cri.SignatureSchemes) == 0 {
2333                                         panic("empty SignatureSchemes")
2334                                 }
2335                                 return new(Certificate), nil
2336                         }
2337                 },
2338                 "",
2339                 func(t *testing.T, testNum int, cs *ConnectionState) {
2340                         if l := len(cs.PeerCertificates); l != 0 {
2341                                 t.Errorf("#%d: expected no certificates but got %d", testNum, l)
2342                         }
2343                 },
2344         },
2345         {
2346                 func(clientConfig, serverConfig *Config) {
2347                         // Returning an error should abort the handshake with
2348                         // that error.
2349                         clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2350                                 return nil, errors.New("GetClientCertificate")
2351                         }
2352                 },
2353                 "GetClientCertificate",
2354                 func(t *testing.T, testNum int, cs *ConnectionState) {
2355                 },
2356         },
2357         {
2358                 func(clientConfig, serverConfig *Config) {
2359                         clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2360                                 if len(cri.AcceptableCAs) == 0 {
2361                                         panic("empty AcceptableCAs")
2362                                 }
2363                                 cert := &Certificate{
2364                                         Certificate: [][]byte{testRSACertificate},
2365                                         PrivateKey:  testRSAPrivateKey,
2366                                 }
2367                                 return cert, nil
2368                         }
2369                 },
2370                 "",
2371                 func(t *testing.T, testNum int, cs *ConnectionState) {
2372                         if len(cs.VerifiedChains) == 0 {
2373                                 t.Errorf("#%d: expected some verified chains, but found none", testNum)
2374                         }
2375                 },
2376         },
2377 }
2378
2379 func TestGetClientCertificate(t *testing.T) {
2380         t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) })
2381         t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) })
2382 }
2383
2384 func testGetClientCertificate(t *testing.T, version uint16) {
2385         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
2386         if err != nil {
2387                 panic(err)
2388         }
2389
2390         for i, test := range getClientCertificateTests {
2391                 serverConfig := testConfig.Clone()
2392                 serverConfig.ClientAuth = VerifyClientCertIfGiven
2393                 serverConfig.RootCAs = x509.NewCertPool()
2394                 serverConfig.RootCAs.AddCert(issuer)
2395                 serverConfig.ClientCAs = serverConfig.RootCAs
2396                 serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
2397                 serverConfig.MaxVersion = version
2398
2399                 clientConfig := testConfig.Clone()
2400                 clientConfig.MaxVersion = version
2401
2402                 test.setup(clientConfig, serverConfig)
2403
2404                 type serverResult struct {
2405                         cs  ConnectionState
2406                         err error
2407                 }
2408
2409                 c, s := localPipe(t)
2410                 done := make(chan serverResult)
2411
2412                 go func() {
2413                         defer s.Close()
2414                         server := Server(s, serverConfig)
2415                         err := server.Handshake()
2416
2417                         var cs ConnectionState
2418                         if err == nil {
2419                                 cs = server.ConnectionState()
2420                         }
2421                         done <- serverResult{cs, err}
2422                 }()
2423
2424                 clientErr := Client(c, clientConfig).Handshake()
2425                 c.Close()
2426
2427                 result := <-done
2428
2429                 if clientErr != nil {
2430                         if len(test.expectedClientError) == 0 {
2431                                 t.Errorf("#%d: client error: %v", i, clientErr)
2432                         } else if got := clientErr.Error(); got != test.expectedClientError {
2433                                 t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
2434                         } else {
2435                                 test.verify(t, i, &result.cs)
2436                         }
2437                 } else if len(test.expectedClientError) > 0 {
2438                         t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
2439                 } else if err := result.err; err != nil {
2440                         t.Errorf("#%d: server error: %v", i, err)
2441                 } else {
2442                         test.verify(t, i, &result.cs)
2443                 }
2444         }
2445 }
2446
2447 func TestRSAPSSKeyError(t *testing.T) {
2448         // crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for
2449         // public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with
2450         // the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't
2451         // parse, or that they don't carry *rsa.PublicKey keys.
2452         b, _ := pem.Decode([]byte(`
2453 -----BEGIN CERTIFICATE-----
2454 MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK
2455 MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC
2456 AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3
2457 MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP
2458 ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z
2459 /a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5
2460 b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL
2461 QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou
2462 czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT
2463 JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz
2464 AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn
2465 OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME
2466 AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab
2467 sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z
2468 H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1
2469 KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ
2470 bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD
2471 HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi
2472 RwBA9Xk1KBNF
2473 -----END CERTIFICATE-----`))
2474         if b == nil {
2475                 t.Fatal("Failed to decode certificate")
2476         }
2477         cert, err := x509.ParseCertificate(b.Bytes)
2478         if err != nil {
2479                 return
2480         }
2481         if _, ok := cert.PublicKey.(*rsa.PublicKey); ok {
2482                 t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms")
2483         }
2484 }
2485
2486 func TestCloseClientConnectionOnIdleServer(t *testing.T) {
2487         clientConn, serverConn := localPipe(t)
2488         client := Client(clientConn, testConfig.Clone())
2489         go func() {
2490                 var b [1]byte
2491                 serverConn.Read(b[:])
2492                 client.Close()
2493         }()
2494         client.SetWriteDeadline(time.Now().Add(time.Minute))
2495         err := client.Handshake()
2496         if err != nil {
2497                 if err, ok := err.(net.Error); ok && err.Timeout() {
2498                         t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
2499                 }
2500         } else {
2501                 t.Errorf("Error expected, but no error returned")
2502         }
2503 }
2504
2505 func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error {
2506         defer func() { testingOnlyForceDowngradeCanary = false }()
2507         testingOnlyForceDowngradeCanary = true
2508
2509         clientConfig := testConfig.Clone()
2510         clientConfig.MaxVersion = clientVersion
2511         serverConfig := testConfig.Clone()
2512         serverConfig.MaxVersion = serverVersion
2513         _, _, err := testHandshake(t, clientConfig, serverConfig)
2514         return err
2515 }
2516
2517 func TestDowngradeCanary(t *testing.T) {
2518         if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil {
2519                 t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected")
2520         }
2521         if testing.Short() {
2522                 t.Skip("skipping the rest of the checks in short mode")
2523         }
2524         if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil {
2525                 t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected")
2526         }
2527         if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil {
2528                 t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected")
2529         }
2530         if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil {
2531                 t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected")
2532         }
2533         if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil {
2534                 t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected")
2535         }
2536         if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil {
2537                 t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3")
2538         }
2539         if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil {
2540                 t.Errorf("client didn't ignore expected TLS 1.2 canary")
2541         }
2542         if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil {
2543                 t.Errorf("client unexpectedly reacted to a canary in TLS 1.1")
2544         }
2545         if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil {
2546                 t.Errorf("client unexpectedly reacted to a canary in TLS 1.0")
2547         }
2548 }
2549
2550 func TestResumptionKeepsOCSPAndSCT(t *testing.T) {
2551         t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) })
2552         t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) })
2553 }
2554
2555 func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
2556         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
2557         if err != nil {
2558                 t.Fatalf("failed to parse test issuer")
2559         }
2560         roots := x509.NewCertPool()
2561         roots.AddCert(issuer)
2562         clientConfig := &Config{
2563                 MaxVersion:         ver,
2564                 ClientSessionCache: NewLRUClientSessionCache(32),
2565                 ServerName:         "example.golang",
2566                 RootCAs:            roots,
2567         }
2568         serverConfig := testConfig.Clone()
2569         serverConfig.MaxVersion = ver
2570         serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3}
2571         serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}}
2572
2573         _, ccs, err := testHandshake(t, clientConfig, serverConfig)
2574         if err != nil {
2575                 t.Fatalf("handshake failed: %s", err)
2576         }
2577         // after a new session we expect to see OCSPResponse and
2578         // SignedCertificateTimestamps populated as usual
2579         if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
2580                 t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v",
2581                         serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
2582         }
2583         if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
2584                 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v",
2585                         serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
2586         }
2587
2588         // if the server doesn't send any SCTs, repopulate the old SCTs
2589         oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps
2590         serverConfig.Certificates[0].SignedCertificateTimestamps = nil
2591         _, ccs, err = testHandshake(t, clientConfig, serverConfig)
2592         if err != nil {
2593                 t.Fatalf("handshake failed: %s", err)
2594         }
2595         if !ccs.DidResume {
2596                 t.Fatalf("expected session to be resumed")
2597         }
2598         // after a resumed session we also expect to see OCSPResponse
2599         // and SignedCertificateTimestamps populated
2600         if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
2601                 t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v",
2602                         serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
2603         }
2604         if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) {
2605                 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
2606                         oldSCTs, ccs.SignedCertificateTimestamps)
2607         }
2608
2609         //  Only test overriding the SCTs for TLS 1.2, since in 1.3
2610         // the server won't send the message containing them
2611         if ver == VersionTLS13 {
2612                 return
2613         }
2614
2615         // if the server changes the SCTs it sends, they should override the saved SCTs
2616         serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}}
2617         _, ccs, err = testHandshake(t, clientConfig, serverConfig)
2618         if err != nil {
2619                 t.Fatalf("handshake failed: %s", err)
2620         }
2621         if !ccs.DidResume {
2622                 t.Fatalf("expected session to be resumed")
2623         }
2624         if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
2625                 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
2626                         serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
2627         }
2628 }
2629
2630 // TestClientHandshakeContextCancellation tests that canceling
2631 // the context given to the client side conn.HandshakeContext
2632 // interrupts the in-progress handshake.
2633 func TestClientHandshakeContextCancellation(t *testing.T) {
2634         c, s := localPipe(t)
2635         ctx, cancel := context.WithCancel(context.Background())
2636         unblockServer := make(chan struct{})
2637         defer close(unblockServer)
2638         go func() {
2639                 cancel()
2640                 <-unblockServer
2641                 _ = s.Close()
2642         }()
2643         cli := Client(c, testConfig)
2644         // Initiates client side handshake, which will block until the client hello is read
2645         // by the server, unless the cancellation works.
2646         err := cli.HandshakeContext(ctx)
2647         if err == nil {
2648                 t.Fatal("Client handshake did not error when the context was canceled")
2649         }
2650         if err != context.Canceled {
2651                 t.Errorf("Unexpected client handshake error: %v", err)
2652         }
2653         if runtime.GOARCH == "wasm" {
2654                 t.Skip("conn.Close does not error as expected when called multiple times on WASM")
2655         }
2656         err = cli.Close()
2657         if err == nil {
2658                 t.Error("Client connection was not closed when the context was canceled")
2659         }
2660 }