]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/tls_test.go
[dev.boringcrypto] all: merge master into dev.boringcrypto
[gostls13.git] / src / crypto / tls / tls_test.go
1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "context"
10         "crypto"
11         "crypto/x509"
12         "encoding/json"
13         "errors"
14         "fmt"
15         "internal/testenv"
16         "io"
17         "io/ioutil"
18         "math"
19         "net"
20         "os"
21         "reflect"
22         "strings"
23         "testing"
24         "time"
25 )
26
27 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
28 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
29 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
30 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
31 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
32 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
33 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
34 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
35 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
36 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
37 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
38 -----END CERTIFICATE-----
39 `
40
41 var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY-----
42 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
43 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
44 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
45 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
46 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
47 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
48 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
49 -----END RSA TESTING KEY-----
50 `)
51
52 // keyPEM is the same as rsaKeyPEM, but declares itself as just
53 // "PRIVATE KEY", not "RSA PRIVATE KEY".  https://golang.org/issue/4477
54 var keyPEM = testingKey(`-----BEGIN TESTING KEY-----
55 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
56 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
57 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
58 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
59 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
60 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
61 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
62 -----END TESTING KEY-----
63 `)
64
65 var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
66 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
67 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
68 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG
69 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk
70 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR
71 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl
72 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8
73 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo
74 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb
75 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1
76 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA==
77 -----END CERTIFICATE-----
78 `
79
80 var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS-----
81 BgUrgQQAIw==
82 -----END EC PARAMETERS-----
83 -----BEGIN EC TESTING KEY-----
84 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
85 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
86 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
87 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
88 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
89 -----END EC TESTING KEY-----
90 `)
91
92 var keyPairTests = []struct {
93         algo string
94         cert string
95         key  string
96 }{
97         {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM},
98         {"RSA", rsaCertPEM, rsaKeyPEM},
99         {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477
100 }
101
102 func TestX509KeyPair(t *testing.T) {
103         t.Parallel()
104         var pem []byte
105         for _, test := range keyPairTests {
106                 pem = []byte(test.cert + test.key)
107                 if _, err := X509KeyPair(pem, pem); err != nil {
108                         t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err)
109                 }
110                 pem = []byte(test.key + test.cert)
111                 if _, err := X509KeyPair(pem, pem); err != nil {
112                         t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err)
113                 }
114         }
115 }
116
117 func TestX509KeyPairErrors(t *testing.T) {
118         _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM))
119         if err == nil {
120                 t.Fatalf("X509KeyPair didn't return an error when arguments were switched")
121         }
122         if subStr := "been switched"; !strings.Contains(err.Error(), subStr) {
123                 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err)
124         }
125
126         _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM))
127         if err == nil {
128                 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates")
129         }
130         if subStr := "certificate"; !strings.Contains(err.Error(), subStr) {
131                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err)
132         }
133
134         const nonsensePEM = `
135 -----BEGIN NONSENSE-----
136 Zm9vZm9vZm9v
137 -----END NONSENSE-----
138 `
139
140         _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM))
141         if err == nil {
142                 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense")
143         }
144         if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) {
145                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err)
146         }
147 }
148
149 func TestX509MixedKeyPair(t *testing.T) {
150         if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil {
151                 t.Error("Load of RSA certificate succeeded with ECDSA private key")
152         }
153         if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil {
154                 t.Error("Load of ECDSA certificate succeeded with RSA private key")
155         }
156 }
157
158 func newLocalListener(t testing.TB) net.Listener {
159         ln, err := net.Listen("tcp", "127.0.0.1:0")
160         if err != nil {
161                 ln, err = net.Listen("tcp6", "[::1]:0")
162         }
163         if err != nil {
164                 t.Fatal(err)
165         }
166         return ln
167 }
168
169 func TestDialTimeout(t *testing.T) {
170         if testing.Short() {
171                 t.Skip("skipping in short mode")
172         }
173         listener := newLocalListener(t)
174
175         addr := listener.Addr().String()
176         defer listener.Close()
177
178         complete := make(chan bool)
179         defer close(complete)
180
181         go func() {
182                 conn, err := listener.Accept()
183                 if err != nil {
184                         t.Error(err)
185                         return
186                 }
187                 <-complete
188                 conn.Close()
189         }()
190
191         dialer := &net.Dialer{
192                 Timeout: 10 * time.Millisecond,
193         }
194
195         var err error
196         if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
197                 t.Fatal("DialWithTimeout completed successfully")
198         }
199
200         if !isTimeoutError(err) {
201                 t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
202         }
203 }
204
205 func TestDeadlineOnWrite(t *testing.T) {
206         if testing.Short() {
207                 t.Skip("skipping in short mode")
208         }
209
210         ln := newLocalListener(t)
211         defer ln.Close()
212
213         srvCh := make(chan *Conn, 1)
214
215         go func() {
216                 sconn, err := ln.Accept()
217                 if err != nil {
218                         srvCh <- nil
219                         return
220                 }
221                 srv := Server(sconn, testConfig.Clone())
222                 if err := srv.Handshake(); err != nil {
223                         srvCh <- nil
224                         return
225                 }
226                 srvCh <- srv
227         }()
228
229         clientConfig := testConfig.Clone()
230         clientConfig.MaxVersion = VersionTLS12
231         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
232         if err != nil {
233                 t.Fatal(err)
234         }
235         defer conn.Close()
236
237         srv := <-srvCh
238         if srv == nil {
239                 t.Error(err)
240         }
241
242         // Make sure the client/server is setup correctly and is able to do a typical Write/Read
243         buf := make([]byte, 6)
244         if _, err := srv.Write([]byte("foobar")); err != nil {
245                 t.Errorf("Write err: %v", err)
246         }
247         if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" {
248                 t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
249         }
250
251         // Set a deadline which should cause Write to timeout
252         if err = srv.SetDeadline(time.Now()); err != nil {
253                 t.Fatalf("SetDeadline(time.Now()) err: %v", err)
254         }
255         if _, err = srv.Write([]byte("should fail")); err == nil {
256                 t.Fatal("Write should have timed out")
257         }
258
259         // Clear deadline and make sure it still times out
260         if err = srv.SetDeadline(time.Time{}); err != nil {
261                 t.Fatalf("SetDeadline(time.Time{}) err: %v", err)
262         }
263         if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil {
264                 t.Fatal("Write which previously failed should still time out")
265         }
266
267         // Verify the error
268         if ne := err.(net.Error); ne.Temporary() != false {
269                 t.Error("Write timed out but incorrectly classified the error as Temporary")
270         }
271         if !isTimeoutError(err) {
272                 t.Error("Write timed out but did not classify the error as a Timeout")
273         }
274 }
275
276 type readerFunc func([]byte) (int, error)
277
278 func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
279
280 // TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
281 // (The other cases are all handled by the existing dial tests in this package, which
282 // all also flow through the same code shared code paths)
283 func TestDialer(t *testing.T) {
284         ln := newLocalListener(t)
285         defer ln.Close()
286
287         unblockServer := make(chan struct{}) // close-only
288         defer close(unblockServer)
289         go func() {
290                 conn, err := ln.Accept()
291                 if err != nil {
292                         return
293                 }
294                 defer conn.Close()
295                 <-unblockServer
296         }()
297
298         ctx, cancel := context.WithCancel(context.Background())
299         d := Dialer{Config: &Config{
300                 Rand: readerFunc(func(b []byte) (n int, err error) {
301                         // By the time crypto/tls wants randomness, that means it has a TCP
302                         // connection, so we're past the Dialer's dial and now blocked
303                         // in a handshake. Cancel our context and see if we get unstuck.
304                         // (Our TCP listener above never reads or writes, so the Handshake
305                         // would otherwise be stuck forever)
306                         cancel()
307                         return len(b), nil
308                 }),
309                 ServerName: "foo",
310         }}
311         _, err := d.DialContext(ctx, "tcp", ln.Addr().String())
312         if err != context.Canceled {
313                 t.Errorf("err = %v; want context.Canceled", err)
314         }
315 }
316
317 func isTimeoutError(err error) bool {
318         if ne, ok := err.(net.Error); ok {
319                 return ne.Timeout()
320         }
321         return false
322 }
323
324 // tests that Conn.Read returns (non-zero, io.EOF) instead of
325 // (non-zero, nil) when a Close (alertCloseNotify) is sitting right
326 // behind the application data in the buffer.
327 func TestConnReadNonzeroAndEOF(t *testing.T) {
328         // This test is racy: it assumes that after a write to a
329         // localhost TCP connection, the peer TCP connection can
330         // immediately read it. Because it's racy, we skip this test
331         // in short mode, and then retry it several times with an
332         // increasing sleep in between our final write (via srv.Close
333         // below) and the following read.
334         if testing.Short() {
335                 t.Skip("skipping in short mode")
336         }
337         var err error
338         for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 {
339                 if err = testConnReadNonzeroAndEOF(t, delay); err == nil {
340                         return
341                 }
342         }
343         t.Error(err)
344 }
345
346 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error {
347         ln := newLocalListener(t)
348         defer ln.Close()
349
350         srvCh := make(chan *Conn, 1)
351         var serr error
352         go func() {
353                 sconn, err := ln.Accept()
354                 if err != nil {
355                         serr = err
356                         srvCh <- nil
357                         return
358                 }
359                 serverConfig := testConfig.Clone()
360                 srv := Server(sconn, serverConfig)
361                 if err := srv.Handshake(); err != nil {
362                         serr = fmt.Errorf("handshake: %v", err)
363                         srvCh <- nil
364                         return
365                 }
366                 srvCh <- srv
367         }()
368
369         clientConfig := testConfig.Clone()
370         // In TLS 1.3, alerts are encrypted and disguised as application data, so
371         // the opportunistic peek won't work.
372         clientConfig.MaxVersion = VersionTLS12
373         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
374         if err != nil {
375                 t.Fatal(err)
376         }
377         defer conn.Close()
378
379         srv := <-srvCh
380         if srv == nil {
381                 return serr
382         }
383
384         buf := make([]byte, 6)
385
386         srv.Write([]byte("foobar"))
387         n, err := conn.Read(buf)
388         if n != 6 || err != nil || string(buf) != "foobar" {
389                 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
390         }
391
392         srv.Write([]byte("abcdef"))
393         srv.Close()
394         time.Sleep(delay)
395         n, err = conn.Read(buf)
396         if n != 6 || string(buf) != "abcdef" {
397                 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf)
398         }
399         if err != io.EOF {
400                 return fmt.Errorf("Second Read error = %v; want io.EOF", err)
401         }
402         return nil
403 }
404
405 func TestTLSUniqueMatches(t *testing.T) {
406         ln := newLocalListener(t)
407         defer ln.Close()
408
409         serverTLSUniques := make(chan []byte)
410         parentDone := make(chan struct{})
411         childDone := make(chan struct{})
412         defer close(parentDone)
413         go func() {
414                 defer close(childDone)
415                 for i := 0; i < 2; i++ {
416                         sconn, err := ln.Accept()
417                         if err != nil {
418                                 t.Error(err)
419                                 return
420                         }
421                         serverConfig := testConfig.Clone()
422                         serverConfig.MaxVersion = VersionTLS12 // TLSUnique is not defined in TLS 1.3
423                         srv := Server(sconn, serverConfig)
424                         if err := srv.Handshake(); err != nil {
425                                 t.Error(err)
426                                 return
427                         }
428                         select {
429                         case <-parentDone:
430                                 return
431                         case serverTLSUniques <- srv.ConnectionState().TLSUnique:
432                         }
433                 }
434         }()
435
436         clientConfig := testConfig.Clone()
437         clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
438         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
439         if err != nil {
440                 t.Fatal(err)
441         }
442
443         var serverTLSUniquesValue []byte
444         select {
445         case <-childDone:
446                 return
447         case serverTLSUniquesValue = <-serverTLSUniques:
448         }
449
450         if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
451                 t.Error("client and server channel bindings differ")
452         }
453         conn.Close()
454
455         conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
456         if err != nil {
457                 t.Fatal(err)
458         }
459         defer conn.Close()
460         if !conn.ConnectionState().DidResume {
461                 t.Error("second session did not use resumption")
462         }
463
464         select {
465         case <-childDone:
466                 return
467         case serverTLSUniquesValue = <-serverTLSUniques:
468         }
469
470         if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
471                 t.Error("client and server channel bindings differ when session resumption is used")
472         }
473 }
474
475 func TestVerifyHostname(t *testing.T) {
476         testenv.MustHaveExternalNetwork(t)
477
478         c, err := Dial("tcp", "www.google.com:https", nil)
479         if err != nil {
480                 t.Fatal(err)
481         }
482         if err := c.VerifyHostname("www.google.com"); err != nil {
483                 t.Fatalf("verify www.google.com: %v", err)
484         }
485         if err := c.VerifyHostname("www.yahoo.com"); err == nil {
486                 t.Fatalf("verify www.yahoo.com succeeded")
487         }
488
489         c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true})
490         if err != nil {
491                 t.Fatal(err)
492         }
493         if err := c.VerifyHostname("www.google.com"); err == nil {
494                 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
495         }
496 }
497
498 func TestConnCloseBreakingWrite(t *testing.T) {
499         ln := newLocalListener(t)
500         defer ln.Close()
501
502         srvCh := make(chan *Conn, 1)
503         var serr error
504         var sconn net.Conn
505         go func() {
506                 var err error
507                 sconn, err = ln.Accept()
508                 if err != nil {
509                         serr = err
510                         srvCh <- nil
511                         return
512                 }
513                 serverConfig := testConfig.Clone()
514                 srv := Server(sconn, serverConfig)
515                 if err := srv.Handshake(); err != nil {
516                         serr = fmt.Errorf("handshake: %v", err)
517                         srvCh <- nil
518                         return
519                 }
520                 srvCh <- srv
521         }()
522
523         cconn, err := net.Dial("tcp", ln.Addr().String())
524         if err != nil {
525                 t.Fatal(err)
526         }
527         defer cconn.Close()
528
529         conn := &changeImplConn{
530                 Conn: cconn,
531         }
532
533         clientConfig := testConfig.Clone()
534         tconn := Client(conn, clientConfig)
535         if err := tconn.Handshake(); err != nil {
536                 t.Fatal(err)
537         }
538
539         srv := <-srvCh
540         if srv == nil {
541                 t.Fatal(serr)
542         }
543         defer sconn.Close()
544
545         connClosed := make(chan struct{})
546         conn.closeFunc = func() error {
547                 close(connClosed)
548                 return nil
549         }
550
551         inWrite := make(chan bool, 1)
552         var errConnClosed = errors.New("conn closed for test")
553         conn.writeFunc = func(p []byte) (n int, err error) {
554                 inWrite <- true
555                 <-connClosed
556                 return 0, errConnClosed
557         }
558
559         closeReturned := make(chan bool, 1)
560         go func() {
561                 <-inWrite
562                 tconn.Close() // test that this doesn't block forever.
563                 closeReturned <- true
564         }()
565
566         _, err = tconn.Write([]byte("foo"))
567         if err != errConnClosed {
568                 t.Errorf("Write error = %v; want errConnClosed", err)
569         }
570
571         <-closeReturned
572         if err := tconn.Close(); err != errClosed {
573                 t.Errorf("Close error = %v; want errClosed", err)
574         }
575 }
576
577 func TestConnCloseWrite(t *testing.T) {
578         ln := newLocalListener(t)
579         defer ln.Close()
580
581         clientDoneChan := make(chan struct{})
582
583         serverCloseWrite := func() error {
584                 sconn, err := ln.Accept()
585                 if err != nil {
586                         return fmt.Errorf("accept: %v", err)
587                 }
588                 defer sconn.Close()
589
590                 serverConfig := testConfig.Clone()
591                 srv := Server(sconn, serverConfig)
592                 if err := srv.Handshake(); err != nil {
593                         return fmt.Errorf("handshake: %v", err)
594                 }
595                 defer srv.Close()
596
597                 data, err := ioutil.ReadAll(srv)
598                 if err != nil {
599                         return err
600                 }
601                 if len(data) > 0 {
602                         return fmt.Errorf("Read data = %q; want nothing", data)
603                 }
604
605                 if err := srv.CloseWrite(); err != nil {
606                         return fmt.Errorf("server CloseWrite: %v", err)
607                 }
608
609                 // Wait for clientCloseWrite to finish, so we know we
610                 // tested the CloseWrite before we defer the
611                 // sconn.Close above, which would also cause the
612                 // client to unblock like CloseWrite.
613                 <-clientDoneChan
614                 return nil
615         }
616
617         clientCloseWrite := func() error {
618                 defer close(clientDoneChan)
619
620                 clientConfig := testConfig.Clone()
621                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
622                 if err != nil {
623                         return err
624                 }
625                 if err := conn.Handshake(); err != nil {
626                         return err
627                 }
628                 defer conn.Close()
629
630                 if err := conn.CloseWrite(); err != nil {
631                         return fmt.Errorf("client CloseWrite: %v", err)
632                 }
633
634                 if _, err := conn.Write([]byte{0}); err != errShutdown {
635                         return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
636                 }
637
638                 data, err := ioutil.ReadAll(conn)
639                 if err != nil {
640                         return err
641                 }
642                 if len(data) > 0 {
643                         return fmt.Errorf("Read data = %q; want nothing", data)
644                 }
645                 return nil
646         }
647
648         errChan := make(chan error, 2)
649
650         go func() { errChan <- serverCloseWrite() }()
651         go func() { errChan <- clientCloseWrite() }()
652
653         for i := 0; i < 2; i++ {
654                 select {
655                 case err := <-errChan:
656                         if err != nil {
657                                 t.Fatal(err)
658                         }
659                 case <-time.After(10 * time.Second):
660                         t.Fatal("deadlock")
661                 }
662         }
663
664         // Also test CloseWrite being called before the handshake is
665         // finished:
666         {
667                 ln2 := newLocalListener(t)
668                 defer ln2.Close()
669
670                 netConn, err := net.Dial("tcp", ln2.Addr().String())
671                 if err != nil {
672                         t.Fatal(err)
673                 }
674                 defer netConn.Close()
675                 conn := Client(netConn, testConfig.Clone())
676
677                 if err := conn.CloseWrite(); err != errEarlyCloseWrite {
678                         t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
679                 }
680         }
681 }
682
683 func TestWarningAlertFlood(t *testing.T) {
684         ln := newLocalListener(t)
685         defer ln.Close()
686
687         server := func() error {
688                 sconn, err := ln.Accept()
689                 if err != nil {
690                         return fmt.Errorf("accept: %v", err)
691                 }
692                 defer sconn.Close()
693
694                 serverConfig := testConfig.Clone()
695                 srv := Server(sconn, serverConfig)
696                 if err := srv.Handshake(); err != nil {
697                         return fmt.Errorf("handshake: %v", err)
698                 }
699                 defer srv.Close()
700
701                 _, err = ioutil.ReadAll(srv)
702                 if err == nil {
703                         return errors.New("unexpected lack of error from server")
704                 }
705                 const expected = "too many ignored"
706                 if str := err.Error(); !strings.Contains(str, expected) {
707                         return fmt.Errorf("expected error containing %q, but saw: %s", expected, str)
708                 }
709
710                 return nil
711         }
712
713         errChan := make(chan error, 1)
714         go func() { errChan <- server() }()
715
716         clientConfig := testConfig.Clone()
717         clientConfig.MaxVersion = VersionTLS12 // there are no warning alerts in TLS 1.3
718         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
719         if err != nil {
720                 t.Fatal(err)
721         }
722         defer conn.Close()
723         if err := conn.Handshake(); err != nil {
724                 t.Fatal(err)
725         }
726
727         for i := 0; i < maxUselessRecords+1; i++ {
728                 conn.sendAlert(alertNoRenegotiation)
729         }
730
731         if err := <-errChan; err != nil {
732                 t.Fatal(err)
733         }
734 }
735
736 func TestCloneFuncFields(t *testing.T) {
737         const expectedCount = 5
738         called := 0
739
740         c1 := Config{
741                 Time: func() time.Time {
742                         called |= 1 << 0
743                         return time.Time{}
744                 },
745                 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
746                         called |= 1 << 1
747                         return nil, nil
748                 },
749                 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) {
750                         called |= 1 << 2
751                         return nil, nil
752                 },
753                 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) {
754                         called |= 1 << 3
755                         return nil, nil
756                 },
757                 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
758                         called |= 1 << 4
759                         return nil
760                 },
761         }
762
763         c2 := c1.Clone()
764
765         c2.Time()
766         c2.GetCertificate(nil)
767         c2.GetClientCertificate(nil)
768         c2.GetConfigForClient(nil)
769         c2.VerifyPeerCertificate(nil, nil)
770
771         if called != (1<<expectedCount)-1 {
772                 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
773         }
774 }
775
776 func TestCloneNonFuncFields(t *testing.T) {
777         var c1 Config
778         v := reflect.ValueOf(&c1).Elem()
779
780         typ := v.Type()
781         for i := 0; i < typ.NumField(); i++ {
782                 f := v.Field(i)
783                 if !f.CanSet() {
784                         // unexported field; not cloned.
785                         continue
786                 }
787
788                 // testing/quick can't handle functions or interfaces and so
789                 // isn't used here.
790                 switch fn := typ.Field(i).Name; fn {
791                 case "Rand":
792                         f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
793                 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
794                         // DeepEqual can't compare functions. If you add a
795                         // function field to this list, you must also change
796                         // TestCloneFuncFields to ensure that the func field is
797                         // cloned.
798                 case "Certificates":
799                         f.Set(reflect.ValueOf([]Certificate{
800                                 {Certificate: [][]byte{{'b'}}},
801                         }))
802                 case "NameToCertificate":
803                         f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil}))
804                 case "RootCAs", "ClientCAs":
805                         f.Set(reflect.ValueOf(x509.NewCertPool()))
806                 case "ClientSessionCache":
807                         f.Set(reflect.ValueOf(NewLRUClientSessionCache(10)))
808                 case "KeyLogWriter":
809                         f.Set(reflect.ValueOf(io.Writer(os.Stdout)))
810                 case "NextProtos":
811                         f.Set(reflect.ValueOf([]string{"a", "b"}))
812                 case "ServerName":
813                         f.Set(reflect.ValueOf("b"))
814                 case "ClientAuth":
815                         f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
816                 case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
817                         f.Set(reflect.ValueOf(true))
818                 case "MinVersion", "MaxVersion":
819                         f.Set(reflect.ValueOf(uint16(VersionTLS12)))
820                 case "SessionTicketKey":
821                         f.Set(reflect.ValueOf([32]byte{}))
822                 case "CipherSuites":
823                         f.Set(reflect.ValueOf([]uint16{1, 2}))
824                 case "CurvePreferences":
825                         f.Set(reflect.ValueOf([]CurveID{CurveP256}))
826                 case "Renegotiation":
827                         f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
828                 default:
829                         t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
830                 }
831         }
832
833         c2 := c1.Clone()
834         // DeepEqual also compares unexported fields, thus c2 needs to have run
835         // serverInit in order to be DeepEqual to c1. Cloning it and discarding
836         // the result is sufficient.
837         c2.Clone()
838
839         if !reflect.DeepEqual(&c1, c2) {
840                 t.Errorf("clone failed to copy a field")
841         }
842 }
843
844 // changeImplConn is a net.Conn which can change its Write and Close
845 // methods.
846 type changeImplConn struct {
847         net.Conn
848         writeFunc func([]byte) (int, error)
849         closeFunc func() error
850 }
851
852 func (w *changeImplConn) Write(p []byte) (n int, err error) {
853         if w.writeFunc != nil {
854                 return w.writeFunc(p)
855         }
856         return w.Conn.Write(p)
857 }
858
859 func (w *changeImplConn) Close() error {
860         if w.closeFunc != nil {
861                 return w.closeFunc()
862         }
863         return w.Conn.Close()
864 }
865
866 func throughput(b *testing.B, version uint16, totalBytes int64, dynamicRecordSizingDisabled bool) {
867         ln := newLocalListener(b)
868         defer ln.Close()
869
870         N := b.N
871
872         // Less than 64KB because Windows appears to use a TCP rwin < 64KB.
873         // See Issue #15899.
874         const bufsize = 32 << 10
875
876         go func() {
877                 buf := make([]byte, bufsize)
878                 for i := 0; i < N; i++ {
879                         sconn, err := ln.Accept()
880                         if err != nil {
881                                 // panic rather than synchronize to avoid benchmark overhead
882                                 // (cannot call b.Fatal in goroutine)
883                                 panic(fmt.Errorf("accept: %v", err))
884                         }
885                         serverConfig := testConfig.Clone()
886                         serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers
887                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
888                         srv := Server(sconn, serverConfig)
889                         if err := srv.Handshake(); err != nil {
890                                 panic(fmt.Errorf("handshake: %v", err))
891                         }
892                         if _, err := io.CopyBuffer(srv, srv, buf); err != nil {
893                                 panic(fmt.Errorf("copy buffer: %v", err))
894                         }
895                 }
896         }()
897
898         b.SetBytes(totalBytes)
899         clientConfig := testConfig.Clone()
900         clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers
901         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
902         clientConfig.MaxVersion = version
903
904         buf := make([]byte, bufsize)
905         chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
906         for i := 0; i < N; i++ {
907                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
908                 if err != nil {
909                         b.Fatal(err)
910                 }
911                 for j := 0; j < chunks; j++ {
912                         _, err := conn.Write(buf)
913                         if err != nil {
914                                 b.Fatal(err)
915                         }
916                         _, err = io.ReadFull(conn, buf)
917                         if err != nil {
918                                 b.Fatal(err)
919                         }
920                 }
921                 conn.Close()
922         }
923 }
924
925 func BenchmarkThroughput(b *testing.B) {
926         for _, mode := range []string{"Max", "Dynamic"} {
927                 for size := 1; size <= 64; size <<= 1 {
928                         name := fmt.Sprintf("%sPacket/%dMB", mode, size)
929                         b.Run(name, func(b *testing.B) {
930                                 b.Run("TLSv12", func(b *testing.B) {
931                                         throughput(b, VersionTLS12, int64(size<<20), mode == "Max")
932                                 })
933                                 b.Run("TLSv13", func(b *testing.B) {
934                                         throughput(b, VersionTLS13, int64(size<<20), mode == "Max")
935                                 })
936                         })
937                 }
938         }
939 }
940
941 type slowConn struct {
942         net.Conn
943         bps int
944 }
945
946 func (c *slowConn) Write(p []byte) (int, error) {
947         if c.bps == 0 {
948                 panic("too slow")
949         }
950         t0 := time.Now()
951         wrote := 0
952         for wrote < len(p) {
953                 time.Sleep(100 * time.Microsecond)
954                 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8
955                 if allowed > len(p) {
956                         allowed = len(p)
957                 }
958                 if wrote < allowed {
959                         n, err := c.Conn.Write(p[wrote:allowed])
960                         wrote += n
961                         if err != nil {
962                                 return wrote, err
963                         }
964                 }
965         }
966         return len(p), nil
967 }
968
969 func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) {
970         ln := newLocalListener(b)
971         defer ln.Close()
972
973         N := b.N
974
975         go func() {
976                 for i := 0; i < N; i++ {
977                         sconn, err := ln.Accept()
978                         if err != nil {
979                                 // panic rather than synchronize to avoid benchmark overhead
980                                 // (cannot call b.Fatal in goroutine)
981                                 panic(fmt.Errorf("accept: %v", err))
982                         }
983                         serverConfig := testConfig.Clone()
984                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
985                         srv := Server(&slowConn{sconn, bps}, serverConfig)
986                         if err := srv.Handshake(); err != nil {
987                                 panic(fmt.Errorf("handshake: %v", err))
988                         }
989                         io.Copy(srv, srv)
990                 }
991         }()
992
993         clientConfig := testConfig.Clone()
994         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
995         clientConfig.MaxVersion = version
996
997         buf := make([]byte, 16384)
998         peek := make([]byte, 1)
999
1000         for i := 0; i < N; i++ {
1001                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
1002                 if err != nil {
1003                         b.Fatal(err)
1004                 }
1005                 // make sure we're connected and previous connection has stopped
1006                 if _, err := conn.Write(buf[:1]); err != nil {
1007                         b.Fatal(err)
1008                 }
1009                 if _, err := io.ReadFull(conn, peek); err != nil {
1010                         b.Fatal(err)
1011                 }
1012                 if _, err := conn.Write(buf); err != nil {
1013                         b.Fatal(err)
1014                 }
1015                 if _, err = io.ReadFull(conn, peek); err != nil {
1016                         b.Fatal(err)
1017                 }
1018                 conn.Close()
1019         }
1020 }
1021
1022 func BenchmarkLatency(b *testing.B) {
1023         for _, mode := range []string{"Max", "Dynamic"} {
1024                 for _, kbps := range []int{200, 500, 1000, 2000, 5000} {
1025                         name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps)
1026                         b.Run(name, func(b *testing.B) {
1027                                 b.Run("TLSv12", func(b *testing.B) {
1028                                         latency(b, VersionTLS12, kbps*1000, mode == "Max")
1029                                 })
1030                                 b.Run("TLSv13", func(b *testing.B) {
1031                                         latency(b, VersionTLS13, kbps*1000, mode == "Max")
1032                                 })
1033                         })
1034                 }
1035         }
1036 }
1037
1038 func TestConnectionStateMarshal(t *testing.T) {
1039         cs := &ConnectionState{}
1040         _, err := json.Marshal(cs)
1041         if err != nil {
1042                 t.Errorf("json.Marshal failed on ConnectionState: %v", err)
1043         }
1044 }
1045
1046 func TestConnectionState(t *testing.T) {
1047         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1048         if err != nil {
1049                 panic(err)
1050         }
1051         rootCAs := x509.NewCertPool()
1052         rootCAs.AddCert(issuer)
1053
1054         now := func() time.Time { return time.Unix(1476984729, 0) }
1055
1056         const alpnProtocol = "golang"
1057         const serverName = "example.golang"
1058         var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1059         var ocsp = []byte("dummy ocsp")
1060
1061         for _, v := range []uint16{VersionTLS12, VersionTLS13} {
1062                 var name string
1063                 switch v {
1064                 case VersionTLS12:
1065                         name = "TLSv12"
1066                 case VersionTLS13:
1067                         name = "TLSv13"
1068                 }
1069                 t.Run(name, func(t *testing.T) {
1070                         config := &Config{
1071                                 Time:         now,
1072                                 Rand:         zeroSource{},
1073                                 Certificates: make([]Certificate, 1),
1074                                 MaxVersion:   v,
1075                                 RootCAs:      rootCAs,
1076                                 ClientCAs:    rootCAs,
1077                                 ClientAuth:   RequireAndVerifyClientCert,
1078                                 NextProtos:   []string{alpnProtocol},
1079                                 ServerName:   serverName,
1080                         }
1081                         config.Certificates[0].Certificate = [][]byte{testRSACertificate}
1082                         config.Certificates[0].PrivateKey = testRSAPrivateKey
1083                         config.Certificates[0].SignedCertificateTimestamps = scts
1084                         config.Certificates[0].OCSPStaple = ocsp
1085
1086                         ss, cs, err := testHandshake(t, config, config)
1087                         if err != nil {
1088                                 t.Fatalf("Handshake failed: %v", err)
1089                         }
1090
1091                         if ss.Version != v || cs.Version != v {
1092                                 t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v)
1093                         }
1094
1095                         if !ss.HandshakeComplete || !cs.HandshakeComplete {
1096                                 t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete)
1097                         }
1098
1099                         if ss.DidResume || cs.DidResume {
1100                                 t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume)
1101                         }
1102
1103                         if ss.CipherSuite == 0 || cs.CipherSuite == 0 {
1104                                 t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite)
1105                         }
1106
1107                         if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol {
1108                                 t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol)
1109                         }
1110
1111                         if !cs.NegotiatedProtocolIsMutual {
1112                                 t.Errorf("Got false NegotiatedProtocolIsMutual on the client side")
1113                         }
1114                         // NegotiatedProtocolIsMutual on the server side is unspecified.
1115
1116                         if ss.ServerName != serverName {
1117                                 t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
1118                         }
1119                         if cs.ServerName != "" {
1120                                 t.Errorf("Got unexpected server name on the client side")
1121                         }
1122
1123                         if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
1124                                 t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1)
1125                         }
1126
1127                         if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 {
1128                                 t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1)
1129                         } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 {
1130                                 t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2)
1131                         }
1132
1133                         if len(cs.SignedCertificateTimestamps) != 2 {
1134                                 t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2)
1135                         }
1136                         if !bytes.Equal(cs.OCSPResponse, ocsp) {
1137                                 t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp)
1138                         }
1139                         // Only TLS 1.3 supports OCSP and SCTs on client certs.
1140                         if v == VersionTLS13 {
1141                                 if len(ss.SignedCertificateTimestamps) != 2 {
1142                                         t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2)
1143                                 }
1144                                 if !bytes.Equal(ss.OCSPResponse, ocsp) {
1145                                         t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp)
1146                                 }
1147                         }
1148
1149                         if v == VersionTLS13 {
1150                                 if ss.TLSUnique != nil || cs.TLSUnique != nil {
1151                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique)
1152                                 }
1153                         } else {
1154                                 if ss.TLSUnique == nil || cs.TLSUnique == nil {
1155                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique)
1156                                 }
1157                         }
1158                 })
1159         }
1160 }
1161
1162 // Issue 28744: Ensure that we don't modify memory
1163 // that Config doesn't own such as Certificates.
1164 func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) {
1165         c0 := Certificate{
1166                 Certificate: [][]byte{testRSACertificate},
1167                 PrivateKey:  testRSAPrivateKey,
1168         }
1169         c1 := Certificate{
1170                 Certificate: [][]byte{testSNICertificate},
1171                 PrivateKey:  testRSAPrivateKey,
1172         }
1173         config := testConfig.Clone()
1174         config.Certificates = []Certificate{c0, c1}
1175
1176         config.BuildNameToCertificate()
1177         got := config.Certificates
1178         want := []Certificate{c0, c1}
1179         if !reflect.DeepEqual(got, want) {
1180                 t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want)
1181         }
1182 }
1183
1184 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
1185
1186 func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
1187         rsaCert := &Certificate{
1188                 Certificate: [][]byte{testRSACertificate},
1189                 PrivateKey:  testRSAPrivateKey,
1190         }
1191         pkcs1Cert := &Certificate{
1192                 Certificate:                  [][]byte{testRSACertificate},
1193                 PrivateKey:                   testRSAPrivateKey,
1194                 SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256},
1195         }
1196         ecdsaCert := &Certificate{
1197                 // ECDSA P-256 certificate
1198                 Certificate: [][]byte{testP256Certificate},
1199                 PrivateKey:  testP256PrivateKey,
1200         }
1201         ed25519Cert := &Certificate{
1202                 Certificate: [][]byte{testEd25519Certificate},
1203                 PrivateKey:  testEd25519PrivateKey,
1204         }
1205
1206         tests := []struct {
1207                 c       *Certificate
1208                 chi     *ClientHelloInfo
1209                 wantErr string
1210         }{
1211                 {rsaCert, &ClientHelloInfo{
1212                         ServerName:        "example.golang",
1213                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1214                         SupportedVersions: []uint16{VersionTLS13},
1215                 }, ""},
1216                 {ecdsaCert, &ClientHelloInfo{
1217                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1218                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1219                 }, ""},
1220                 {rsaCert, &ClientHelloInfo{
1221                         ServerName:        "example.com",
1222                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1223                         SupportedVersions: []uint16{VersionTLS13},
1224                 }, "not valid for requested server name"},
1225                 {ecdsaCert, &ClientHelloInfo{
1226                         SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1227                         SupportedVersions: []uint16{VersionTLS13},
1228                 }, "signature algorithms"},
1229                 {pkcs1Cert, &ClientHelloInfo{
1230                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1231                         SupportedVersions: []uint16{VersionTLS13},
1232                 }, "signature algorithms"},
1233
1234                 {rsaCert, &ClientHelloInfo{
1235                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1236                         SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1237                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1238                 }, "signature algorithms"},
1239                 {rsaCert, &ClientHelloInfo{
1240                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1241                         SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1242                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1243                         config: &Config{
1244                                 MaxVersion: VersionTLS12,
1245                         },
1246                 }, ""}, // Check that mutual version selection works.
1247
1248                 {ecdsaCert, &ClientHelloInfo{
1249                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1250                         SupportedCurves:   []CurveID{CurveP256},
1251                         SupportedPoints:   []uint8{pointFormatUncompressed},
1252                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1253                         SupportedVersions: []uint16{VersionTLS12},
1254                 }, ""},
1255                 {ecdsaCert, &ClientHelloInfo{
1256                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1257                         SupportedCurves:   []CurveID{CurveP256},
1258                         SupportedPoints:   []uint8{pointFormatUncompressed},
1259                         SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1260                         SupportedVersions: []uint16{VersionTLS12},
1261                 }, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme.
1262                 {ecdsaCert, &ClientHelloInfo{
1263                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1264                         SupportedCurves:   []CurveID{CurveP256},
1265                         SupportedPoints:   []uint8{pointFormatUncompressed},
1266                         SignatureSchemes:  nil,
1267                         SupportedVersions: []uint16{VersionTLS12},
1268                 }, ""}, // TLS 1.2 comes with default signature schemes.
1269                 {ecdsaCert, &ClientHelloInfo{
1270                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1271                         SupportedCurves:   []CurveID{CurveP256},
1272                         SupportedPoints:   []uint8{pointFormatUncompressed},
1273                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1274                         SupportedVersions: []uint16{VersionTLS12},
1275                 }, "cipher suite"},
1276                 {ecdsaCert, &ClientHelloInfo{
1277                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1278                         SupportedCurves:   []CurveID{CurveP256},
1279                         SupportedPoints:   []uint8{pointFormatUncompressed},
1280                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1281                         SupportedVersions: []uint16{VersionTLS12},
1282                         config: &Config{
1283                                 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1284                         },
1285                 }, "cipher suite"},
1286                 {ecdsaCert, &ClientHelloInfo{
1287                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1288                         SupportedCurves:   []CurveID{CurveP384},
1289                         SupportedPoints:   []uint8{pointFormatUncompressed},
1290                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1291                         SupportedVersions: []uint16{VersionTLS12},
1292                 }, "certificate curve"},
1293                 {ecdsaCert, &ClientHelloInfo{
1294                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1295                         SupportedCurves:   []CurveID{CurveP256},
1296                         SupportedPoints:   []uint8{1},
1297                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1298                         SupportedVersions: []uint16{VersionTLS12},
1299                 }, "doesn't support ECDHE"},
1300                 {ecdsaCert, &ClientHelloInfo{
1301                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1302                         SupportedCurves:   []CurveID{CurveP256},
1303                         SupportedPoints:   []uint8{pointFormatUncompressed},
1304                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1305                         SupportedVersions: []uint16{VersionTLS12},
1306                 }, "signature algorithms"},
1307
1308                 {ed25519Cert, &ClientHelloInfo{
1309                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1310                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1311                         SupportedPoints:   []uint8{pointFormatUncompressed},
1312                         SignatureSchemes:  []SignatureScheme{Ed25519},
1313                         SupportedVersions: []uint16{VersionTLS12},
1314                 }, ""},
1315                 {ed25519Cert, &ClientHelloInfo{
1316                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1317                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1318                         SupportedPoints:   []uint8{pointFormatUncompressed},
1319                         SignatureSchemes:  []SignatureScheme{Ed25519},
1320                         SupportedVersions: []uint16{VersionTLS10},
1321                 }, "doesn't support Ed25519"},
1322                 {ed25519Cert, &ClientHelloInfo{
1323                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1324                         SupportedCurves:   []CurveID{},
1325                         SupportedPoints:   []uint8{pointFormatUncompressed},
1326                         SignatureSchemes:  []SignatureScheme{Ed25519},
1327                         SupportedVersions: []uint16{VersionTLS12},
1328                 }, "doesn't support ECDHE"},
1329
1330                 {rsaCert, &ClientHelloInfo{
1331                         CipherSuites:      []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
1332                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1333                         SupportedPoints:   []uint8{pointFormatUncompressed},
1334                         SupportedVersions: []uint16{VersionTLS10},
1335                 }, ""},
1336                 {rsaCert, &ClientHelloInfo{
1337                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1338                         SupportedVersions: []uint16{VersionTLS12},
1339                 }, ""}, // static RSA fallback
1340         }
1341         for i, tt := range tests {
1342                 err := tt.chi.SupportsCertificate(tt.c)
1343                 switch {
1344                 case tt.wantErr == "" && err != nil:
1345                         t.Errorf("%d: unexpected error: %v", i, err)
1346                 case tt.wantErr != "" && err == nil:
1347                         t.Errorf("%d: unexpected success", i)
1348                 case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr):
1349                         t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr)
1350                 }
1351         }
1352 }
1353
1354 func TestCipherSuites(t *testing.T) {
1355         var lastID uint16
1356         for _, c := range CipherSuites() {
1357                 if lastID > c.ID {
1358                         t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1359                 } else {
1360                         lastID = c.ID
1361                 }
1362
1363                 if c.Insecure {
1364                         t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID)
1365                 }
1366         }
1367         lastID = 0
1368         for _, c := range InsecureCipherSuites() {
1369                 if lastID > c.ID {
1370                         t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1371                 } else {
1372                         lastID = c.ID
1373                 }
1374
1375                 if !c.Insecure {
1376                         t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID)
1377                 }
1378         }
1379
1380         cipherSuiteByID := func(id uint16) *CipherSuite {
1381                 for _, c := range CipherSuites() {
1382                         if c.ID == id {
1383                                 return c
1384                         }
1385                 }
1386                 for _, c := range InsecureCipherSuites() {
1387                         if c.ID == id {
1388                                 return c
1389                         }
1390                 }
1391                 return nil
1392         }
1393
1394         for _, c := range cipherSuites {
1395                 cc := cipherSuiteByID(c.id)
1396                 if cc == nil {
1397                         t.Errorf("%#04x: no CipherSuite entry", c.id)
1398                         continue
1399                 }
1400
1401                 if defaultOff := c.flags&suiteDefaultOff != 0; defaultOff != cc.Insecure {
1402                         t.Errorf("%#04x: Insecure %v, expected %v", c.id, cc.Insecure, defaultOff)
1403                 }
1404                 if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
1405                         t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1406                 } else if !tls12Only && len(cc.SupportedVersions) != 3 {
1407                         t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1408                 }
1409
1410                 if got := CipherSuiteName(c.id); got != cc.Name {
1411                         t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1412                 }
1413         }
1414         for _, c := range cipherSuitesTLS13 {
1415                 cc := cipherSuiteByID(c.id)
1416                 if cc == nil {
1417                         t.Errorf("%#04x: no CipherSuite entry", c.id)
1418                         continue
1419                 }
1420
1421                 if cc.Insecure {
1422                         t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure)
1423                 }
1424                 if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 {
1425                         t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1426                 }
1427
1428                 if got := CipherSuiteName(c.id); got != cc.Name {
1429                         t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1430                 }
1431         }
1432
1433         if got := CipherSuiteName(0xabc); got != "0x0ABC" {
1434                 t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
1435         }
1436 }
1437
1438 type brokenSigner struct{ crypto.Signer }
1439
1440 func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
1441         // Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded.
1442         return s.Signer.Sign(rand, digest, opts.HashFunc())
1443 }
1444
1445 // TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that
1446 // always makes PKCS#1 v1.5 signatures, so can't be used with RSA-PSS.
1447 func TestPKCS1OnlyCert(t *testing.T) {
1448         clientConfig := testConfig.Clone()
1449         clientConfig.Certificates = []Certificate{{
1450                 Certificate: [][]byte{testRSACertificate},
1451                 PrivateKey:  brokenSigner{testRSAPrivateKey},
1452         }}
1453         serverConfig := testConfig.Clone()
1454         serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS#1 v1.5
1455         serverConfig.ClientAuth = RequireAnyClientCert
1456
1457         // If RSA-PSS is selected, the handshake should fail.
1458         if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil {
1459                 t.Fatal("expected broken certificate to cause connection to fail")
1460         }
1461
1462         clientConfig.Certificates[0].SupportedSignatureAlgorithms =
1463                 []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}
1464
1465         // But if the certificate restricts supported algorithms, RSA-PSS should not
1466         // be selected, and the handshake should succeed.
1467         if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
1468                 t.Error(err)
1469         }
1470 }