]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/tls_test.go
crypto/tls: failed tls.Conn.Write returns a permanent error
[gostls13.git] / src / crypto / tls / tls_test.go
1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "crypto"
10         "crypto/x509"
11         "encoding/json"
12         "errors"
13         "fmt"
14         "internal/testenv"
15         "io"
16         "io/ioutil"
17         "math"
18         "net"
19         "os"
20         "reflect"
21         "strings"
22         "testing"
23         "time"
24 )
25
26 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
27 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
28 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
29 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
30 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
31 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
32 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
33 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
34 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
35 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
36 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
37 -----END CERTIFICATE-----
38 `
39
40 var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY-----
41 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
42 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
43 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
44 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
45 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
46 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
47 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
48 -----END RSA TESTING KEY-----
49 `)
50
51 // keyPEM is the same as rsaKeyPEM, but declares itself as just
52 // "PRIVATE KEY", not "RSA PRIVATE KEY".  https://golang.org/issue/4477
53 var keyPEM = testingKey(`-----BEGIN TESTING KEY-----
54 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
55 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
56 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
57 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
58 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
59 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
60 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
61 -----END TESTING KEY-----
62 `)
63
64 var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
65 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
66 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
67 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG
68 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk
69 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR
70 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl
71 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8
72 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo
73 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb
74 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1
75 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA==
76 -----END CERTIFICATE-----
77 `
78
79 var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS-----
80 BgUrgQQAIw==
81 -----END EC PARAMETERS-----
82 -----BEGIN EC TESTING KEY-----
83 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
84 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
85 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
86 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
87 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
88 -----END EC TESTING KEY-----
89 `)
90
91 var keyPairTests = []struct {
92         algo string
93         cert string
94         key  string
95 }{
96         {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM},
97         {"RSA", rsaCertPEM, rsaKeyPEM},
98         {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477
99 }
100
101 func TestX509KeyPair(t *testing.T) {
102         t.Parallel()
103         var pem []byte
104         for _, test := range keyPairTests {
105                 pem = []byte(test.cert + test.key)
106                 if _, err := X509KeyPair(pem, pem); err != nil {
107                         t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err)
108                 }
109                 pem = []byte(test.key + test.cert)
110                 if _, err := X509KeyPair(pem, pem); err != nil {
111                         t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err)
112                 }
113         }
114 }
115
116 func TestX509KeyPairErrors(t *testing.T) {
117         _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM))
118         if err == nil {
119                 t.Fatalf("X509KeyPair didn't return an error when arguments were switched")
120         }
121         if subStr := "been switched"; !strings.Contains(err.Error(), subStr) {
122                 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err)
123         }
124
125         _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM))
126         if err == nil {
127                 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates")
128         }
129         if subStr := "certificate"; !strings.Contains(err.Error(), subStr) {
130                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err)
131         }
132
133         const nonsensePEM = `
134 -----BEGIN NONSENSE-----
135 Zm9vZm9vZm9v
136 -----END NONSENSE-----
137 `
138
139         _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM))
140         if err == nil {
141                 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense")
142         }
143         if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) {
144                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err)
145         }
146 }
147
148 func TestX509MixedKeyPair(t *testing.T) {
149         if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil {
150                 t.Error("Load of RSA certificate succeeded with ECDSA private key")
151         }
152         if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil {
153                 t.Error("Load of ECDSA certificate succeeded with RSA private key")
154         }
155 }
156
157 func newLocalListener(t testing.TB) net.Listener {
158         ln, err := net.Listen("tcp", "127.0.0.1:0")
159         if err != nil {
160                 ln, err = net.Listen("tcp6", "[::1]:0")
161         }
162         if err != nil {
163                 t.Fatal(err)
164         }
165         return ln
166 }
167
168 func TestDialTimeout(t *testing.T) {
169         if testing.Short() {
170                 t.Skip("skipping in short mode")
171         }
172         listener := newLocalListener(t)
173
174         addr := listener.Addr().String()
175         defer listener.Close()
176
177         complete := make(chan bool)
178         defer close(complete)
179
180         go func() {
181                 conn, err := listener.Accept()
182                 if err != nil {
183                         t.Error(err)
184                         return
185                 }
186                 <-complete
187                 conn.Close()
188         }()
189
190         dialer := &net.Dialer{
191                 Timeout: 10 * time.Millisecond,
192         }
193
194         var err error
195         if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
196                 t.Fatal("DialWithTimeout completed successfully")
197         }
198
199         if !isTimeoutError(err) {
200                 t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
201         }
202 }
203
204 func TestDeadlineOnWrite(t *testing.T) {
205         if testing.Short() {
206                 t.Skip("skipping in short mode")
207         }
208
209         ln := newLocalListener(t)
210         defer ln.Close()
211
212         srvCh := make(chan *Conn, 1)
213
214         go func() {
215                 sconn, err := ln.Accept()
216                 if err != nil {
217                         srvCh <- nil
218                         return
219                 }
220                 srv := Server(sconn, testConfig.Clone())
221                 if err := srv.Handshake(); err != nil {
222                         srvCh <- nil
223                         return
224                 }
225                 srvCh <- srv
226         }()
227
228         clientConfig := testConfig.Clone()
229         clientConfig.MaxVersion = VersionTLS12
230         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
231         if err != nil {
232                 t.Fatal(err)
233         }
234         defer conn.Close()
235
236         srv := <-srvCh
237         if srv == nil {
238                 t.Error(err)
239         }
240
241         // Make sure the client/server is setup correctly and is able to do a typical Write/Read
242         buf := make([]byte, 6)
243         if _, err := srv.Write([]byte("foobar")); err != nil {
244                 t.Errorf("Write err: %v", err)
245         }
246         if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" {
247                 t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
248         }
249
250         // Set a deadline which should cause Write to timeout
251         if err = srv.SetDeadline(time.Now()); err != nil {
252                 t.Fatalf("SetDeadline(time.Now()) err: %v", err)
253         }
254         if _, err = srv.Write([]byte("should fail")); err == nil {
255                 t.Fatal("Write should have timed out")
256         }
257
258         // Clear deadline and make sure it still times out
259         if err = srv.SetDeadline(time.Time{}); err != nil {
260                 t.Fatalf("SetDeadline(time.Time{}) err: %v", err)
261         }
262         if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil {
263                 t.Fatal("Write which previously failed should still time out")
264         }
265
266         // Verify the error
267         if ne := err.(net.Error); ne.Temporary() != false {
268                 t.Error("Write timed out but incorrectly classified the error as Temporary")
269         }
270         if !isTimeoutError(err) {
271                 t.Error("Write timed out but did not classify the error as a Timeout")
272         }
273 }
274
275 func isTimeoutError(err error) bool {
276         if ne, ok := err.(net.Error); ok {
277                 return ne.Timeout()
278         }
279         return false
280 }
281
282 // tests that Conn.Read returns (non-zero, io.EOF) instead of
283 // (non-zero, nil) when a Close (alertCloseNotify) is sitting right
284 // behind the application data in the buffer.
285 func TestConnReadNonzeroAndEOF(t *testing.T) {
286         // This test is racy: it assumes that after a write to a
287         // localhost TCP connection, the peer TCP connection can
288         // immediately read it. Because it's racy, we skip this test
289         // in short mode, and then retry it several times with an
290         // increasing sleep in between our final write (via srv.Close
291         // below) and the following read.
292         if testing.Short() {
293                 t.Skip("skipping in short mode")
294         }
295         var err error
296         for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 {
297                 if err = testConnReadNonzeroAndEOF(t, delay); err == nil {
298                         return
299                 }
300         }
301         t.Error(err)
302 }
303
304 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error {
305         ln := newLocalListener(t)
306         defer ln.Close()
307
308         srvCh := make(chan *Conn, 1)
309         var serr error
310         go func() {
311                 sconn, err := ln.Accept()
312                 if err != nil {
313                         serr = err
314                         srvCh <- nil
315                         return
316                 }
317                 serverConfig := testConfig.Clone()
318                 srv := Server(sconn, serverConfig)
319                 if err := srv.Handshake(); err != nil {
320                         serr = fmt.Errorf("handshake: %v", err)
321                         srvCh <- nil
322                         return
323                 }
324                 srvCh <- srv
325         }()
326
327         clientConfig := testConfig.Clone()
328         // In TLS 1.3, alerts are encrypted and disguised as application data, so
329         // the opportunistic peek won't work.
330         clientConfig.MaxVersion = VersionTLS12
331         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
332         if err != nil {
333                 t.Fatal(err)
334         }
335         defer conn.Close()
336
337         srv := <-srvCh
338         if srv == nil {
339                 return serr
340         }
341
342         buf := make([]byte, 6)
343
344         srv.Write([]byte("foobar"))
345         n, err := conn.Read(buf)
346         if n != 6 || err != nil || string(buf) != "foobar" {
347                 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
348         }
349
350         srv.Write([]byte("abcdef"))
351         srv.Close()
352         time.Sleep(delay)
353         n, err = conn.Read(buf)
354         if n != 6 || string(buf) != "abcdef" {
355                 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf)
356         }
357         if err != io.EOF {
358                 return fmt.Errorf("Second Read error = %v; want io.EOF", err)
359         }
360         return nil
361 }
362
363 func TestTLSUniqueMatches(t *testing.T) {
364         ln := newLocalListener(t)
365         defer ln.Close()
366
367         serverTLSUniques := make(chan []byte)
368         parentDone := make(chan struct{})
369         childDone := make(chan struct{})
370         defer close(parentDone)
371         go func() {
372                 defer close(childDone)
373                 for i := 0; i < 2; i++ {
374                         sconn, err := ln.Accept()
375                         if err != nil {
376                                 t.Error(err)
377                                 return
378                         }
379                         serverConfig := testConfig.Clone()
380                         serverConfig.MaxVersion = VersionTLS12 // TLSUnique is not defined in TLS 1.3
381                         srv := Server(sconn, serverConfig)
382                         if err := srv.Handshake(); err != nil {
383                                 t.Error(err)
384                                 return
385                         }
386                         select {
387                         case <-parentDone:
388                                 return
389                         case serverTLSUniques <- srv.ConnectionState().TLSUnique:
390                         }
391                 }
392         }()
393
394         clientConfig := testConfig.Clone()
395         clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
396         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
397         if err != nil {
398                 t.Fatal(err)
399         }
400
401         var serverTLSUniquesValue []byte
402         select {
403         case <-childDone:
404                 return
405         case serverTLSUniquesValue = <-serverTLSUniques:
406         }
407
408         if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
409                 t.Error("client and server channel bindings differ")
410         }
411         conn.Close()
412
413         conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
414         if err != nil {
415                 t.Fatal(err)
416         }
417         defer conn.Close()
418         if !conn.ConnectionState().DidResume {
419                 t.Error("second session did not use resumption")
420         }
421
422         select {
423         case <-childDone:
424                 return
425         case serverTLSUniquesValue = <-serverTLSUniques:
426         }
427
428         if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
429                 t.Error("client and server channel bindings differ when session resumption is used")
430         }
431 }
432
433 func TestVerifyHostname(t *testing.T) {
434         testenv.MustHaveExternalNetwork(t)
435
436         c, err := Dial("tcp", "www.google.com:https", nil)
437         if err != nil {
438                 t.Fatal(err)
439         }
440         if err := c.VerifyHostname("www.google.com"); err != nil {
441                 t.Fatalf("verify www.google.com: %v", err)
442         }
443         if err := c.VerifyHostname("www.yahoo.com"); err == nil {
444                 t.Fatalf("verify www.yahoo.com succeeded")
445         }
446
447         c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true})
448         if err != nil {
449                 t.Fatal(err)
450         }
451         if err := c.VerifyHostname("www.google.com"); err == nil {
452                 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
453         }
454 }
455
456 func TestConnCloseBreakingWrite(t *testing.T) {
457         ln := newLocalListener(t)
458         defer ln.Close()
459
460         srvCh := make(chan *Conn, 1)
461         var serr error
462         var sconn net.Conn
463         go func() {
464                 var err error
465                 sconn, err = ln.Accept()
466                 if err != nil {
467                         serr = err
468                         srvCh <- nil
469                         return
470                 }
471                 serverConfig := testConfig.Clone()
472                 srv := Server(sconn, serverConfig)
473                 if err := srv.Handshake(); err != nil {
474                         serr = fmt.Errorf("handshake: %v", err)
475                         srvCh <- nil
476                         return
477                 }
478                 srvCh <- srv
479         }()
480
481         cconn, err := net.Dial("tcp", ln.Addr().String())
482         if err != nil {
483                 t.Fatal(err)
484         }
485         defer cconn.Close()
486
487         conn := &changeImplConn{
488                 Conn: cconn,
489         }
490
491         clientConfig := testConfig.Clone()
492         tconn := Client(conn, clientConfig)
493         if err := tconn.Handshake(); err != nil {
494                 t.Fatal(err)
495         }
496
497         srv := <-srvCh
498         if srv == nil {
499                 t.Fatal(serr)
500         }
501         defer sconn.Close()
502
503         connClosed := make(chan struct{})
504         conn.closeFunc = func() error {
505                 close(connClosed)
506                 return nil
507         }
508
509         inWrite := make(chan bool, 1)
510         var errConnClosed = errors.New("conn closed for test")
511         conn.writeFunc = func(p []byte) (n int, err error) {
512                 inWrite <- true
513                 <-connClosed
514                 return 0, errConnClosed
515         }
516
517         closeReturned := make(chan bool, 1)
518         go func() {
519                 <-inWrite
520                 tconn.Close() // test that this doesn't block forever.
521                 closeReturned <- true
522         }()
523
524         _, err = tconn.Write([]byte("foo"))
525         if err != errConnClosed {
526                 t.Errorf("Write error = %v; want errConnClosed", err)
527         }
528
529         <-closeReturned
530         if err := tconn.Close(); err != errClosed {
531                 t.Errorf("Close error = %v; want errClosed", err)
532         }
533 }
534
535 func TestConnCloseWrite(t *testing.T) {
536         ln := newLocalListener(t)
537         defer ln.Close()
538
539         clientDoneChan := make(chan struct{})
540
541         serverCloseWrite := func() error {
542                 sconn, err := ln.Accept()
543                 if err != nil {
544                         return fmt.Errorf("accept: %v", err)
545                 }
546                 defer sconn.Close()
547
548                 serverConfig := testConfig.Clone()
549                 srv := Server(sconn, serverConfig)
550                 if err := srv.Handshake(); err != nil {
551                         return fmt.Errorf("handshake: %v", err)
552                 }
553                 defer srv.Close()
554
555                 data, err := ioutil.ReadAll(srv)
556                 if err != nil {
557                         return err
558                 }
559                 if len(data) > 0 {
560                         return fmt.Errorf("Read data = %q; want nothing", data)
561                 }
562
563                 if err := srv.CloseWrite(); err != nil {
564                         return fmt.Errorf("server CloseWrite: %v", err)
565                 }
566
567                 // Wait for clientCloseWrite to finish, so we know we
568                 // tested the CloseWrite before we defer the
569                 // sconn.Close above, which would also cause the
570                 // client to unblock like CloseWrite.
571                 <-clientDoneChan
572                 return nil
573         }
574
575         clientCloseWrite := func() error {
576                 defer close(clientDoneChan)
577
578                 clientConfig := testConfig.Clone()
579                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
580                 if err != nil {
581                         return err
582                 }
583                 if err := conn.Handshake(); err != nil {
584                         return err
585                 }
586                 defer conn.Close()
587
588                 if err := conn.CloseWrite(); err != nil {
589                         return fmt.Errorf("client CloseWrite: %v", err)
590                 }
591
592                 if _, err := conn.Write([]byte{0}); err != errShutdown {
593                         return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
594                 }
595
596                 data, err := ioutil.ReadAll(conn)
597                 if err != nil {
598                         return err
599                 }
600                 if len(data) > 0 {
601                         return fmt.Errorf("Read data = %q; want nothing", data)
602                 }
603                 return nil
604         }
605
606         errChan := make(chan error, 2)
607
608         go func() { errChan <- serverCloseWrite() }()
609         go func() { errChan <- clientCloseWrite() }()
610
611         for i := 0; i < 2; i++ {
612                 select {
613                 case err := <-errChan:
614                         if err != nil {
615                                 t.Fatal(err)
616                         }
617                 case <-time.After(10 * time.Second):
618                         t.Fatal("deadlock")
619                 }
620         }
621
622         // Also test CloseWrite being called before the handshake is
623         // finished:
624         {
625                 ln2 := newLocalListener(t)
626                 defer ln2.Close()
627
628                 netConn, err := net.Dial("tcp", ln2.Addr().String())
629                 if err != nil {
630                         t.Fatal(err)
631                 }
632                 defer netConn.Close()
633                 conn := Client(netConn, testConfig.Clone())
634
635                 if err := conn.CloseWrite(); err != errEarlyCloseWrite {
636                         t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
637                 }
638         }
639 }
640
641 func TestWarningAlertFlood(t *testing.T) {
642         ln := newLocalListener(t)
643         defer ln.Close()
644
645         server := func() error {
646                 sconn, err := ln.Accept()
647                 if err != nil {
648                         return fmt.Errorf("accept: %v", err)
649                 }
650                 defer sconn.Close()
651
652                 serverConfig := testConfig.Clone()
653                 srv := Server(sconn, serverConfig)
654                 if err := srv.Handshake(); err != nil {
655                         return fmt.Errorf("handshake: %v", err)
656                 }
657                 defer srv.Close()
658
659                 _, err = ioutil.ReadAll(srv)
660                 if err == nil {
661                         return errors.New("unexpected lack of error from server")
662                 }
663                 const expected = "too many ignored"
664                 if str := err.Error(); !strings.Contains(str, expected) {
665                         return fmt.Errorf("expected error containing %q, but saw: %s", expected, str)
666                 }
667
668                 return nil
669         }
670
671         errChan := make(chan error, 1)
672         go func() { errChan <- server() }()
673
674         clientConfig := testConfig.Clone()
675         clientConfig.MaxVersion = VersionTLS12 // there are no warning alerts in TLS 1.3
676         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
677         if err != nil {
678                 t.Fatal(err)
679         }
680         defer conn.Close()
681         if err := conn.Handshake(); err != nil {
682                 t.Fatal(err)
683         }
684
685         for i := 0; i < maxUselessRecords+1; i++ {
686                 conn.sendAlert(alertNoRenegotiation)
687         }
688
689         if err := <-errChan; err != nil {
690                 t.Fatal(err)
691         }
692 }
693
694 func TestCloneFuncFields(t *testing.T) {
695         const expectedCount = 5
696         called := 0
697
698         c1 := Config{
699                 Time: func() time.Time {
700                         called |= 1 << 0
701                         return time.Time{}
702                 },
703                 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
704                         called |= 1 << 1
705                         return nil, nil
706                 },
707                 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) {
708                         called |= 1 << 2
709                         return nil, nil
710                 },
711                 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) {
712                         called |= 1 << 3
713                         return nil, nil
714                 },
715                 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
716                         called |= 1 << 4
717                         return nil
718                 },
719         }
720
721         c2 := c1.Clone()
722
723         c2.Time()
724         c2.GetCertificate(nil)
725         c2.GetClientCertificate(nil)
726         c2.GetConfigForClient(nil)
727         c2.VerifyPeerCertificate(nil, nil)
728
729         if called != (1<<expectedCount)-1 {
730                 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
731         }
732 }
733
734 func TestCloneNonFuncFields(t *testing.T) {
735         var c1 Config
736         v := reflect.ValueOf(&c1).Elem()
737
738         typ := v.Type()
739         for i := 0; i < typ.NumField(); i++ {
740                 f := v.Field(i)
741                 if !f.CanSet() {
742                         // unexported field; not cloned.
743                         continue
744                 }
745
746                 // testing/quick can't handle functions or interfaces and so
747                 // isn't used here.
748                 switch fn := typ.Field(i).Name; fn {
749                 case "Rand":
750                         f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
751                 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
752                         // DeepEqual can't compare functions. If you add a
753                         // function field to this list, you must also change
754                         // TestCloneFuncFields to ensure that the func field is
755                         // cloned.
756                 case "Certificates":
757                         f.Set(reflect.ValueOf([]Certificate{
758                                 {Certificate: [][]byte{{'b'}}},
759                         }))
760                 case "NameToCertificate":
761                         f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil}))
762                 case "RootCAs", "ClientCAs":
763                         f.Set(reflect.ValueOf(x509.NewCertPool()))
764                 case "ClientSessionCache":
765                         f.Set(reflect.ValueOf(NewLRUClientSessionCache(10)))
766                 case "KeyLogWriter":
767                         f.Set(reflect.ValueOf(io.Writer(os.Stdout)))
768                 case "NextProtos":
769                         f.Set(reflect.ValueOf([]string{"a", "b"}))
770                 case "ServerName":
771                         f.Set(reflect.ValueOf("b"))
772                 case "ClientAuth":
773                         f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
774                 case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
775                         f.Set(reflect.ValueOf(true))
776                 case "MinVersion", "MaxVersion":
777                         f.Set(reflect.ValueOf(uint16(VersionTLS12)))
778                 case "SessionTicketKey":
779                         f.Set(reflect.ValueOf([32]byte{}))
780                 case "CipherSuites":
781                         f.Set(reflect.ValueOf([]uint16{1, 2}))
782                 case "CurvePreferences":
783                         f.Set(reflect.ValueOf([]CurveID{CurveP256}))
784                 case "Renegotiation":
785                         f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
786                 default:
787                         t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
788                 }
789         }
790
791         c2 := c1.Clone()
792         // DeepEqual also compares unexported fields, thus c2 needs to have run
793         // serverInit in order to be DeepEqual to c1. Cloning it and discarding
794         // the result is sufficient.
795         c2.Clone()
796
797         if !reflect.DeepEqual(&c1, c2) {
798                 t.Errorf("clone failed to copy a field")
799         }
800 }
801
802 // changeImplConn is a net.Conn which can change its Write and Close
803 // methods.
804 type changeImplConn struct {
805         net.Conn
806         writeFunc func([]byte) (int, error)
807         closeFunc func() error
808 }
809
810 func (w *changeImplConn) Write(p []byte) (n int, err error) {
811         if w.writeFunc != nil {
812                 return w.writeFunc(p)
813         }
814         return w.Conn.Write(p)
815 }
816
817 func (w *changeImplConn) Close() error {
818         if w.closeFunc != nil {
819                 return w.closeFunc()
820         }
821         return w.Conn.Close()
822 }
823
824 func throughput(b *testing.B, version uint16, totalBytes int64, dynamicRecordSizingDisabled bool) {
825         ln := newLocalListener(b)
826         defer ln.Close()
827
828         N := b.N
829
830         // Less than 64KB because Windows appears to use a TCP rwin < 64KB.
831         // See Issue #15899.
832         const bufsize = 32 << 10
833
834         go func() {
835                 buf := make([]byte, bufsize)
836                 for i := 0; i < N; i++ {
837                         sconn, err := ln.Accept()
838                         if err != nil {
839                                 // panic rather than synchronize to avoid benchmark overhead
840                                 // (cannot call b.Fatal in goroutine)
841                                 panic(fmt.Errorf("accept: %v", err))
842                         }
843                         serverConfig := testConfig.Clone()
844                         serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers
845                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
846                         srv := Server(sconn, serverConfig)
847                         if err := srv.Handshake(); err != nil {
848                                 panic(fmt.Errorf("handshake: %v", err))
849                         }
850                         if _, err := io.CopyBuffer(srv, srv, buf); err != nil {
851                                 panic(fmt.Errorf("copy buffer: %v", err))
852                         }
853                 }
854         }()
855
856         b.SetBytes(totalBytes)
857         clientConfig := testConfig.Clone()
858         clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers
859         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
860         clientConfig.MaxVersion = version
861
862         buf := make([]byte, bufsize)
863         chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
864         for i := 0; i < N; i++ {
865                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
866                 if err != nil {
867                         b.Fatal(err)
868                 }
869                 for j := 0; j < chunks; j++ {
870                         _, err := conn.Write(buf)
871                         if err != nil {
872                                 b.Fatal(err)
873                         }
874                         _, err = io.ReadFull(conn, buf)
875                         if err != nil {
876                                 b.Fatal(err)
877                         }
878                 }
879                 conn.Close()
880         }
881 }
882
883 func BenchmarkThroughput(b *testing.B) {
884         for _, mode := range []string{"Max", "Dynamic"} {
885                 for size := 1; size <= 64; size <<= 1 {
886                         name := fmt.Sprintf("%sPacket/%dMB", mode, size)
887                         b.Run(name, func(b *testing.B) {
888                                 b.Run("TLSv12", func(b *testing.B) {
889                                         throughput(b, VersionTLS12, int64(size<<20), mode == "Max")
890                                 })
891                                 b.Run("TLSv13", func(b *testing.B) {
892                                         throughput(b, VersionTLS13, int64(size<<20), mode == "Max")
893                                 })
894                         })
895                 }
896         }
897 }
898
899 type slowConn struct {
900         net.Conn
901         bps int
902 }
903
904 func (c *slowConn) Write(p []byte) (int, error) {
905         if c.bps == 0 {
906                 panic("too slow")
907         }
908         t0 := time.Now()
909         wrote := 0
910         for wrote < len(p) {
911                 time.Sleep(100 * time.Microsecond)
912                 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8
913                 if allowed > len(p) {
914                         allowed = len(p)
915                 }
916                 if wrote < allowed {
917                         n, err := c.Conn.Write(p[wrote:allowed])
918                         wrote += n
919                         if err != nil {
920                                 return wrote, err
921                         }
922                 }
923         }
924         return len(p), nil
925 }
926
927 func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) {
928         ln := newLocalListener(b)
929         defer ln.Close()
930
931         N := b.N
932
933         go func() {
934                 for i := 0; i < N; i++ {
935                         sconn, err := ln.Accept()
936                         if err != nil {
937                                 // panic rather than synchronize to avoid benchmark overhead
938                                 // (cannot call b.Fatal in goroutine)
939                                 panic(fmt.Errorf("accept: %v", err))
940                         }
941                         serverConfig := testConfig.Clone()
942                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
943                         srv := Server(&slowConn{sconn, bps}, serverConfig)
944                         if err := srv.Handshake(); err != nil {
945                                 panic(fmt.Errorf("handshake: %v", err))
946                         }
947                         io.Copy(srv, srv)
948                 }
949         }()
950
951         clientConfig := testConfig.Clone()
952         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
953         clientConfig.MaxVersion = version
954
955         buf := make([]byte, 16384)
956         peek := make([]byte, 1)
957
958         for i := 0; i < N; i++ {
959                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
960                 if err != nil {
961                         b.Fatal(err)
962                 }
963                 // make sure we're connected and previous connection has stopped
964                 if _, err := conn.Write(buf[:1]); err != nil {
965                         b.Fatal(err)
966                 }
967                 if _, err := io.ReadFull(conn, peek); err != nil {
968                         b.Fatal(err)
969                 }
970                 if _, err := conn.Write(buf); err != nil {
971                         b.Fatal(err)
972                 }
973                 if _, err = io.ReadFull(conn, peek); err != nil {
974                         b.Fatal(err)
975                 }
976                 conn.Close()
977         }
978 }
979
980 func BenchmarkLatency(b *testing.B) {
981         for _, mode := range []string{"Max", "Dynamic"} {
982                 for _, kbps := range []int{200, 500, 1000, 2000, 5000} {
983                         name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps)
984                         b.Run(name, func(b *testing.B) {
985                                 b.Run("TLSv12", func(b *testing.B) {
986                                         latency(b, VersionTLS12, kbps*1000, mode == "Max")
987                                 })
988                                 b.Run("TLSv13", func(b *testing.B) {
989                                         latency(b, VersionTLS13, kbps*1000, mode == "Max")
990                                 })
991                         })
992                 }
993         }
994 }
995
996 func TestConnectionStateMarshal(t *testing.T) {
997         cs := &ConnectionState{}
998         _, err := json.Marshal(cs)
999         if err != nil {
1000                 t.Errorf("json.Marshal failed on ConnectionState: %v", err)
1001         }
1002 }
1003
1004 func TestConnectionState(t *testing.T) {
1005         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1006         if err != nil {
1007                 panic(err)
1008         }
1009         rootCAs := x509.NewCertPool()
1010         rootCAs.AddCert(issuer)
1011
1012         now := func() time.Time { return time.Unix(1476984729, 0) }
1013
1014         const alpnProtocol = "golang"
1015         const serverName = "example.golang"
1016         var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1017         var ocsp = []byte("dummy ocsp")
1018
1019         for _, v := range []uint16{VersionTLS12, VersionTLS13} {
1020                 var name string
1021                 switch v {
1022                 case VersionTLS12:
1023                         name = "TLSv12"
1024                 case VersionTLS13:
1025                         name = "TLSv13"
1026                 }
1027                 t.Run(name, func(t *testing.T) {
1028                         config := &Config{
1029                                 Time:         now,
1030                                 Rand:         zeroSource{},
1031                                 Certificates: make([]Certificate, 1),
1032                                 MaxVersion:   v,
1033                                 RootCAs:      rootCAs,
1034                                 ClientCAs:    rootCAs,
1035                                 ClientAuth:   RequireAndVerifyClientCert,
1036                                 NextProtos:   []string{alpnProtocol},
1037                                 ServerName:   serverName,
1038                         }
1039                         config.Certificates[0].Certificate = [][]byte{testRSACertificate}
1040                         config.Certificates[0].PrivateKey = testRSAPrivateKey
1041                         config.Certificates[0].SignedCertificateTimestamps = scts
1042                         config.Certificates[0].OCSPStaple = ocsp
1043
1044                         ss, cs, err := testHandshake(t, config, config)
1045                         if err != nil {
1046                                 t.Fatalf("Handshake failed: %v", err)
1047                         }
1048
1049                         if ss.Version != v || cs.Version != v {
1050                                 t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v)
1051                         }
1052
1053                         if !ss.HandshakeComplete || !cs.HandshakeComplete {
1054                                 t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete)
1055                         }
1056
1057                         if ss.DidResume || cs.DidResume {
1058                                 t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume)
1059                         }
1060
1061                         if ss.CipherSuite == 0 || cs.CipherSuite == 0 {
1062                                 t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite)
1063                         }
1064
1065                         if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol {
1066                                 t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol)
1067                         }
1068
1069                         if !cs.NegotiatedProtocolIsMutual {
1070                                 t.Errorf("Got false NegotiatedProtocolIsMutual on the client side")
1071                         }
1072                         // NegotiatedProtocolIsMutual on the server side is unspecified.
1073
1074                         if ss.ServerName != serverName {
1075                                 t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
1076                         }
1077                         if cs.ServerName != "" {
1078                                 t.Errorf("Got unexpected server name on the client side")
1079                         }
1080
1081                         if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
1082                                 t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1)
1083                         }
1084
1085                         if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 {
1086                                 t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1)
1087                         } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 {
1088                                 t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2)
1089                         }
1090
1091                         if len(cs.SignedCertificateTimestamps) != 2 {
1092                                 t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2)
1093                         }
1094                         if !bytes.Equal(cs.OCSPResponse, ocsp) {
1095                                 t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp)
1096                         }
1097                         // Only TLS 1.3 supports OCSP and SCTs on client certs.
1098                         if v == VersionTLS13 {
1099                                 if len(ss.SignedCertificateTimestamps) != 2 {
1100                                         t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2)
1101                                 }
1102                                 if !bytes.Equal(ss.OCSPResponse, ocsp) {
1103                                         t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp)
1104                                 }
1105                         }
1106
1107                         if v == VersionTLS13 {
1108                                 if ss.TLSUnique != nil || cs.TLSUnique != nil {
1109                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique)
1110                                 }
1111                         } else {
1112                                 if ss.TLSUnique == nil || cs.TLSUnique == nil {
1113                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique)
1114                                 }
1115                         }
1116                 })
1117         }
1118 }
1119
1120 // Issue 28744: Ensure that we don't modify memory
1121 // that Config doesn't own such as Certificates.
1122 func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) {
1123         c0 := Certificate{
1124                 Certificate: [][]byte{testRSACertificate},
1125                 PrivateKey:  testRSAPrivateKey,
1126         }
1127         c1 := Certificate{
1128                 Certificate: [][]byte{testSNICertificate},
1129                 PrivateKey:  testRSAPrivateKey,
1130         }
1131         config := testConfig.Clone()
1132         config.Certificates = []Certificate{c0, c1}
1133
1134         config.BuildNameToCertificate()
1135         got := config.Certificates
1136         want := []Certificate{c0, c1}
1137         if !reflect.DeepEqual(got, want) {
1138                 t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want)
1139         }
1140 }
1141
1142 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
1143
1144 func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
1145         rsaCert := &Certificate{
1146                 Certificate: [][]byte{testRSACertificate},
1147                 PrivateKey:  testRSAPrivateKey,
1148         }
1149         pkcs1Cert := &Certificate{
1150                 Certificate:                  [][]byte{testRSACertificate},
1151                 PrivateKey:                   testRSAPrivateKey,
1152                 SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256},
1153         }
1154         ecdsaCert := &Certificate{
1155                 // ECDSA P-256 certificate
1156                 Certificate: [][]byte{testP256Certificate},
1157                 PrivateKey:  testP256PrivateKey,
1158         }
1159         ed25519Cert := &Certificate{
1160                 Certificate: [][]byte{testEd25519Certificate},
1161                 PrivateKey:  testEd25519PrivateKey,
1162         }
1163
1164         tests := []struct {
1165                 c       *Certificate
1166                 chi     *ClientHelloInfo
1167                 wantErr string
1168         }{
1169                 {rsaCert, &ClientHelloInfo{
1170                         ServerName:        "example.golang",
1171                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1172                         SupportedVersions: []uint16{VersionTLS13},
1173                 }, ""},
1174                 {ecdsaCert, &ClientHelloInfo{
1175                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1176                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1177                 }, ""},
1178                 {rsaCert, &ClientHelloInfo{
1179                         ServerName:        "example.com",
1180                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1181                         SupportedVersions: []uint16{VersionTLS13},
1182                 }, "not valid for requested server name"},
1183                 {ecdsaCert, &ClientHelloInfo{
1184                         SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1185                         SupportedVersions: []uint16{VersionTLS13},
1186                 }, "signature algorithms"},
1187                 {pkcs1Cert, &ClientHelloInfo{
1188                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1189                         SupportedVersions: []uint16{VersionTLS13},
1190                 }, "signature algorithms"},
1191
1192                 {rsaCert, &ClientHelloInfo{
1193                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1194                         SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1195                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1196                 }, "signature algorithms"},
1197                 {rsaCert, &ClientHelloInfo{
1198                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1199                         SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1200                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1201                         config: &Config{
1202                                 MaxVersion: VersionTLS12,
1203                         },
1204                 }, ""}, // Check that mutual version selection works.
1205
1206                 {ecdsaCert, &ClientHelloInfo{
1207                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1208                         SupportedCurves:   []CurveID{CurveP256},
1209                         SupportedPoints:   []uint8{pointFormatUncompressed},
1210                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1211                         SupportedVersions: []uint16{VersionTLS12},
1212                 }, ""},
1213                 {ecdsaCert, &ClientHelloInfo{
1214                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1215                         SupportedCurves:   []CurveID{CurveP256},
1216                         SupportedPoints:   []uint8{pointFormatUncompressed},
1217                         SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1218                         SupportedVersions: []uint16{VersionTLS12},
1219                 }, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme.
1220                 {ecdsaCert, &ClientHelloInfo{
1221                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1222                         SupportedCurves:   []CurveID{CurveP256},
1223                         SupportedPoints:   []uint8{pointFormatUncompressed},
1224                         SignatureSchemes:  nil,
1225                         SupportedVersions: []uint16{VersionTLS12},
1226                 }, ""}, // TLS 1.2 comes with default signature schemes.
1227                 {ecdsaCert, &ClientHelloInfo{
1228                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1229                         SupportedCurves:   []CurveID{CurveP256},
1230                         SupportedPoints:   []uint8{pointFormatUncompressed},
1231                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1232                         SupportedVersions: []uint16{VersionTLS12},
1233                 }, "cipher suite"},
1234                 {ecdsaCert, &ClientHelloInfo{
1235                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1236                         SupportedCurves:   []CurveID{CurveP256},
1237                         SupportedPoints:   []uint8{pointFormatUncompressed},
1238                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1239                         SupportedVersions: []uint16{VersionTLS12},
1240                         config: &Config{
1241                                 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1242                         },
1243                 }, "cipher suite"},
1244                 {ecdsaCert, &ClientHelloInfo{
1245                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1246                         SupportedCurves:   []CurveID{CurveP384},
1247                         SupportedPoints:   []uint8{pointFormatUncompressed},
1248                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1249                         SupportedVersions: []uint16{VersionTLS12},
1250                 }, "certificate curve"},
1251                 {ecdsaCert, &ClientHelloInfo{
1252                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1253                         SupportedCurves:   []CurveID{CurveP256},
1254                         SupportedPoints:   []uint8{1},
1255                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1256                         SupportedVersions: []uint16{VersionTLS12},
1257                 }, "doesn't support ECDHE"},
1258                 {ecdsaCert, &ClientHelloInfo{
1259                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1260                         SupportedCurves:   []CurveID{CurveP256},
1261                         SupportedPoints:   []uint8{pointFormatUncompressed},
1262                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1263                         SupportedVersions: []uint16{VersionTLS12},
1264                 }, "signature algorithms"},
1265
1266                 {ed25519Cert, &ClientHelloInfo{
1267                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1268                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1269                         SupportedPoints:   []uint8{pointFormatUncompressed},
1270                         SignatureSchemes:  []SignatureScheme{Ed25519},
1271                         SupportedVersions: []uint16{VersionTLS12},
1272                 }, ""},
1273                 {ed25519Cert, &ClientHelloInfo{
1274                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1275                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1276                         SupportedPoints:   []uint8{pointFormatUncompressed},
1277                         SignatureSchemes:  []SignatureScheme{Ed25519},
1278                         SupportedVersions: []uint16{VersionTLS10},
1279                 }, "doesn't support Ed25519"},
1280                 {ed25519Cert, &ClientHelloInfo{
1281                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1282                         SupportedCurves:   []CurveID{},
1283                         SupportedPoints:   []uint8{pointFormatUncompressed},
1284                         SignatureSchemes:  []SignatureScheme{Ed25519},
1285                         SupportedVersions: []uint16{VersionTLS12},
1286                 }, "doesn't support ECDHE"},
1287
1288                 {rsaCert, &ClientHelloInfo{
1289                         CipherSuites:      []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
1290                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1291                         SupportedPoints:   []uint8{pointFormatUncompressed},
1292                         SupportedVersions: []uint16{VersionTLS10},
1293                 }, ""},
1294                 {rsaCert, &ClientHelloInfo{
1295                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1296                         SupportedVersions: []uint16{VersionTLS12},
1297                 }, ""}, // static RSA fallback
1298         }
1299         for i, tt := range tests {
1300                 err := tt.chi.SupportsCertificate(tt.c)
1301                 switch {
1302                 case tt.wantErr == "" && err != nil:
1303                         t.Errorf("%d: unexpected error: %v", i, err)
1304                 case tt.wantErr != "" && err == nil:
1305                         t.Errorf("%d: unexpected success", i)
1306                 case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr):
1307                         t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr)
1308                 }
1309         }
1310 }
1311
1312 func TestCipherSuites(t *testing.T) {
1313         var lastID uint16
1314         for _, c := range CipherSuites() {
1315                 if lastID > c.ID {
1316                         t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1317                 } else {
1318                         lastID = c.ID
1319                 }
1320
1321                 if c.Insecure {
1322                         t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID)
1323                 }
1324         }
1325         lastID = 0
1326         for _, c := range InsecureCipherSuites() {
1327                 if lastID > c.ID {
1328                         t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1329                 } else {
1330                         lastID = c.ID
1331                 }
1332
1333                 if !c.Insecure {
1334                         t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID)
1335                 }
1336         }
1337
1338         cipherSuiteByID := func(id uint16) *CipherSuite {
1339                 for _, c := range CipherSuites() {
1340                         if c.ID == id {
1341                                 return c
1342                         }
1343                 }
1344                 for _, c := range InsecureCipherSuites() {
1345                         if c.ID == id {
1346                                 return c
1347                         }
1348                 }
1349                 return nil
1350         }
1351
1352         for _, c := range cipherSuites {
1353                 cc := cipherSuiteByID(c.id)
1354                 if cc == nil {
1355                         t.Errorf("%#04x: no CipherSuite entry", c.id)
1356                         continue
1357                 }
1358
1359                 if defaultOff := c.flags&suiteDefaultOff != 0; defaultOff != cc.Insecure {
1360                         t.Errorf("%#04x: Insecure %v, expected %v", c.id, cc.Insecure, defaultOff)
1361                 }
1362                 if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
1363                         t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1364                 } else if !tls12Only && len(cc.SupportedVersions) != 3 {
1365                         t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1366                 }
1367
1368                 if got := CipherSuiteName(c.id); got != cc.Name {
1369                         t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1370                 }
1371         }
1372         for _, c := range cipherSuitesTLS13 {
1373                 cc := cipherSuiteByID(c.id)
1374                 if cc == nil {
1375                         t.Errorf("%#04x: no CipherSuite entry", c.id)
1376                         continue
1377                 }
1378
1379                 if cc.Insecure {
1380                         t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure)
1381                 }
1382                 if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 {
1383                         t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1384                 }
1385
1386                 if got := CipherSuiteName(c.id); got != cc.Name {
1387                         t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1388                 }
1389         }
1390
1391         if got := CipherSuiteName(0xabc); got != "0x0ABC" {
1392                 t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
1393         }
1394 }
1395
1396 type brokenSigner struct{ crypto.Signer }
1397
1398 func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
1399         // Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded.
1400         return s.Signer.Sign(rand, digest, opts.HashFunc())
1401 }
1402
1403 // TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that
1404 // always makes PKCS#1 v1.5 signatures, so can't be used with RSA-PSS.
1405 func TestPKCS1OnlyCert(t *testing.T) {
1406         clientConfig := testConfig.Clone()
1407         clientConfig.Certificates = []Certificate{{
1408                 Certificate: [][]byte{testRSACertificate},
1409                 PrivateKey:  brokenSigner{testRSAPrivateKey},
1410         }}
1411         serverConfig := testConfig.Clone()
1412         serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS#1 v1.5
1413         serverConfig.ClientAuth = RequireAnyClientCert
1414
1415         // If RSA-PSS is selected, the handshake should fail.
1416         if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil {
1417                 t.Fatal("expected broken certificate to cause connection to fail")
1418         }
1419
1420         clientConfig.Certificates[0].SupportedSignatureAlgorithms =
1421                 []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}
1422
1423         // But if the certificate restricts supported algorithms, RSA-PSS should not
1424         // be selected, and the handshake should succeed.
1425         if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
1426                 t.Error(err)
1427         }
1428 }