]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/tls_test.go
all: fix incorrect channel and API usage in some unit tests
[gostls13.git] / src / crypto / tls / tls_test.go
1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "crypto"
10         "crypto/x509"
11         "encoding/json"
12         "errors"
13         "fmt"
14         "internal/testenv"
15         "io"
16         "io/ioutil"
17         "math"
18         "net"
19         "os"
20         "reflect"
21         "strings"
22         "testing"
23         "time"
24 )
25
26 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
27 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
28 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
29 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
30 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
31 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
32 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
33 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
34 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
35 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
36 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
37 -----END CERTIFICATE-----
38 `
39
40 var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY-----
41 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
42 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
43 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
44 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
45 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
46 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
47 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
48 -----END RSA TESTING KEY-----
49 `)
50
51 // keyPEM is the same as rsaKeyPEM, but declares itself as just
52 // "PRIVATE KEY", not "RSA PRIVATE KEY".  https://golang.org/issue/4477
53 var keyPEM = testingKey(`-----BEGIN TESTING KEY-----
54 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
55 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
56 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
57 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
58 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
59 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
60 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
61 -----END TESTING KEY-----
62 `)
63
64 var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
65 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
66 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
67 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG
68 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk
69 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR
70 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl
71 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8
72 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo
73 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb
74 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1
75 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA==
76 -----END CERTIFICATE-----
77 `
78
79 var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS-----
80 BgUrgQQAIw==
81 -----END EC PARAMETERS-----
82 -----BEGIN EC TESTING KEY-----
83 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
84 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
85 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
86 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
87 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
88 -----END EC TESTING KEY-----
89 `)
90
91 var keyPairTests = []struct {
92         algo string
93         cert string
94         key  string
95 }{
96         {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM},
97         {"RSA", rsaCertPEM, rsaKeyPEM},
98         {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477
99 }
100
101 func TestX509KeyPair(t *testing.T) {
102         t.Parallel()
103         var pem []byte
104         for _, test := range keyPairTests {
105                 pem = []byte(test.cert + test.key)
106                 if _, err := X509KeyPair(pem, pem); err != nil {
107                         t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err)
108                 }
109                 pem = []byte(test.key + test.cert)
110                 if _, err := X509KeyPair(pem, pem); err != nil {
111                         t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err)
112                 }
113         }
114 }
115
116 func TestX509KeyPairErrors(t *testing.T) {
117         _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM))
118         if err == nil {
119                 t.Fatalf("X509KeyPair didn't return an error when arguments were switched")
120         }
121         if subStr := "been switched"; !strings.Contains(err.Error(), subStr) {
122                 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err)
123         }
124
125         _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM))
126         if err == nil {
127                 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates")
128         }
129         if subStr := "certificate"; !strings.Contains(err.Error(), subStr) {
130                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err)
131         }
132
133         const nonsensePEM = `
134 -----BEGIN NONSENSE-----
135 Zm9vZm9vZm9v
136 -----END NONSENSE-----
137 `
138
139         _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM))
140         if err == nil {
141                 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense")
142         }
143         if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) {
144                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err)
145         }
146 }
147
148 func TestX509MixedKeyPair(t *testing.T) {
149         if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil {
150                 t.Error("Load of RSA certificate succeeded with ECDSA private key")
151         }
152         if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil {
153                 t.Error("Load of ECDSA certificate succeeded with RSA private key")
154         }
155 }
156
157 func newLocalListener(t testing.TB) net.Listener {
158         ln, err := net.Listen("tcp", "127.0.0.1:0")
159         if err != nil {
160                 ln, err = net.Listen("tcp6", "[::1]:0")
161         }
162         if err != nil {
163                 t.Fatal(err)
164         }
165         return ln
166 }
167
168 func TestDialTimeout(t *testing.T) {
169         if testing.Short() {
170                 t.Skip("skipping in short mode")
171         }
172         listener := newLocalListener(t)
173
174         addr := listener.Addr().String()
175         defer listener.Close()
176
177         complete := make(chan bool)
178         defer close(complete)
179
180         go func() {
181                 conn, err := listener.Accept()
182                 if err != nil {
183                         t.Error(err)
184                         return
185                 }
186                 <-complete
187                 conn.Close()
188         }()
189
190         dialer := &net.Dialer{
191                 Timeout: 10 * time.Millisecond,
192         }
193
194         var err error
195         if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
196                 t.Fatal("DialWithTimeout completed successfully")
197         }
198
199         if !isTimeoutError(err) {
200                 t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
201         }
202 }
203
204 func isTimeoutError(err error) bool {
205         if ne, ok := err.(net.Error); ok {
206                 return ne.Timeout()
207         }
208         return false
209 }
210
211 // tests that Conn.Read returns (non-zero, io.EOF) instead of
212 // (non-zero, nil) when a Close (alertCloseNotify) is sitting right
213 // behind the application data in the buffer.
214 func TestConnReadNonzeroAndEOF(t *testing.T) {
215         // This test is racy: it assumes that after a write to a
216         // localhost TCP connection, the peer TCP connection can
217         // immediately read it. Because it's racy, we skip this test
218         // in short mode, and then retry it several times with an
219         // increasing sleep in between our final write (via srv.Close
220         // below) and the following read.
221         if testing.Short() {
222                 t.Skip("skipping in short mode")
223         }
224         var err error
225         for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 {
226                 if err = testConnReadNonzeroAndEOF(t, delay); err == nil {
227                         return
228                 }
229         }
230         t.Error(err)
231 }
232
233 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error {
234         ln := newLocalListener(t)
235         defer ln.Close()
236
237         srvCh := make(chan *Conn, 1)
238         var serr error
239         go func() {
240                 sconn, err := ln.Accept()
241                 if err != nil {
242                         serr = err
243                         srvCh <- nil
244                         return
245                 }
246                 serverConfig := testConfig.Clone()
247                 srv := Server(sconn, serverConfig)
248                 if err := srv.Handshake(); err != nil {
249                         serr = fmt.Errorf("handshake: %v", err)
250                         srvCh <- nil
251                         return
252                 }
253                 srvCh <- srv
254         }()
255
256         clientConfig := testConfig.Clone()
257         // In TLS 1.3, alerts are encrypted and disguised as application data, so
258         // the opportunistic peek won't work.
259         clientConfig.MaxVersion = VersionTLS12
260         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
261         if err != nil {
262                 t.Fatal(err)
263         }
264         defer conn.Close()
265
266         srv := <-srvCh
267         if srv == nil {
268                 return serr
269         }
270
271         buf := make([]byte, 6)
272
273         srv.Write([]byte("foobar"))
274         n, err := conn.Read(buf)
275         if n != 6 || err != nil || string(buf) != "foobar" {
276                 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
277         }
278
279         srv.Write([]byte("abcdef"))
280         srv.Close()
281         time.Sleep(delay)
282         n, err = conn.Read(buf)
283         if n != 6 || string(buf) != "abcdef" {
284                 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf)
285         }
286         if err != io.EOF {
287                 return fmt.Errorf("Second Read error = %v; want io.EOF", err)
288         }
289         return nil
290 }
291
292 func TestTLSUniqueMatches(t *testing.T) {
293         ln := newLocalListener(t)
294         defer ln.Close()
295
296         serverTLSUniques := make(chan []byte)
297         parentDone := make(chan struct{})
298         childDone := make(chan struct{})
299         defer close(parentDone)
300         go func() {
301                 defer close(childDone)
302                 for i := 0; i < 2; i++ {
303                         sconn, err := ln.Accept()
304                         if err != nil {
305                                 t.Error(err)
306                                 return
307                         }
308                         serverConfig := testConfig.Clone()
309                         serverConfig.MaxVersion = VersionTLS12 // TLSUnique is not defined in TLS 1.3
310                         srv := Server(sconn, serverConfig)
311                         if err := srv.Handshake(); err != nil {
312                                 t.Error(err)
313                                 return
314                         }
315                         select {
316                         case <-parentDone:
317                                 return
318                         case serverTLSUniques <- srv.ConnectionState().TLSUnique:
319                         }
320                 }
321         }()
322
323         clientConfig := testConfig.Clone()
324         clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
325         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
326         if err != nil {
327                 t.Fatal(err)
328         }
329
330         var serverTLSUniquesValue []byte
331         select {
332         case <-childDone:
333                 return
334         case serverTLSUniquesValue = <-serverTLSUniques:
335         }
336
337         if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
338                 t.Error("client and server channel bindings differ")
339         }
340         conn.Close()
341
342         conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
343         if err != nil {
344                 t.Fatal(err)
345         }
346         defer conn.Close()
347         if !conn.ConnectionState().DidResume {
348                 t.Error("second session did not use resumption")
349         }
350
351         select {
352         case <-childDone:
353                 return
354         case serverTLSUniquesValue = <-serverTLSUniques:
355         }
356
357         if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
358                 t.Error("client and server channel bindings differ when session resumption is used")
359         }
360 }
361
362 func TestVerifyHostname(t *testing.T) {
363         testenv.MustHaveExternalNetwork(t)
364
365         c, err := Dial("tcp", "www.google.com:https", nil)
366         if err != nil {
367                 t.Fatal(err)
368         }
369         if err := c.VerifyHostname("www.google.com"); err != nil {
370                 t.Fatalf("verify www.google.com: %v", err)
371         }
372         if err := c.VerifyHostname("www.yahoo.com"); err == nil {
373                 t.Fatalf("verify www.yahoo.com succeeded")
374         }
375
376         c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true})
377         if err != nil {
378                 t.Fatal(err)
379         }
380         if err := c.VerifyHostname("www.google.com"); err == nil {
381                 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
382         }
383 }
384
385 func TestConnCloseBreakingWrite(t *testing.T) {
386         ln := newLocalListener(t)
387         defer ln.Close()
388
389         srvCh := make(chan *Conn, 1)
390         var serr error
391         var sconn net.Conn
392         go func() {
393                 var err error
394                 sconn, err = ln.Accept()
395                 if err != nil {
396                         serr = err
397                         srvCh <- nil
398                         return
399                 }
400                 serverConfig := testConfig.Clone()
401                 srv := Server(sconn, serverConfig)
402                 if err := srv.Handshake(); err != nil {
403                         serr = fmt.Errorf("handshake: %v", err)
404                         srvCh <- nil
405                         return
406                 }
407                 srvCh <- srv
408         }()
409
410         cconn, err := net.Dial("tcp", ln.Addr().String())
411         if err != nil {
412                 t.Fatal(err)
413         }
414         defer cconn.Close()
415
416         conn := &changeImplConn{
417                 Conn: cconn,
418         }
419
420         clientConfig := testConfig.Clone()
421         tconn := Client(conn, clientConfig)
422         if err := tconn.Handshake(); err != nil {
423                 t.Fatal(err)
424         }
425
426         srv := <-srvCh
427         if srv == nil {
428                 t.Fatal(serr)
429         }
430         defer sconn.Close()
431
432         connClosed := make(chan struct{})
433         conn.closeFunc = func() error {
434                 close(connClosed)
435                 return nil
436         }
437
438         inWrite := make(chan bool, 1)
439         var errConnClosed = errors.New("conn closed for test")
440         conn.writeFunc = func(p []byte) (n int, err error) {
441                 inWrite <- true
442                 <-connClosed
443                 return 0, errConnClosed
444         }
445
446         closeReturned := make(chan bool, 1)
447         go func() {
448                 <-inWrite
449                 tconn.Close() // test that this doesn't block forever.
450                 closeReturned <- true
451         }()
452
453         _, err = tconn.Write([]byte("foo"))
454         if err != errConnClosed {
455                 t.Errorf("Write error = %v; want errConnClosed", err)
456         }
457
458         <-closeReturned
459         if err := tconn.Close(); err != errClosed {
460                 t.Errorf("Close error = %v; want errClosed", err)
461         }
462 }
463
464 func TestConnCloseWrite(t *testing.T) {
465         ln := newLocalListener(t)
466         defer ln.Close()
467
468         clientDoneChan := make(chan struct{})
469
470         serverCloseWrite := func() error {
471                 sconn, err := ln.Accept()
472                 if err != nil {
473                         return fmt.Errorf("accept: %v", err)
474                 }
475                 defer sconn.Close()
476
477                 serverConfig := testConfig.Clone()
478                 srv := Server(sconn, serverConfig)
479                 if err := srv.Handshake(); err != nil {
480                         return fmt.Errorf("handshake: %v", err)
481                 }
482                 defer srv.Close()
483
484                 data, err := ioutil.ReadAll(srv)
485                 if err != nil {
486                         return err
487                 }
488                 if len(data) > 0 {
489                         return fmt.Errorf("Read data = %q; want nothing", data)
490                 }
491
492                 if err := srv.CloseWrite(); err != nil {
493                         return fmt.Errorf("server CloseWrite: %v", err)
494                 }
495
496                 // Wait for clientCloseWrite to finish, so we know we
497                 // tested the CloseWrite before we defer the
498                 // sconn.Close above, which would also cause the
499                 // client to unblock like CloseWrite.
500                 <-clientDoneChan
501                 return nil
502         }
503
504         clientCloseWrite := func() error {
505                 defer close(clientDoneChan)
506
507                 clientConfig := testConfig.Clone()
508                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
509                 if err != nil {
510                         return err
511                 }
512                 if err := conn.Handshake(); err != nil {
513                         return err
514                 }
515                 defer conn.Close()
516
517                 if err := conn.CloseWrite(); err != nil {
518                         return fmt.Errorf("client CloseWrite: %v", err)
519                 }
520
521                 if _, err := conn.Write([]byte{0}); err != errShutdown {
522                         return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
523                 }
524
525                 data, err := ioutil.ReadAll(conn)
526                 if err != nil {
527                         return err
528                 }
529                 if len(data) > 0 {
530                         return fmt.Errorf("Read data = %q; want nothing", data)
531                 }
532                 return nil
533         }
534
535         errChan := make(chan error, 2)
536
537         go func() { errChan <- serverCloseWrite() }()
538         go func() { errChan <- clientCloseWrite() }()
539
540         for i := 0; i < 2; i++ {
541                 select {
542                 case err := <-errChan:
543                         if err != nil {
544                                 t.Fatal(err)
545                         }
546                 case <-time.After(10 * time.Second):
547                         t.Fatal("deadlock")
548                 }
549         }
550
551         // Also test CloseWrite being called before the handshake is
552         // finished:
553         {
554                 ln2 := newLocalListener(t)
555                 defer ln2.Close()
556
557                 netConn, err := net.Dial("tcp", ln2.Addr().String())
558                 if err != nil {
559                         t.Fatal(err)
560                 }
561                 defer netConn.Close()
562                 conn := Client(netConn, testConfig.Clone())
563
564                 if err := conn.CloseWrite(); err != errEarlyCloseWrite {
565                         t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
566                 }
567         }
568 }
569
570 func TestWarningAlertFlood(t *testing.T) {
571         ln := newLocalListener(t)
572         defer ln.Close()
573
574         server := func() error {
575                 sconn, err := ln.Accept()
576                 if err != nil {
577                         return fmt.Errorf("accept: %v", err)
578                 }
579                 defer sconn.Close()
580
581                 serverConfig := testConfig.Clone()
582                 srv := Server(sconn, serverConfig)
583                 if err := srv.Handshake(); err != nil {
584                         return fmt.Errorf("handshake: %v", err)
585                 }
586                 defer srv.Close()
587
588                 _, err = ioutil.ReadAll(srv)
589                 if err == nil {
590                         return errors.New("unexpected lack of error from server")
591                 }
592                 const expected = "too many ignored"
593                 if str := err.Error(); !strings.Contains(str, expected) {
594                         return fmt.Errorf("expected error containing %q, but saw: %s", expected, str)
595                 }
596
597                 return nil
598         }
599
600         errChan := make(chan error, 1)
601         go func() { errChan <- server() }()
602
603         clientConfig := testConfig.Clone()
604         clientConfig.MaxVersion = VersionTLS12 // there are no warning alerts in TLS 1.3
605         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
606         if err != nil {
607                 t.Fatal(err)
608         }
609         defer conn.Close()
610         if err := conn.Handshake(); err != nil {
611                 t.Fatal(err)
612         }
613
614         for i := 0; i < maxUselessRecords+1; i++ {
615                 conn.sendAlert(alertNoRenegotiation)
616         }
617
618         if err := <-errChan; err != nil {
619                 t.Fatal(err)
620         }
621 }
622
623 func TestCloneFuncFields(t *testing.T) {
624         const expectedCount = 5
625         called := 0
626
627         c1 := Config{
628                 Time: func() time.Time {
629                         called |= 1 << 0
630                         return time.Time{}
631                 },
632                 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
633                         called |= 1 << 1
634                         return nil, nil
635                 },
636                 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) {
637                         called |= 1 << 2
638                         return nil, nil
639                 },
640                 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) {
641                         called |= 1 << 3
642                         return nil, nil
643                 },
644                 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
645                         called |= 1 << 4
646                         return nil
647                 },
648         }
649
650         c2 := c1.Clone()
651
652         c2.Time()
653         c2.GetCertificate(nil)
654         c2.GetClientCertificate(nil)
655         c2.GetConfigForClient(nil)
656         c2.VerifyPeerCertificate(nil, nil)
657
658         if called != (1<<expectedCount)-1 {
659                 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
660         }
661 }
662
663 func TestCloneNonFuncFields(t *testing.T) {
664         var c1 Config
665         v := reflect.ValueOf(&c1).Elem()
666
667         typ := v.Type()
668         for i := 0; i < typ.NumField(); i++ {
669                 f := v.Field(i)
670                 if !f.CanSet() {
671                         // unexported field; not cloned.
672                         continue
673                 }
674
675                 // testing/quick can't handle functions or interfaces and so
676                 // isn't used here.
677                 switch fn := typ.Field(i).Name; fn {
678                 case "Rand":
679                         f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
680                 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
681                         // DeepEqual can't compare functions. If you add a
682                         // function field to this list, you must also change
683                         // TestCloneFuncFields to ensure that the func field is
684                         // cloned.
685                 case "Certificates":
686                         f.Set(reflect.ValueOf([]Certificate{
687                                 {Certificate: [][]byte{{'b'}}},
688                         }))
689                 case "NameToCertificate":
690                         f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil}))
691                 case "RootCAs", "ClientCAs":
692                         f.Set(reflect.ValueOf(x509.NewCertPool()))
693                 case "ClientSessionCache":
694                         f.Set(reflect.ValueOf(NewLRUClientSessionCache(10)))
695                 case "KeyLogWriter":
696                         f.Set(reflect.ValueOf(io.Writer(os.Stdout)))
697                 case "NextProtos":
698                         f.Set(reflect.ValueOf([]string{"a", "b"}))
699                 case "ServerName":
700                         f.Set(reflect.ValueOf("b"))
701                 case "ClientAuth":
702                         f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
703                 case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
704                         f.Set(reflect.ValueOf(true))
705                 case "MinVersion", "MaxVersion":
706                         f.Set(reflect.ValueOf(uint16(VersionTLS12)))
707                 case "SessionTicketKey":
708                         f.Set(reflect.ValueOf([32]byte{}))
709                 case "CipherSuites":
710                         f.Set(reflect.ValueOf([]uint16{1, 2}))
711                 case "CurvePreferences":
712                         f.Set(reflect.ValueOf([]CurveID{CurveP256}))
713                 case "Renegotiation":
714                         f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
715                 default:
716                         t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
717                 }
718         }
719
720         c2 := c1.Clone()
721         // DeepEqual also compares unexported fields, thus c2 needs to have run
722         // serverInit in order to be DeepEqual to c1. Cloning it and discarding
723         // the result is sufficient.
724         c2.Clone()
725
726         if !reflect.DeepEqual(&c1, c2) {
727                 t.Errorf("clone failed to copy a field")
728         }
729 }
730
731 // changeImplConn is a net.Conn which can change its Write and Close
732 // methods.
733 type changeImplConn struct {
734         net.Conn
735         writeFunc func([]byte) (int, error)
736         closeFunc func() error
737 }
738
739 func (w *changeImplConn) Write(p []byte) (n int, err error) {
740         if w.writeFunc != nil {
741                 return w.writeFunc(p)
742         }
743         return w.Conn.Write(p)
744 }
745
746 func (w *changeImplConn) Close() error {
747         if w.closeFunc != nil {
748                 return w.closeFunc()
749         }
750         return w.Conn.Close()
751 }
752
753 func throughput(b *testing.B, version uint16, totalBytes int64, dynamicRecordSizingDisabled bool) {
754         ln := newLocalListener(b)
755         defer ln.Close()
756
757         N := b.N
758
759         // Less than 64KB because Windows appears to use a TCP rwin < 64KB.
760         // See Issue #15899.
761         const bufsize = 32 << 10
762
763         go func() {
764                 buf := make([]byte, bufsize)
765                 for i := 0; i < N; i++ {
766                         sconn, err := ln.Accept()
767                         if err != nil {
768                                 // panic rather than synchronize to avoid benchmark overhead
769                                 // (cannot call b.Fatal in goroutine)
770                                 panic(fmt.Errorf("accept: %v", err))
771                         }
772                         serverConfig := testConfig.Clone()
773                         serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers
774                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
775                         srv := Server(sconn, serverConfig)
776                         if err := srv.Handshake(); err != nil {
777                                 panic(fmt.Errorf("handshake: %v", err))
778                         }
779                         if _, err := io.CopyBuffer(srv, srv, buf); err != nil {
780                                 panic(fmt.Errorf("copy buffer: %v", err))
781                         }
782                 }
783         }()
784
785         b.SetBytes(totalBytes)
786         clientConfig := testConfig.Clone()
787         clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers
788         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
789         clientConfig.MaxVersion = version
790
791         buf := make([]byte, bufsize)
792         chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
793         for i := 0; i < N; i++ {
794                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
795                 if err != nil {
796                         b.Fatal(err)
797                 }
798                 for j := 0; j < chunks; j++ {
799                         _, err := conn.Write(buf)
800                         if err != nil {
801                                 b.Fatal(err)
802                         }
803                         _, err = io.ReadFull(conn, buf)
804                         if err != nil {
805                                 b.Fatal(err)
806                         }
807                 }
808                 conn.Close()
809         }
810 }
811
812 func BenchmarkThroughput(b *testing.B) {
813         for _, mode := range []string{"Max", "Dynamic"} {
814                 for size := 1; size <= 64; size <<= 1 {
815                         name := fmt.Sprintf("%sPacket/%dMB", mode, size)
816                         b.Run(name, func(b *testing.B) {
817                                 b.Run("TLSv12", func(b *testing.B) {
818                                         throughput(b, VersionTLS12, int64(size<<20), mode == "Max")
819                                 })
820                                 b.Run("TLSv13", func(b *testing.B) {
821                                         throughput(b, VersionTLS13, int64(size<<20), mode == "Max")
822                                 })
823                         })
824                 }
825         }
826 }
827
828 type slowConn struct {
829         net.Conn
830         bps int
831 }
832
833 func (c *slowConn) Write(p []byte) (int, error) {
834         if c.bps == 0 {
835                 panic("too slow")
836         }
837         t0 := time.Now()
838         wrote := 0
839         for wrote < len(p) {
840                 time.Sleep(100 * time.Microsecond)
841                 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8
842                 if allowed > len(p) {
843                         allowed = len(p)
844                 }
845                 if wrote < allowed {
846                         n, err := c.Conn.Write(p[wrote:allowed])
847                         wrote += n
848                         if err != nil {
849                                 return wrote, err
850                         }
851                 }
852         }
853         return len(p), nil
854 }
855
856 func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) {
857         ln := newLocalListener(b)
858         defer ln.Close()
859
860         N := b.N
861
862         go func() {
863                 for i := 0; i < N; i++ {
864                         sconn, err := ln.Accept()
865                         if err != nil {
866                                 // panic rather than synchronize to avoid benchmark overhead
867                                 // (cannot call b.Fatal in goroutine)
868                                 panic(fmt.Errorf("accept: %v", err))
869                         }
870                         serverConfig := testConfig.Clone()
871                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
872                         srv := Server(&slowConn{sconn, bps}, serverConfig)
873                         if err := srv.Handshake(); err != nil {
874                                 panic(fmt.Errorf("handshake: %v", err))
875                         }
876                         io.Copy(srv, srv)
877                 }
878         }()
879
880         clientConfig := testConfig.Clone()
881         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
882         clientConfig.MaxVersion = version
883
884         buf := make([]byte, 16384)
885         peek := make([]byte, 1)
886
887         for i := 0; i < N; i++ {
888                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
889                 if err != nil {
890                         b.Fatal(err)
891                 }
892                 // make sure we're connected and previous connection has stopped
893                 if _, err := conn.Write(buf[:1]); err != nil {
894                         b.Fatal(err)
895                 }
896                 if _, err := io.ReadFull(conn, peek); err != nil {
897                         b.Fatal(err)
898                 }
899                 if _, err := conn.Write(buf); err != nil {
900                         b.Fatal(err)
901                 }
902                 if _, err = io.ReadFull(conn, peek); err != nil {
903                         b.Fatal(err)
904                 }
905                 conn.Close()
906         }
907 }
908
909 func BenchmarkLatency(b *testing.B) {
910         for _, mode := range []string{"Max", "Dynamic"} {
911                 for _, kbps := range []int{200, 500, 1000, 2000, 5000} {
912                         name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps)
913                         b.Run(name, func(b *testing.B) {
914                                 b.Run("TLSv12", func(b *testing.B) {
915                                         latency(b, VersionTLS12, kbps*1000, mode == "Max")
916                                 })
917                                 b.Run("TLSv13", func(b *testing.B) {
918                                         latency(b, VersionTLS13, kbps*1000, mode == "Max")
919                                 })
920                         })
921                 }
922         }
923 }
924
925 func TestConnectionStateMarshal(t *testing.T) {
926         cs := &ConnectionState{}
927         _, err := json.Marshal(cs)
928         if err != nil {
929                 t.Errorf("json.Marshal failed on ConnectionState: %v", err)
930         }
931 }
932
933 func TestConnectionState(t *testing.T) {
934         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
935         if err != nil {
936                 panic(err)
937         }
938         rootCAs := x509.NewCertPool()
939         rootCAs.AddCert(issuer)
940
941         now := func() time.Time { return time.Unix(1476984729, 0) }
942
943         const alpnProtocol = "golang"
944         const serverName = "example.golang"
945         var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
946         var ocsp = []byte("dummy ocsp")
947
948         for _, v := range []uint16{VersionTLS12, VersionTLS13} {
949                 var name string
950                 switch v {
951                 case VersionTLS12:
952                         name = "TLSv12"
953                 case VersionTLS13:
954                         name = "TLSv13"
955                 }
956                 t.Run(name, func(t *testing.T) {
957                         config := &Config{
958                                 Time:         now,
959                                 Rand:         zeroSource{},
960                                 Certificates: make([]Certificate, 1),
961                                 MaxVersion:   v,
962                                 RootCAs:      rootCAs,
963                                 ClientCAs:    rootCAs,
964                                 ClientAuth:   RequireAndVerifyClientCert,
965                                 NextProtos:   []string{alpnProtocol},
966                                 ServerName:   serverName,
967                         }
968                         config.Certificates[0].Certificate = [][]byte{testRSACertificate}
969                         config.Certificates[0].PrivateKey = testRSAPrivateKey
970                         config.Certificates[0].SignedCertificateTimestamps = scts
971                         config.Certificates[0].OCSPStaple = ocsp
972
973                         ss, cs, err := testHandshake(t, config, config)
974                         if err != nil {
975                                 t.Fatalf("Handshake failed: %v", err)
976                         }
977
978                         if ss.Version != v || cs.Version != v {
979                                 t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v)
980                         }
981
982                         if !ss.HandshakeComplete || !cs.HandshakeComplete {
983                                 t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete)
984                         }
985
986                         if ss.DidResume || cs.DidResume {
987                                 t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume)
988                         }
989
990                         if ss.CipherSuite == 0 || cs.CipherSuite == 0 {
991                                 t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite)
992                         }
993
994                         if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol {
995                                 t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol)
996                         }
997
998                         if !cs.NegotiatedProtocolIsMutual {
999                                 t.Errorf("Got false NegotiatedProtocolIsMutual on the client side")
1000                         }
1001                         // NegotiatedProtocolIsMutual on the server side is unspecified.
1002
1003                         if ss.ServerName != serverName {
1004                                 t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
1005                         }
1006                         if cs.ServerName != "" {
1007                                 t.Errorf("Got unexpected server name on the client side")
1008                         }
1009
1010                         if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
1011                                 t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1)
1012                         }
1013
1014                         if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 {
1015                                 t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1)
1016                         } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 {
1017                                 t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2)
1018                         }
1019
1020                         if len(cs.SignedCertificateTimestamps) != 2 {
1021                                 t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2)
1022                         }
1023                         if !bytes.Equal(cs.OCSPResponse, ocsp) {
1024                                 t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp)
1025                         }
1026                         // Only TLS 1.3 supports OCSP and SCTs on client certs.
1027                         if v == VersionTLS13 {
1028                                 if len(ss.SignedCertificateTimestamps) != 2 {
1029                                         t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2)
1030                                 }
1031                                 if !bytes.Equal(ss.OCSPResponse, ocsp) {
1032                                         t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp)
1033                                 }
1034                         }
1035
1036                         if v == VersionTLS13 {
1037                                 if ss.TLSUnique != nil || cs.TLSUnique != nil {
1038                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique)
1039                                 }
1040                         } else {
1041                                 if ss.TLSUnique == nil || cs.TLSUnique == nil {
1042                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique)
1043                                 }
1044                         }
1045                 })
1046         }
1047 }
1048
1049 // Issue 28744: Ensure that we don't modify memory
1050 // that Config doesn't own such as Certificates.
1051 func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) {
1052         c0 := Certificate{
1053                 Certificate: [][]byte{testRSACertificate},
1054                 PrivateKey:  testRSAPrivateKey,
1055         }
1056         c1 := Certificate{
1057                 Certificate: [][]byte{testSNICertificate},
1058                 PrivateKey:  testRSAPrivateKey,
1059         }
1060         config := testConfig.Clone()
1061         config.Certificates = []Certificate{c0, c1}
1062
1063         config.BuildNameToCertificate()
1064         got := config.Certificates
1065         want := []Certificate{c0, c1}
1066         if !reflect.DeepEqual(got, want) {
1067                 t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want)
1068         }
1069 }
1070
1071 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
1072
1073 func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
1074         rsaCert := &Certificate{
1075                 Certificate: [][]byte{testRSACertificate},
1076                 PrivateKey:  testRSAPrivateKey,
1077         }
1078         pkcs1Cert := &Certificate{
1079                 Certificate:                  [][]byte{testRSACertificate},
1080                 PrivateKey:                   testRSAPrivateKey,
1081                 SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256},
1082         }
1083         ecdsaCert := &Certificate{
1084                 // ECDSA P-256 certificate
1085                 Certificate: [][]byte{testP256Certificate},
1086                 PrivateKey:  testP256PrivateKey,
1087         }
1088         ed25519Cert := &Certificate{
1089                 Certificate: [][]byte{testEd25519Certificate},
1090                 PrivateKey:  testEd25519PrivateKey,
1091         }
1092
1093         tests := []struct {
1094                 c       *Certificate
1095                 chi     *ClientHelloInfo
1096                 wantErr string
1097         }{
1098                 {rsaCert, &ClientHelloInfo{
1099                         ServerName:        "example.golang",
1100                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1101                         SupportedVersions: []uint16{VersionTLS13},
1102                 }, ""},
1103                 {ecdsaCert, &ClientHelloInfo{
1104                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1105                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1106                 }, ""},
1107                 {rsaCert, &ClientHelloInfo{
1108                         ServerName:        "example.com",
1109                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1110                         SupportedVersions: []uint16{VersionTLS13},
1111                 }, "not valid for requested server name"},
1112                 {ecdsaCert, &ClientHelloInfo{
1113                         SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1114                         SupportedVersions: []uint16{VersionTLS13},
1115                 }, "signature algorithms"},
1116                 {pkcs1Cert, &ClientHelloInfo{
1117                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1118                         SupportedVersions: []uint16{VersionTLS13},
1119                 }, "signature algorithms"},
1120
1121                 {rsaCert, &ClientHelloInfo{
1122                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1123                         SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1124                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1125                 }, "signature algorithms"},
1126                 {rsaCert, &ClientHelloInfo{
1127                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1128                         SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1129                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1130                         config: &Config{
1131                                 MaxVersion: VersionTLS12,
1132                         },
1133                 }, ""}, // Check that mutual version selection works.
1134
1135                 {ecdsaCert, &ClientHelloInfo{
1136                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1137                         SupportedCurves:   []CurveID{CurveP256},
1138                         SupportedPoints:   []uint8{pointFormatUncompressed},
1139                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1140                         SupportedVersions: []uint16{VersionTLS12},
1141                 }, ""},
1142                 {ecdsaCert, &ClientHelloInfo{
1143                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1144                         SupportedCurves:   []CurveID{CurveP256},
1145                         SupportedPoints:   []uint8{pointFormatUncompressed},
1146                         SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1147                         SupportedVersions: []uint16{VersionTLS12},
1148                 }, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme.
1149                 {ecdsaCert, &ClientHelloInfo{
1150                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1151                         SupportedCurves:   []CurveID{CurveP256},
1152                         SupportedPoints:   []uint8{pointFormatUncompressed},
1153                         SignatureSchemes:  nil,
1154                         SupportedVersions: []uint16{VersionTLS12},
1155                 }, ""}, // TLS 1.2 comes with default signature schemes.
1156                 {ecdsaCert, &ClientHelloInfo{
1157                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1158                         SupportedCurves:   []CurveID{CurveP256},
1159                         SupportedPoints:   []uint8{pointFormatUncompressed},
1160                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1161                         SupportedVersions: []uint16{VersionTLS12},
1162                 }, "cipher suite"},
1163                 {ecdsaCert, &ClientHelloInfo{
1164                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1165                         SupportedCurves:   []CurveID{CurveP256},
1166                         SupportedPoints:   []uint8{pointFormatUncompressed},
1167                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1168                         SupportedVersions: []uint16{VersionTLS12},
1169                         config: &Config{
1170                                 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1171                         },
1172                 }, "cipher suite"},
1173                 {ecdsaCert, &ClientHelloInfo{
1174                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1175                         SupportedCurves:   []CurveID{CurveP384},
1176                         SupportedPoints:   []uint8{pointFormatUncompressed},
1177                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1178                         SupportedVersions: []uint16{VersionTLS12},
1179                 }, "certificate curve"},
1180                 {ecdsaCert, &ClientHelloInfo{
1181                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1182                         SupportedCurves:   []CurveID{CurveP256},
1183                         SupportedPoints:   []uint8{1},
1184                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1185                         SupportedVersions: []uint16{VersionTLS12},
1186                 }, "doesn't support ECDHE"},
1187                 {ecdsaCert, &ClientHelloInfo{
1188                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1189                         SupportedCurves:   []CurveID{CurveP256},
1190                         SupportedPoints:   []uint8{pointFormatUncompressed},
1191                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1192                         SupportedVersions: []uint16{VersionTLS12},
1193                 }, "signature algorithms"},
1194
1195                 {ed25519Cert, &ClientHelloInfo{
1196                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1197                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1198                         SupportedPoints:   []uint8{pointFormatUncompressed},
1199                         SignatureSchemes:  []SignatureScheme{Ed25519},
1200                         SupportedVersions: []uint16{VersionTLS12},
1201                 }, ""},
1202                 {ed25519Cert, &ClientHelloInfo{
1203                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1204                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1205                         SupportedPoints:   []uint8{pointFormatUncompressed},
1206                         SignatureSchemes:  []SignatureScheme{Ed25519},
1207                         SupportedVersions: []uint16{VersionTLS10},
1208                 }, "doesn't support Ed25519"},
1209                 {ed25519Cert, &ClientHelloInfo{
1210                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1211                         SupportedCurves:   []CurveID{},
1212                         SupportedPoints:   []uint8{pointFormatUncompressed},
1213                         SignatureSchemes:  []SignatureScheme{Ed25519},
1214                         SupportedVersions: []uint16{VersionTLS12},
1215                 }, "doesn't support ECDHE"},
1216
1217                 {rsaCert, &ClientHelloInfo{
1218                         CipherSuites:      []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
1219                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1220                         SupportedPoints:   []uint8{pointFormatUncompressed},
1221                         SupportedVersions: []uint16{VersionTLS10},
1222                 }, ""},
1223                 {rsaCert, &ClientHelloInfo{
1224                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1225                         SupportedVersions: []uint16{VersionTLS12},
1226                 }, ""}, // static RSA fallback
1227         }
1228         for i, tt := range tests {
1229                 err := tt.chi.SupportsCertificate(tt.c)
1230                 switch {
1231                 case tt.wantErr == "" && err != nil:
1232                         t.Errorf("%d: unexpected error: %v", i, err)
1233                 case tt.wantErr != "" && err == nil:
1234                         t.Errorf("%d: unexpected success", i)
1235                 case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr):
1236                         t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr)
1237                 }
1238         }
1239 }
1240
1241 func TestCipherSuites(t *testing.T) {
1242         var lastID uint16
1243         for _, c := range CipherSuites() {
1244                 if lastID > c.ID {
1245                         t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1246                 } else {
1247                         lastID = c.ID
1248                 }
1249
1250                 if c.Insecure {
1251                         t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID)
1252                 }
1253         }
1254         lastID = 0
1255         for _, c := range InsecureCipherSuites() {
1256                 if lastID > c.ID {
1257                         t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1258                 } else {
1259                         lastID = c.ID
1260                 }
1261
1262                 if !c.Insecure {
1263                         t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID)
1264                 }
1265         }
1266
1267         cipherSuiteByID := func(id uint16) *CipherSuite {
1268                 for _, c := range CipherSuites() {
1269                         if c.ID == id {
1270                                 return c
1271                         }
1272                 }
1273                 for _, c := range InsecureCipherSuites() {
1274                         if c.ID == id {
1275                                 return c
1276                         }
1277                 }
1278                 return nil
1279         }
1280
1281         for _, c := range cipherSuites {
1282                 cc := cipherSuiteByID(c.id)
1283                 if cc == nil {
1284                         t.Errorf("%#04x: no CipherSuite entry", c.id)
1285                         continue
1286                 }
1287
1288                 if defaultOff := c.flags&suiteDefaultOff != 0; defaultOff != cc.Insecure {
1289                         t.Errorf("%#04x: Insecure %v, expected %v", c.id, cc.Insecure, defaultOff)
1290                 }
1291                 if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
1292                         t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1293                 } else if !tls12Only && len(cc.SupportedVersions) != 3 {
1294                         t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1295                 }
1296
1297                 if got := CipherSuiteName(c.id); got != cc.Name {
1298                         t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1299                 }
1300         }
1301         for _, c := range cipherSuitesTLS13 {
1302                 cc := cipherSuiteByID(c.id)
1303                 if cc == nil {
1304                         t.Errorf("%#04x: no CipherSuite entry", c.id)
1305                         continue
1306                 }
1307
1308                 if cc.Insecure {
1309                         t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure)
1310                 }
1311                 if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 {
1312                         t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1313                 }
1314
1315                 if got := CipherSuiteName(c.id); got != cc.Name {
1316                         t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1317                 }
1318         }
1319
1320         if got := CipherSuiteName(0xabc); got != "0x0ABC" {
1321                 t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
1322         }
1323 }
1324
1325 type brokenSigner struct{ crypto.Signer }
1326
1327 func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
1328         // Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded.
1329         return s.Signer.Sign(rand, digest, opts.HashFunc())
1330 }
1331
1332 // TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that
1333 // always makes PKCS#1 v1.5 signatures, so can't be used with RSA-PSS.
1334 func TestPKCS1OnlyCert(t *testing.T) {
1335         clientConfig := testConfig.Clone()
1336         clientConfig.Certificates = []Certificate{{
1337                 Certificate: [][]byte{testRSACertificate},
1338                 PrivateKey:  brokenSigner{testRSAPrivateKey},
1339         }}
1340         serverConfig := testConfig.Clone()
1341         serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS#1 v1.5
1342         serverConfig.ClientAuth = RequireAnyClientCert
1343
1344         // If RSA-PSS is selected, the handshake should fail.
1345         if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil {
1346                 t.Fatal("expected broken certificate to cause connection to fail")
1347         }
1348
1349         clientConfig.Certificates[0].SupportedSignatureAlgorithms =
1350                 []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}
1351
1352         // But if the certificate restricts supported algorithms, RSA-PSS should not
1353         // be selected, and the handshake should succeed.
1354         if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
1355                 t.Error(err)
1356         }
1357 }