]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/tls_test.go
crypto/tls: make cipher suite preference ordering automatic
[gostls13.git] / src / crypto / tls / tls_test.go
1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "context"
10         "crypto"
11         "crypto/x509"
12         "encoding/json"
13         "errors"
14         "fmt"
15         "internal/testenv"
16         "io"
17         "math"
18         "net"
19         "os"
20         "reflect"
21         "sort"
22         "strings"
23         "testing"
24         "time"
25 )
26
27 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
28 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
29 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
30 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
31 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
32 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
33 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
34 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
35 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
36 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
37 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
38 -----END CERTIFICATE-----
39 `
40
41 var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY-----
42 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
43 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
44 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
45 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
46 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
47 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
48 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
49 -----END RSA TESTING KEY-----
50 `)
51
52 // keyPEM is the same as rsaKeyPEM, but declares itself as just
53 // "PRIVATE KEY", not "RSA PRIVATE KEY".  https://golang.org/issue/4477
54 var keyPEM = testingKey(`-----BEGIN TESTING KEY-----
55 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
56 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
57 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
58 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
59 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
60 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
61 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
62 -----END TESTING KEY-----
63 `)
64
65 var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
66 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
67 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
68 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG
69 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk
70 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR
71 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl
72 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8
73 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo
74 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb
75 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1
76 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA==
77 -----END CERTIFICATE-----
78 `
79
80 var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS-----
81 BgUrgQQAIw==
82 -----END EC PARAMETERS-----
83 -----BEGIN EC TESTING KEY-----
84 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
85 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
86 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
87 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
88 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
89 -----END EC TESTING KEY-----
90 `)
91
92 var keyPairTests = []struct {
93         algo string
94         cert string
95         key  string
96 }{
97         {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM},
98         {"RSA", rsaCertPEM, rsaKeyPEM},
99         {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477
100 }
101
102 func TestX509KeyPair(t *testing.T) {
103         t.Parallel()
104         var pem []byte
105         for _, test := range keyPairTests {
106                 pem = []byte(test.cert + test.key)
107                 if _, err := X509KeyPair(pem, pem); err != nil {
108                         t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err)
109                 }
110                 pem = []byte(test.key + test.cert)
111                 if _, err := X509KeyPair(pem, pem); err != nil {
112                         t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err)
113                 }
114         }
115 }
116
117 func TestX509KeyPairErrors(t *testing.T) {
118         _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM))
119         if err == nil {
120                 t.Fatalf("X509KeyPair didn't return an error when arguments were switched")
121         }
122         if subStr := "been switched"; !strings.Contains(err.Error(), subStr) {
123                 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err)
124         }
125
126         _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM))
127         if err == nil {
128                 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates")
129         }
130         if subStr := "certificate"; !strings.Contains(err.Error(), subStr) {
131                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err)
132         }
133
134         const nonsensePEM = `
135 -----BEGIN NONSENSE-----
136 Zm9vZm9vZm9v
137 -----END NONSENSE-----
138 `
139
140         _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM))
141         if err == nil {
142                 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense")
143         }
144         if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) {
145                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err)
146         }
147 }
148
149 func TestX509MixedKeyPair(t *testing.T) {
150         if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil {
151                 t.Error("Load of RSA certificate succeeded with ECDSA private key")
152         }
153         if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil {
154                 t.Error("Load of ECDSA certificate succeeded with RSA private key")
155         }
156 }
157
158 func newLocalListener(t testing.TB) net.Listener {
159         ln, err := net.Listen("tcp", "127.0.0.1:0")
160         if err != nil {
161                 ln, err = net.Listen("tcp6", "[::1]:0")
162         }
163         if err != nil {
164                 t.Fatal(err)
165         }
166         return ln
167 }
168
169 func TestDialTimeout(t *testing.T) {
170         if testing.Short() {
171                 t.Skip("skipping in short mode")
172         }
173         listener := newLocalListener(t)
174
175         addr := listener.Addr().String()
176         defer listener.Close()
177
178         complete := make(chan bool)
179         defer close(complete)
180
181         go func() {
182                 conn, err := listener.Accept()
183                 if err != nil {
184                         t.Error(err)
185                         return
186                 }
187                 <-complete
188                 conn.Close()
189         }()
190
191         dialer := &net.Dialer{
192                 Timeout: 10 * time.Millisecond,
193         }
194
195         var err error
196         if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
197                 t.Fatal("DialWithTimeout completed successfully")
198         }
199
200         if !isTimeoutError(err) {
201                 t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
202         }
203 }
204
205 func TestDeadlineOnWrite(t *testing.T) {
206         if testing.Short() {
207                 t.Skip("skipping in short mode")
208         }
209
210         ln := newLocalListener(t)
211         defer ln.Close()
212
213         srvCh := make(chan *Conn, 1)
214
215         go func() {
216                 sconn, err := ln.Accept()
217                 if err != nil {
218                         srvCh <- nil
219                         return
220                 }
221                 srv := Server(sconn, testConfig.Clone())
222                 if err := srv.Handshake(); err != nil {
223                         srvCh <- nil
224                         return
225                 }
226                 srvCh <- srv
227         }()
228
229         clientConfig := testConfig.Clone()
230         clientConfig.MaxVersion = VersionTLS12
231         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
232         if err != nil {
233                 t.Fatal(err)
234         }
235         defer conn.Close()
236
237         srv := <-srvCh
238         if srv == nil {
239                 t.Error(err)
240         }
241
242         // Make sure the client/server is setup correctly and is able to do a typical Write/Read
243         buf := make([]byte, 6)
244         if _, err := srv.Write([]byte("foobar")); err != nil {
245                 t.Errorf("Write err: %v", err)
246         }
247         if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" {
248                 t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
249         }
250
251         // Set a deadline which should cause Write to timeout
252         if err = srv.SetDeadline(time.Now()); err != nil {
253                 t.Fatalf("SetDeadline(time.Now()) err: %v", err)
254         }
255         if _, err = srv.Write([]byte("should fail")); err == nil {
256                 t.Fatal("Write should have timed out")
257         }
258
259         // Clear deadline and make sure it still times out
260         if err = srv.SetDeadline(time.Time{}); err != nil {
261                 t.Fatalf("SetDeadline(time.Time{}) err: %v", err)
262         }
263         if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil {
264                 t.Fatal("Write which previously failed should still time out")
265         }
266
267         // Verify the error
268         if ne := err.(net.Error); ne.Temporary() != false {
269                 t.Error("Write timed out but incorrectly classified the error as Temporary")
270         }
271         if !isTimeoutError(err) {
272                 t.Error("Write timed out but did not classify the error as a Timeout")
273         }
274 }
275
276 type readerFunc func([]byte) (int, error)
277
278 func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
279
280 // TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
281 // (The other cases are all handled by the existing dial tests in this package, which
282 // all also flow through the same code shared code paths)
283 func TestDialer(t *testing.T) {
284         ln := newLocalListener(t)
285         defer ln.Close()
286
287         unblockServer := make(chan struct{}) // close-only
288         defer close(unblockServer)
289         go func() {
290                 conn, err := ln.Accept()
291                 if err != nil {
292                         return
293                 }
294                 defer conn.Close()
295                 <-unblockServer
296         }()
297
298         ctx, cancel := context.WithCancel(context.Background())
299         d := Dialer{Config: &Config{
300                 Rand: readerFunc(func(b []byte) (n int, err error) {
301                         // By the time crypto/tls wants randomness, that means it has a TCP
302                         // connection, so we're past the Dialer's dial and now blocked
303                         // in a handshake. Cancel our context and see if we get unstuck.
304                         // (Our TCP listener above never reads or writes, so the Handshake
305                         // would otherwise be stuck forever)
306                         cancel()
307                         return len(b), nil
308                 }),
309                 ServerName: "foo",
310         }}
311         _, err := d.DialContext(ctx, "tcp", ln.Addr().String())
312         if err != context.Canceled {
313                 t.Errorf("err = %v; want context.Canceled", err)
314         }
315 }
316
317 func isTimeoutError(err error) bool {
318         if ne, ok := err.(net.Error); ok {
319                 return ne.Timeout()
320         }
321         return false
322 }
323
324 // tests that Conn.Read returns (non-zero, io.EOF) instead of
325 // (non-zero, nil) when a Close (alertCloseNotify) is sitting right
326 // behind the application data in the buffer.
327 func TestConnReadNonzeroAndEOF(t *testing.T) {
328         // This test is racy: it assumes that after a write to a
329         // localhost TCP connection, the peer TCP connection can
330         // immediately read it. Because it's racy, we skip this test
331         // in short mode, and then retry it several times with an
332         // increasing sleep in between our final write (via srv.Close
333         // below) and the following read.
334         if testing.Short() {
335                 t.Skip("skipping in short mode")
336         }
337         var err error
338         for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 {
339                 if err = testConnReadNonzeroAndEOF(t, delay); err == nil {
340                         return
341                 }
342         }
343         t.Error(err)
344 }
345
346 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error {
347         ln := newLocalListener(t)
348         defer ln.Close()
349
350         srvCh := make(chan *Conn, 1)
351         var serr error
352         go func() {
353                 sconn, err := ln.Accept()
354                 if err != nil {
355                         serr = err
356                         srvCh <- nil
357                         return
358                 }
359                 serverConfig := testConfig.Clone()
360                 srv := Server(sconn, serverConfig)
361                 if err := srv.Handshake(); err != nil {
362                         serr = fmt.Errorf("handshake: %v", err)
363                         srvCh <- nil
364                         return
365                 }
366                 srvCh <- srv
367         }()
368
369         clientConfig := testConfig.Clone()
370         // In TLS 1.3, alerts are encrypted and disguised as application data, so
371         // the opportunistic peek won't work.
372         clientConfig.MaxVersion = VersionTLS12
373         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
374         if err != nil {
375                 t.Fatal(err)
376         }
377         defer conn.Close()
378
379         srv := <-srvCh
380         if srv == nil {
381                 return serr
382         }
383
384         buf := make([]byte, 6)
385
386         srv.Write([]byte("foobar"))
387         n, err := conn.Read(buf)
388         if n != 6 || err != nil || string(buf) != "foobar" {
389                 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
390         }
391
392         srv.Write([]byte("abcdef"))
393         srv.Close()
394         time.Sleep(delay)
395         n, err = conn.Read(buf)
396         if n != 6 || string(buf) != "abcdef" {
397                 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf)
398         }
399         if err != io.EOF {
400                 return fmt.Errorf("Second Read error = %v; want io.EOF", err)
401         }
402         return nil
403 }
404
405 func TestTLSUniqueMatches(t *testing.T) {
406         ln := newLocalListener(t)
407         defer ln.Close()
408
409         serverTLSUniques := make(chan []byte)
410         parentDone := make(chan struct{})
411         childDone := make(chan struct{})
412         defer close(parentDone)
413         go func() {
414                 defer close(childDone)
415                 for i := 0; i < 2; i++ {
416                         sconn, err := ln.Accept()
417                         if err != nil {
418                                 t.Error(err)
419                                 return
420                         }
421                         serverConfig := testConfig.Clone()
422                         serverConfig.MaxVersion = VersionTLS12 // TLSUnique is not defined in TLS 1.3
423                         srv := Server(sconn, serverConfig)
424                         if err := srv.Handshake(); err != nil {
425                                 t.Error(err)
426                                 return
427                         }
428                         select {
429                         case <-parentDone:
430                                 return
431                         case serverTLSUniques <- srv.ConnectionState().TLSUnique:
432                         }
433                 }
434         }()
435
436         clientConfig := testConfig.Clone()
437         clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
438         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
439         if err != nil {
440                 t.Fatal(err)
441         }
442
443         var serverTLSUniquesValue []byte
444         select {
445         case <-childDone:
446                 return
447         case serverTLSUniquesValue = <-serverTLSUniques:
448         }
449
450         if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
451                 t.Error("client and server channel bindings differ")
452         }
453         conn.Close()
454
455         conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
456         if err != nil {
457                 t.Fatal(err)
458         }
459         defer conn.Close()
460         if !conn.ConnectionState().DidResume {
461                 t.Error("second session did not use resumption")
462         }
463
464         select {
465         case <-childDone:
466                 return
467         case serverTLSUniquesValue = <-serverTLSUniques:
468         }
469
470         if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
471                 t.Error("client and server channel bindings differ when session resumption is used")
472         }
473 }
474
475 func TestVerifyHostname(t *testing.T) {
476         testenv.MustHaveExternalNetwork(t)
477
478         c, err := Dial("tcp", "www.google.com:https", nil)
479         if err != nil {
480                 t.Fatal(err)
481         }
482         if err := c.VerifyHostname("www.google.com"); err != nil {
483                 t.Fatalf("verify www.google.com: %v", err)
484         }
485         if err := c.VerifyHostname("www.yahoo.com"); err == nil {
486                 t.Fatalf("verify www.yahoo.com succeeded")
487         }
488
489         c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true})
490         if err != nil {
491                 t.Fatal(err)
492         }
493         if err := c.VerifyHostname("www.google.com"); err == nil {
494                 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
495         }
496 }
497
498 func TestConnCloseBreakingWrite(t *testing.T) {
499         ln := newLocalListener(t)
500         defer ln.Close()
501
502         srvCh := make(chan *Conn, 1)
503         var serr error
504         var sconn net.Conn
505         go func() {
506                 var err error
507                 sconn, err = ln.Accept()
508                 if err != nil {
509                         serr = err
510                         srvCh <- nil
511                         return
512                 }
513                 serverConfig := testConfig.Clone()
514                 srv := Server(sconn, serverConfig)
515                 if err := srv.Handshake(); err != nil {
516                         serr = fmt.Errorf("handshake: %v", err)
517                         srvCh <- nil
518                         return
519                 }
520                 srvCh <- srv
521         }()
522
523         cconn, err := net.Dial("tcp", ln.Addr().String())
524         if err != nil {
525                 t.Fatal(err)
526         }
527         defer cconn.Close()
528
529         conn := &changeImplConn{
530                 Conn: cconn,
531         }
532
533         clientConfig := testConfig.Clone()
534         tconn := Client(conn, clientConfig)
535         if err := tconn.Handshake(); err != nil {
536                 t.Fatal(err)
537         }
538
539         srv := <-srvCh
540         if srv == nil {
541                 t.Fatal(serr)
542         }
543         defer sconn.Close()
544
545         connClosed := make(chan struct{})
546         conn.closeFunc = func() error {
547                 close(connClosed)
548                 return nil
549         }
550
551         inWrite := make(chan bool, 1)
552         var errConnClosed = errors.New("conn closed for test")
553         conn.writeFunc = func(p []byte) (n int, err error) {
554                 inWrite <- true
555                 <-connClosed
556                 return 0, errConnClosed
557         }
558
559         closeReturned := make(chan bool, 1)
560         go func() {
561                 <-inWrite
562                 tconn.Close() // test that this doesn't block forever.
563                 closeReturned <- true
564         }()
565
566         _, err = tconn.Write([]byte("foo"))
567         if err != errConnClosed {
568                 t.Errorf("Write error = %v; want errConnClosed", err)
569         }
570
571         <-closeReturned
572         if err := tconn.Close(); err != net.ErrClosed {
573                 t.Errorf("Close error = %v; want net.ErrClosed", err)
574         }
575 }
576
577 func TestConnCloseWrite(t *testing.T) {
578         ln := newLocalListener(t)
579         defer ln.Close()
580
581         clientDoneChan := make(chan struct{})
582
583         serverCloseWrite := func() error {
584                 sconn, err := ln.Accept()
585                 if err != nil {
586                         return fmt.Errorf("accept: %v", err)
587                 }
588                 defer sconn.Close()
589
590                 serverConfig := testConfig.Clone()
591                 srv := Server(sconn, serverConfig)
592                 if err := srv.Handshake(); err != nil {
593                         return fmt.Errorf("handshake: %v", err)
594                 }
595                 defer srv.Close()
596
597                 data, err := io.ReadAll(srv)
598                 if err != nil {
599                         return err
600                 }
601                 if len(data) > 0 {
602                         return fmt.Errorf("Read data = %q; want nothing", data)
603                 }
604
605                 if err := srv.CloseWrite(); err != nil {
606                         return fmt.Errorf("server CloseWrite: %v", err)
607                 }
608
609                 // Wait for clientCloseWrite to finish, so we know we
610                 // tested the CloseWrite before we defer the
611                 // sconn.Close above, which would also cause the
612                 // client to unblock like CloseWrite.
613                 <-clientDoneChan
614                 return nil
615         }
616
617         clientCloseWrite := func() error {
618                 defer close(clientDoneChan)
619
620                 clientConfig := testConfig.Clone()
621                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
622                 if err != nil {
623                         return err
624                 }
625                 if err := conn.Handshake(); err != nil {
626                         return err
627                 }
628                 defer conn.Close()
629
630                 if err := conn.CloseWrite(); err != nil {
631                         return fmt.Errorf("client CloseWrite: %v", err)
632                 }
633
634                 if _, err := conn.Write([]byte{0}); err != errShutdown {
635                         return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
636                 }
637
638                 data, err := io.ReadAll(conn)
639                 if err != nil {
640                         return err
641                 }
642                 if len(data) > 0 {
643                         return fmt.Errorf("Read data = %q; want nothing", data)
644                 }
645                 return nil
646         }
647
648         errChan := make(chan error, 2)
649
650         go func() { errChan <- serverCloseWrite() }()
651         go func() { errChan <- clientCloseWrite() }()
652
653         for i := 0; i < 2; i++ {
654                 select {
655                 case err := <-errChan:
656                         if err != nil {
657                                 t.Fatal(err)
658                         }
659                 case <-time.After(10 * time.Second):
660                         t.Fatal("deadlock")
661                 }
662         }
663
664         // Also test CloseWrite being called before the handshake is
665         // finished:
666         {
667                 ln2 := newLocalListener(t)
668                 defer ln2.Close()
669
670                 netConn, err := net.Dial("tcp", ln2.Addr().String())
671                 if err != nil {
672                         t.Fatal(err)
673                 }
674                 defer netConn.Close()
675                 conn := Client(netConn, testConfig.Clone())
676
677                 if err := conn.CloseWrite(); err != errEarlyCloseWrite {
678                         t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
679                 }
680         }
681 }
682
683 func TestWarningAlertFlood(t *testing.T) {
684         ln := newLocalListener(t)
685         defer ln.Close()
686
687         server := func() error {
688                 sconn, err := ln.Accept()
689                 if err != nil {
690                         return fmt.Errorf("accept: %v", err)
691                 }
692                 defer sconn.Close()
693
694                 serverConfig := testConfig.Clone()
695                 srv := Server(sconn, serverConfig)
696                 if err := srv.Handshake(); err != nil {
697                         return fmt.Errorf("handshake: %v", err)
698                 }
699                 defer srv.Close()
700
701                 _, err = io.ReadAll(srv)
702                 if err == nil {
703                         return errors.New("unexpected lack of error from server")
704                 }
705                 const expected = "too many ignored"
706                 if str := err.Error(); !strings.Contains(str, expected) {
707                         return fmt.Errorf("expected error containing %q, but saw: %s", expected, str)
708                 }
709
710                 return nil
711         }
712
713         errChan := make(chan error, 1)
714         go func() { errChan <- server() }()
715
716         clientConfig := testConfig.Clone()
717         clientConfig.MaxVersion = VersionTLS12 // there are no warning alerts in TLS 1.3
718         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
719         if err != nil {
720                 t.Fatal(err)
721         }
722         defer conn.Close()
723         if err := conn.Handshake(); err != nil {
724                 t.Fatal(err)
725         }
726
727         for i := 0; i < maxUselessRecords+1; i++ {
728                 conn.sendAlert(alertNoRenegotiation)
729         }
730
731         if err := <-errChan; err != nil {
732                 t.Fatal(err)
733         }
734 }
735
736 func TestCloneFuncFields(t *testing.T) {
737         const expectedCount = 6
738         called := 0
739
740         c1 := Config{
741                 Time: func() time.Time {
742                         called |= 1 << 0
743                         return time.Time{}
744                 },
745                 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
746                         called |= 1 << 1
747                         return nil, nil
748                 },
749                 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) {
750                         called |= 1 << 2
751                         return nil, nil
752                 },
753                 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) {
754                         called |= 1 << 3
755                         return nil, nil
756                 },
757                 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
758                         called |= 1 << 4
759                         return nil
760                 },
761                 VerifyConnection: func(ConnectionState) error {
762                         called |= 1 << 5
763                         return nil
764                 },
765         }
766
767         c2 := c1.Clone()
768
769         c2.Time()
770         c2.GetCertificate(nil)
771         c2.GetClientCertificate(nil)
772         c2.GetConfigForClient(nil)
773         c2.VerifyPeerCertificate(nil, nil)
774         c2.VerifyConnection(ConnectionState{})
775
776         if called != (1<<expectedCount)-1 {
777                 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
778         }
779 }
780
781 func TestCloneNonFuncFields(t *testing.T) {
782         var c1 Config
783         v := reflect.ValueOf(&c1).Elem()
784
785         typ := v.Type()
786         for i := 0; i < typ.NumField(); i++ {
787                 f := v.Field(i)
788                 // testing/quick can't handle functions or interfaces and so
789                 // isn't used here.
790                 switch fn := typ.Field(i).Name; fn {
791                 case "Rand":
792                         f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
793                 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
794                         // DeepEqual can't compare functions. If you add a
795                         // function field to this list, you must also change
796                         // TestCloneFuncFields to ensure that the func field is
797                         // cloned.
798                 case "Certificates":
799                         f.Set(reflect.ValueOf([]Certificate{
800                                 {Certificate: [][]byte{{'b'}}},
801                         }))
802                 case "NameToCertificate":
803                         f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil}))
804                 case "RootCAs", "ClientCAs":
805                         f.Set(reflect.ValueOf(x509.NewCertPool()))
806                 case "ClientSessionCache":
807                         f.Set(reflect.ValueOf(NewLRUClientSessionCache(10)))
808                 case "KeyLogWriter":
809                         f.Set(reflect.ValueOf(io.Writer(os.Stdout)))
810                 case "NextProtos":
811                         f.Set(reflect.ValueOf([]string{"a", "b"}))
812                 case "ServerName":
813                         f.Set(reflect.ValueOf("b"))
814                 case "ClientAuth":
815                         f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
816                 case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
817                         f.Set(reflect.ValueOf(true))
818                 case "MinVersion", "MaxVersion":
819                         f.Set(reflect.ValueOf(uint16(VersionTLS12)))
820                 case "SessionTicketKey":
821                         f.Set(reflect.ValueOf([32]byte{}))
822                 case "CipherSuites":
823                         f.Set(reflect.ValueOf([]uint16{1, 2}))
824                 case "CurvePreferences":
825                         f.Set(reflect.ValueOf([]CurveID{CurveP256}))
826                 case "Renegotiation":
827                         f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
828                 case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
829                         continue // these are unexported fields that are handled separately
830                 default:
831                         t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
832                 }
833         }
834         // Set the unexported fields related to session ticket keys, which are copied with Clone().
835         c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
836         c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
837
838         c2 := c1.Clone()
839         if !reflect.DeepEqual(&c1, c2) {
840                 t.Errorf("clone failed to copy a field")
841         }
842 }
843
844 func TestCloneNilConfig(t *testing.T) {
845         var config *Config
846         if cc := config.Clone(); cc != nil {
847                 t.Fatalf("Clone with nil should return nil, got: %+v", cc)
848         }
849 }
850
851 // changeImplConn is a net.Conn which can change its Write and Close
852 // methods.
853 type changeImplConn struct {
854         net.Conn
855         writeFunc func([]byte) (int, error)
856         closeFunc func() error
857 }
858
859 func (w *changeImplConn) Write(p []byte) (n int, err error) {
860         if w.writeFunc != nil {
861                 return w.writeFunc(p)
862         }
863         return w.Conn.Write(p)
864 }
865
866 func (w *changeImplConn) Close() error {
867         if w.closeFunc != nil {
868                 return w.closeFunc()
869         }
870         return w.Conn.Close()
871 }
872
873 func throughput(b *testing.B, version uint16, totalBytes int64, dynamicRecordSizingDisabled bool) {
874         ln := newLocalListener(b)
875         defer ln.Close()
876
877         N := b.N
878
879         // Less than 64KB because Windows appears to use a TCP rwin < 64KB.
880         // See Issue #15899.
881         const bufsize = 32 << 10
882
883         go func() {
884                 buf := make([]byte, bufsize)
885                 for i := 0; i < N; i++ {
886                         sconn, err := ln.Accept()
887                         if err != nil {
888                                 // panic rather than synchronize to avoid benchmark overhead
889                                 // (cannot call b.Fatal in goroutine)
890                                 panic(fmt.Errorf("accept: %v", err))
891                         }
892                         serverConfig := testConfig.Clone()
893                         serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers
894                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
895                         srv := Server(sconn, serverConfig)
896                         if err := srv.Handshake(); err != nil {
897                                 panic(fmt.Errorf("handshake: %v", err))
898                         }
899                         if _, err := io.CopyBuffer(srv, srv, buf); err != nil {
900                                 panic(fmt.Errorf("copy buffer: %v", err))
901                         }
902                 }
903         }()
904
905         b.SetBytes(totalBytes)
906         clientConfig := testConfig.Clone()
907         clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers
908         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
909         clientConfig.MaxVersion = version
910
911         buf := make([]byte, bufsize)
912         chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
913         for i := 0; i < N; i++ {
914                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
915                 if err != nil {
916                         b.Fatal(err)
917                 }
918                 for j := 0; j < chunks; j++ {
919                         _, err := conn.Write(buf)
920                         if err != nil {
921                                 b.Fatal(err)
922                         }
923                         _, err = io.ReadFull(conn, buf)
924                         if err != nil {
925                                 b.Fatal(err)
926                         }
927                 }
928                 conn.Close()
929         }
930 }
931
932 func BenchmarkThroughput(b *testing.B) {
933         for _, mode := range []string{"Max", "Dynamic"} {
934                 for size := 1; size <= 64; size <<= 1 {
935                         name := fmt.Sprintf("%sPacket/%dMB", mode, size)
936                         b.Run(name, func(b *testing.B) {
937                                 b.Run("TLSv12", func(b *testing.B) {
938                                         throughput(b, VersionTLS12, int64(size<<20), mode == "Max")
939                                 })
940                                 b.Run("TLSv13", func(b *testing.B) {
941                                         throughput(b, VersionTLS13, int64(size<<20), mode == "Max")
942                                 })
943                         })
944                 }
945         }
946 }
947
948 type slowConn struct {
949         net.Conn
950         bps int
951 }
952
953 func (c *slowConn) Write(p []byte) (int, error) {
954         if c.bps == 0 {
955                 panic("too slow")
956         }
957         t0 := time.Now()
958         wrote := 0
959         for wrote < len(p) {
960                 time.Sleep(100 * time.Microsecond)
961                 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8
962                 if allowed > len(p) {
963                         allowed = len(p)
964                 }
965                 if wrote < allowed {
966                         n, err := c.Conn.Write(p[wrote:allowed])
967                         wrote += n
968                         if err != nil {
969                                 return wrote, err
970                         }
971                 }
972         }
973         return len(p), nil
974 }
975
976 func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) {
977         ln := newLocalListener(b)
978         defer ln.Close()
979
980         N := b.N
981
982         go func() {
983                 for i := 0; i < N; i++ {
984                         sconn, err := ln.Accept()
985                         if err != nil {
986                                 // panic rather than synchronize to avoid benchmark overhead
987                                 // (cannot call b.Fatal in goroutine)
988                                 panic(fmt.Errorf("accept: %v", err))
989                         }
990                         serverConfig := testConfig.Clone()
991                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
992                         srv := Server(&slowConn{sconn, bps}, serverConfig)
993                         if err := srv.Handshake(); err != nil {
994                                 panic(fmt.Errorf("handshake: %v", err))
995                         }
996                         io.Copy(srv, srv)
997                 }
998         }()
999
1000         clientConfig := testConfig.Clone()
1001         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
1002         clientConfig.MaxVersion = version
1003
1004         buf := make([]byte, 16384)
1005         peek := make([]byte, 1)
1006
1007         for i := 0; i < N; i++ {
1008                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
1009                 if err != nil {
1010                         b.Fatal(err)
1011                 }
1012                 // make sure we're connected and previous connection has stopped
1013                 if _, err := conn.Write(buf[:1]); err != nil {
1014                         b.Fatal(err)
1015                 }
1016                 if _, err := io.ReadFull(conn, peek); err != nil {
1017                         b.Fatal(err)
1018                 }
1019                 if _, err := conn.Write(buf); err != nil {
1020                         b.Fatal(err)
1021                 }
1022                 if _, err = io.ReadFull(conn, peek); err != nil {
1023                         b.Fatal(err)
1024                 }
1025                 conn.Close()
1026         }
1027 }
1028
1029 func BenchmarkLatency(b *testing.B) {
1030         for _, mode := range []string{"Max", "Dynamic"} {
1031                 for _, kbps := range []int{200, 500, 1000, 2000, 5000} {
1032                         name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps)
1033                         b.Run(name, func(b *testing.B) {
1034                                 b.Run("TLSv12", func(b *testing.B) {
1035                                         latency(b, VersionTLS12, kbps*1000, mode == "Max")
1036                                 })
1037                                 b.Run("TLSv13", func(b *testing.B) {
1038                                         latency(b, VersionTLS13, kbps*1000, mode == "Max")
1039                                 })
1040                         })
1041                 }
1042         }
1043 }
1044
1045 func TestConnectionStateMarshal(t *testing.T) {
1046         cs := &ConnectionState{}
1047         _, err := json.Marshal(cs)
1048         if err != nil {
1049                 t.Errorf("json.Marshal failed on ConnectionState: %v", err)
1050         }
1051 }
1052
1053 func TestConnectionState(t *testing.T) {
1054         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1055         if err != nil {
1056                 panic(err)
1057         }
1058         rootCAs := x509.NewCertPool()
1059         rootCAs.AddCert(issuer)
1060
1061         now := func() time.Time { return time.Unix(1476984729, 0) }
1062
1063         const alpnProtocol = "golang"
1064         const serverName = "example.golang"
1065         var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1066         var ocsp = []byte("dummy ocsp")
1067
1068         for _, v := range []uint16{VersionTLS12, VersionTLS13} {
1069                 var name string
1070                 switch v {
1071                 case VersionTLS12:
1072                         name = "TLSv12"
1073                 case VersionTLS13:
1074                         name = "TLSv13"
1075                 }
1076                 t.Run(name, func(t *testing.T) {
1077                         config := &Config{
1078                                 Time:         now,
1079                                 Rand:         zeroSource{},
1080                                 Certificates: make([]Certificate, 1),
1081                                 MaxVersion:   v,
1082                                 RootCAs:      rootCAs,
1083                                 ClientCAs:    rootCAs,
1084                                 ClientAuth:   RequireAndVerifyClientCert,
1085                                 NextProtos:   []string{alpnProtocol},
1086                                 ServerName:   serverName,
1087                         }
1088                         config.Certificates[0].Certificate = [][]byte{testRSACertificate}
1089                         config.Certificates[0].PrivateKey = testRSAPrivateKey
1090                         config.Certificates[0].SignedCertificateTimestamps = scts
1091                         config.Certificates[0].OCSPStaple = ocsp
1092
1093                         ss, cs, err := testHandshake(t, config, config)
1094                         if err != nil {
1095                                 t.Fatalf("Handshake failed: %v", err)
1096                         }
1097
1098                         if ss.Version != v || cs.Version != v {
1099                                 t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v)
1100                         }
1101
1102                         if !ss.HandshakeComplete || !cs.HandshakeComplete {
1103                                 t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete)
1104                         }
1105
1106                         if ss.DidResume || cs.DidResume {
1107                                 t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume)
1108                         }
1109
1110                         if ss.CipherSuite == 0 || cs.CipherSuite == 0 {
1111                                 t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite)
1112                         }
1113
1114                         if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol {
1115                                 t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol)
1116                         }
1117
1118                         if !cs.NegotiatedProtocolIsMutual {
1119                                 t.Errorf("Got false NegotiatedProtocolIsMutual on the client side")
1120                         }
1121                         // NegotiatedProtocolIsMutual on the server side is unspecified.
1122
1123                         if ss.ServerName != serverName {
1124                                 t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
1125                         }
1126                         if cs.ServerName != serverName {
1127                                 t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
1128                         }
1129
1130                         if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
1131                                 t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1)
1132                         }
1133
1134                         if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 {
1135                                 t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1)
1136                         } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 {
1137                                 t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2)
1138                         }
1139
1140                         if len(cs.SignedCertificateTimestamps) != 2 {
1141                                 t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2)
1142                         }
1143                         if !bytes.Equal(cs.OCSPResponse, ocsp) {
1144                                 t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp)
1145                         }
1146                         // Only TLS 1.3 supports OCSP and SCTs on client certs.
1147                         if v == VersionTLS13 {
1148                                 if len(ss.SignedCertificateTimestamps) != 2 {
1149                                         t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2)
1150                                 }
1151                                 if !bytes.Equal(ss.OCSPResponse, ocsp) {
1152                                         t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp)
1153                                 }
1154                         }
1155
1156                         if v == VersionTLS13 {
1157                                 if ss.TLSUnique != nil || cs.TLSUnique != nil {
1158                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique)
1159                                 }
1160                         } else {
1161                                 if ss.TLSUnique == nil || cs.TLSUnique == nil {
1162                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique)
1163                                 }
1164                         }
1165                 })
1166         }
1167 }
1168
1169 // Issue 28744: Ensure that we don't modify memory
1170 // that Config doesn't own such as Certificates.
1171 func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) {
1172         c0 := Certificate{
1173                 Certificate: [][]byte{testRSACertificate},
1174                 PrivateKey:  testRSAPrivateKey,
1175         }
1176         c1 := Certificate{
1177                 Certificate: [][]byte{testSNICertificate},
1178                 PrivateKey:  testRSAPrivateKey,
1179         }
1180         config := testConfig.Clone()
1181         config.Certificates = []Certificate{c0, c1}
1182
1183         config.BuildNameToCertificate()
1184         got := config.Certificates
1185         want := []Certificate{c0, c1}
1186         if !reflect.DeepEqual(got, want) {
1187                 t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want)
1188         }
1189 }
1190
1191 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
1192
1193 func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
1194         rsaCert := &Certificate{
1195                 Certificate: [][]byte{testRSACertificate},
1196                 PrivateKey:  testRSAPrivateKey,
1197         }
1198         pkcs1Cert := &Certificate{
1199                 Certificate:                  [][]byte{testRSACertificate},
1200                 PrivateKey:                   testRSAPrivateKey,
1201                 SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256},
1202         }
1203         ecdsaCert := &Certificate{
1204                 // ECDSA P-256 certificate
1205                 Certificate: [][]byte{testP256Certificate},
1206                 PrivateKey:  testP256PrivateKey,
1207         }
1208         ed25519Cert := &Certificate{
1209                 Certificate: [][]byte{testEd25519Certificate},
1210                 PrivateKey:  testEd25519PrivateKey,
1211         }
1212
1213         tests := []struct {
1214                 c       *Certificate
1215                 chi     *ClientHelloInfo
1216                 wantErr string
1217         }{
1218                 {rsaCert, &ClientHelloInfo{
1219                         ServerName:        "example.golang",
1220                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1221                         SupportedVersions: []uint16{VersionTLS13},
1222                 }, ""},
1223                 {ecdsaCert, &ClientHelloInfo{
1224                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1225                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1226                 }, ""},
1227                 {rsaCert, &ClientHelloInfo{
1228                         ServerName:        "example.com",
1229                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1230                         SupportedVersions: []uint16{VersionTLS13},
1231                 }, "not valid for requested server name"},
1232                 {ecdsaCert, &ClientHelloInfo{
1233                         SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1234                         SupportedVersions: []uint16{VersionTLS13},
1235                 }, "signature algorithms"},
1236                 {pkcs1Cert, &ClientHelloInfo{
1237                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1238                         SupportedVersions: []uint16{VersionTLS13},
1239                 }, "signature algorithms"},
1240
1241                 {rsaCert, &ClientHelloInfo{
1242                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1243                         SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1244                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1245                 }, "signature algorithms"},
1246                 {rsaCert, &ClientHelloInfo{
1247                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1248                         SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1249                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1250                         config: &Config{
1251                                 MaxVersion: VersionTLS12,
1252                         },
1253                 }, ""}, // Check that mutual version selection works.
1254
1255                 {ecdsaCert, &ClientHelloInfo{
1256                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1257                         SupportedCurves:   []CurveID{CurveP256},
1258                         SupportedPoints:   []uint8{pointFormatUncompressed},
1259                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1260                         SupportedVersions: []uint16{VersionTLS12},
1261                 }, ""},
1262                 {ecdsaCert, &ClientHelloInfo{
1263                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1264                         SupportedCurves:   []CurveID{CurveP256},
1265                         SupportedPoints:   []uint8{pointFormatUncompressed},
1266                         SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1267                         SupportedVersions: []uint16{VersionTLS12},
1268                 }, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme.
1269                 {ecdsaCert, &ClientHelloInfo{
1270                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1271                         SupportedCurves:   []CurveID{CurveP256},
1272                         SupportedPoints:   []uint8{pointFormatUncompressed},
1273                         SignatureSchemes:  nil,
1274                         SupportedVersions: []uint16{VersionTLS12},
1275                 }, ""}, // TLS 1.2 comes with default signature schemes.
1276                 {ecdsaCert, &ClientHelloInfo{
1277                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1278                         SupportedCurves:   []CurveID{CurveP256},
1279                         SupportedPoints:   []uint8{pointFormatUncompressed},
1280                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1281                         SupportedVersions: []uint16{VersionTLS12},
1282                 }, "cipher suite"},
1283                 {ecdsaCert, &ClientHelloInfo{
1284                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1285                         SupportedCurves:   []CurveID{CurveP256},
1286                         SupportedPoints:   []uint8{pointFormatUncompressed},
1287                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1288                         SupportedVersions: []uint16{VersionTLS12},
1289                         config: &Config{
1290                                 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1291                         },
1292                 }, "cipher suite"},
1293                 {ecdsaCert, &ClientHelloInfo{
1294                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1295                         SupportedCurves:   []CurveID{CurveP384},
1296                         SupportedPoints:   []uint8{pointFormatUncompressed},
1297                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1298                         SupportedVersions: []uint16{VersionTLS12},
1299                 }, "certificate curve"},
1300                 {ecdsaCert, &ClientHelloInfo{
1301                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1302                         SupportedCurves:   []CurveID{CurveP256},
1303                         SupportedPoints:   []uint8{1},
1304                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1305                         SupportedVersions: []uint16{VersionTLS12},
1306                 }, "doesn't support ECDHE"},
1307                 {ecdsaCert, &ClientHelloInfo{
1308                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1309                         SupportedCurves:   []CurveID{CurveP256},
1310                         SupportedPoints:   []uint8{pointFormatUncompressed},
1311                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1312                         SupportedVersions: []uint16{VersionTLS12},
1313                 }, "signature algorithms"},
1314
1315                 {ed25519Cert, &ClientHelloInfo{
1316                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1317                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1318                         SupportedPoints:   []uint8{pointFormatUncompressed},
1319                         SignatureSchemes:  []SignatureScheme{Ed25519},
1320                         SupportedVersions: []uint16{VersionTLS12},
1321                 }, ""},
1322                 {ed25519Cert, &ClientHelloInfo{
1323                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1324                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1325                         SupportedPoints:   []uint8{pointFormatUncompressed},
1326                         SignatureSchemes:  []SignatureScheme{Ed25519},
1327                         SupportedVersions: []uint16{VersionTLS10},
1328                 }, "doesn't support Ed25519"},
1329                 {ed25519Cert, &ClientHelloInfo{
1330                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1331                         SupportedCurves:   []CurveID{},
1332                         SupportedPoints:   []uint8{pointFormatUncompressed},
1333                         SignatureSchemes:  []SignatureScheme{Ed25519},
1334                         SupportedVersions: []uint16{VersionTLS12},
1335                 }, "doesn't support ECDHE"},
1336
1337                 {rsaCert, &ClientHelloInfo{
1338                         CipherSuites:      []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
1339                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1340                         SupportedPoints:   []uint8{pointFormatUncompressed},
1341                         SupportedVersions: []uint16{VersionTLS10},
1342                 }, ""},
1343                 {rsaCert, &ClientHelloInfo{
1344                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1345                         SupportedVersions: []uint16{VersionTLS12},
1346                 }, ""}, // static RSA fallback
1347         }
1348         for i, tt := range tests {
1349                 err := tt.chi.SupportsCertificate(tt.c)
1350                 switch {
1351                 case tt.wantErr == "" && err != nil:
1352                         t.Errorf("%d: unexpected error: %v", i, err)
1353                 case tt.wantErr != "" && err == nil:
1354                         t.Errorf("%d: unexpected success", i)
1355                 case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr):
1356                         t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr)
1357                 }
1358         }
1359 }
1360
1361 func TestCipherSuites(t *testing.T) {
1362         var lastID uint16
1363         for _, c := range CipherSuites() {
1364                 if lastID > c.ID {
1365                         t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1366                 } else {
1367                         lastID = c.ID
1368                 }
1369
1370                 if c.Insecure {
1371                         t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID)
1372                 }
1373         }
1374         lastID = 0
1375         for _, c := range InsecureCipherSuites() {
1376                 if lastID > c.ID {
1377                         t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1378                 } else {
1379                         lastID = c.ID
1380                 }
1381
1382                 if !c.Insecure {
1383                         t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID)
1384                 }
1385         }
1386
1387         CipherSuiteByID := func(id uint16) *CipherSuite {
1388                 for _, c := range CipherSuites() {
1389                         if c.ID == id {
1390                                 return c
1391                         }
1392                 }
1393                 for _, c := range InsecureCipherSuites() {
1394                         if c.ID == id {
1395                                 return c
1396                         }
1397                 }
1398                 return nil
1399         }
1400
1401         for _, c := range cipherSuites {
1402                 cc := CipherSuiteByID(c.id)
1403                 if cc == nil {
1404                         t.Errorf("%#04x: no CipherSuite entry", c.id)
1405                         continue
1406                 }
1407
1408                 if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
1409                         t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1410                 } else if !tls12Only && len(cc.SupportedVersions) != 3 {
1411                         t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1412                 }
1413
1414                 if got := CipherSuiteName(c.id); got != cc.Name {
1415                         t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1416                 }
1417         }
1418         for _, c := range cipherSuitesTLS13 {
1419                 cc := CipherSuiteByID(c.id)
1420                 if cc == nil {
1421                         t.Errorf("%#04x: no CipherSuite entry", c.id)
1422                         continue
1423                 }
1424
1425                 if cc.Insecure {
1426                         t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure)
1427                 }
1428                 if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 {
1429                         t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1430                 }
1431
1432                 if got := CipherSuiteName(c.id); got != cc.Name {
1433                         t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1434                 }
1435         }
1436
1437         if got := CipherSuiteName(0xabc); got != "0x0ABC" {
1438                 t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
1439         }
1440
1441         if len(cipherSuitesPreferenceOrder) != len(cipherSuites) {
1442                 t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites")
1443         }
1444         if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) {
1445                 t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder")
1446         }
1447
1448         // Check that disabled suites are at the end of the preference lists, and
1449         // that they are marked insecure.
1450         for i, id := range disabledCipherSuites {
1451                 offset := len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
1452                 if cipherSuitesPreferenceOrder[offset+i] != id {
1453                         t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrder", i)
1454                 }
1455                 if cipherSuitesPreferenceOrderNoAES[offset+i] != id {
1456                         t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrderNoAES", i)
1457                 }
1458                 c := CipherSuiteByID(id)
1459                 if c == nil {
1460                         t.Errorf("%#04x: no CipherSuite entry", id)
1461                         continue
1462                 }
1463                 if !c.Insecure {
1464                         t.Errorf("%#04x: disabled by default but not marked insecure", id)
1465                 }
1466         }
1467
1468         for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} {
1469                 // Check that insecure and HTTP/2 bad cipher suites are at the end of
1470                 // the preference lists.
1471                 var sawInsecure, sawBad bool
1472                 for _, id := range prefOrder {
1473                         c := CipherSuiteByID(id)
1474                         if c == nil {
1475                                 t.Errorf("%#04x: no CipherSuite entry", id)
1476                                 continue
1477                         }
1478
1479                         if c.Insecure {
1480                                 sawInsecure = true
1481                         } else if sawInsecure {
1482                                 t.Errorf("%#04x: secure suite after insecure one(s)", id)
1483                         }
1484
1485                         if http2isBadCipher(id) {
1486                                 sawBad = true
1487                         } else if sawBad {
1488                                 t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id)
1489                         }
1490                 }
1491
1492                 // Check that the list is sorted according to the documented criteria.
1493                 isBetter := func(a, b int) bool {
1494                         aSuite, bSuite := cipherSuiteByID(prefOrder[a]), cipherSuiteByID(prefOrder[b])
1495                         aName, bName := CipherSuiteName(prefOrder[a]), CipherSuiteName(prefOrder[b])
1496                         // * < RC4
1497                         if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") {
1498                                 return true
1499                         } else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") {
1500                                 return false
1501                         }
1502                         // * < CBC_SHA256
1503                         if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") {
1504                                 return true
1505                         } else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") {
1506                                 return false
1507                         }
1508                         // * < 3DES
1509                         if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") {
1510                                 return true
1511                         } else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") {
1512                                 return false
1513                         }
1514                         // ECDHE < *
1515                         if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 {
1516                                 return true
1517                         } else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 {
1518                                 return false
1519                         }
1520                         // AEAD < CBC
1521                         if aSuite.aead != nil && bSuite.aead == nil {
1522                                 return true
1523                         } else if aSuite.aead == nil && bSuite.aead != nil {
1524                                 return false
1525                         }
1526                         // AES < ChaCha20
1527                         if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") {
1528                                 return i == 0 // true for cipherSuitesPreferenceOrder
1529                         } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") {
1530                                 return i != 0 // true for cipherSuitesPreferenceOrderNoAES
1531                         }
1532                         // AES-128 < AES-256
1533                         if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") {
1534                                 return true
1535                         } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") {
1536                                 return false
1537                         }
1538                         // ECDSA < RSA
1539                         if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 {
1540                                 return true
1541                         } else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 {
1542                                 return false
1543                         }
1544                         t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName)
1545                         panic("unreachable")
1546                 }
1547                 if !sort.SliceIsSorted(prefOrder, isBetter) {
1548                         t.Error("preference order is not sorted according to the rules")
1549                 }
1550         }
1551 }
1552
1553 // http2isBadCipher is copied from net/http.
1554 // TODO: if it ends up exposed somewhere, use that instead.
1555 func http2isBadCipher(cipher uint16) bool {
1556         switch cipher {
1557         case TLS_RSA_WITH_RC4_128_SHA,
1558                 TLS_RSA_WITH_3DES_EDE_CBC_SHA,
1559                 TLS_RSA_WITH_AES_128_CBC_SHA,
1560                 TLS_RSA_WITH_AES_256_CBC_SHA,
1561                 TLS_RSA_WITH_AES_128_CBC_SHA256,
1562                 TLS_RSA_WITH_AES_128_GCM_SHA256,
1563                 TLS_RSA_WITH_AES_256_GCM_SHA384,
1564                 TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
1565                 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
1566                 TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
1567                 TLS_ECDHE_RSA_WITH_RC4_128_SHA,
1568                 TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
1569                 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
1570                 TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
1571                 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
1572                 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
1573                 return true
1574         default:
1575                 return false
1576         }
1577 }
1578
1579 type brokenSigner struct{ crypto.Signer }
1580
1581 func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
1582         // Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded.
1583         return s.Signer.Sign(rand, digest, opts.HashFunc())
1584 }
1585
1586 // TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that
1587 // always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS.
1588 func TestPKCS1OnlyCert(t *testing.T) {
1589         clientConfig := testConfig.Clone()
1590         clientConfig.Certificates = []Certificate{{
1591                 Certificate: [][]byte{testRSACertificate},
1592                 PrivateKey:  brokenSigner{testRSAPrivateKey},
1593         }}
1594         serverConfig := testConfig.Clone()
1595         serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5
1596         serverConfig.ClientAuth = RequireAnyClientCert
1597
1598         // If RSA-PSS is selected, the handshake should fail.
1599         if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil {
1600                 t.Fatal("expected broken certificate to cause connection to fail")
1601         }
1602
1603         clientConfig.Certificates[0].SupportedSignatureAlgorithms =
1604                 []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}
1605
1606         // But if the certificate restricts supported algorithms, RSA-PSS should not
1607         // be selected, and the handshake should succeed.
1608         if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
1609                 t.Error(err)
1610         }
1611 }