]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/tls_test.go
crypto/tls: add WrapSession and UnwrapSession
[gostls13.git] / src / crypto / tls / tls_test.go
1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "context"
10         "crypto"
11         "crypto/x509"
12         "encoding/json"
13         "errors"
14         "fmt"
15         "internal/testenv"
16         "io"
17         "math"
18         "net"
19         "os"
20         "reflect"
21         "sort"
22         "strings"
23         "testing"
24         "time"
25 )
26
27 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
28 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
29 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
30 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
31 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
32 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
33 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
34 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
35 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
36 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
37 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
38 -----END CERTIFICATE-----
39 `
40
41 var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY-----
42 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
43 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
44 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
45 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
46 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
47 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
48 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
49 -----END RSA TESTING KEY-----
50 `)
51
52 // keyPEM is the same as rsaKeyPEM, but declares itself as just
53 // "PRIVATE KEY", not "RSA PRIVATE KEY".  https://golang.org/issue/4477
54 var keyPEM = testingKey(`-----BEGIN TESTING KEY-----
55 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
56 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
57 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
58 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
59 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
60 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
61 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
62 -----END TESTING KEY-----
63 `)
64
65 var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
66 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
67 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
68 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG
69 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk
70 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR
71 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl
72 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8
73 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo
74 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb
75 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1
76 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA==
77 -----END CERTIFICATE-----
78 `
79
80 var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS-----
81 BgUrgQQAIw==
82 -----END EC PARAMETERS-----
83 -----BEGIN EC TESTING KEY-----
84 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
85 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
86 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
87 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
88 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
89 -----END EC TESTING KEY-----
90 `)
91
92 var keyPairTests = []struct {
93         algo string
94         cert string
95         key  string
96 }{
97         {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM},
98         {"RSA", rsaCertPEM, rsaKeyPEM},
99         {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477
100 }
101
102 func TestX509KeyPair(t *testing.T) {
103         t.Parallel()
104         var pem []byte
105         for _, test := range keyPairTests {
106                 pem = []byte(test.cert + test.key)
107                 if _, err := X509KeyPair(pem, pem); err != nil {
108                         t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err)
109                 }
110                 pem = []byte(test.key + test.cert)
111                 if _, err := X509KeyPair(pem, pem); err != nil {
112                         t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err)
113                 }
114         }
115 }
116
117 func TestX509KeyPairErrors(t *testing.T) {
118         _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM))
119         if err == nil {
120                 t.Fatalf("X509KeyPair didn't return an error when arguments were switched")
121         }
122         if subStr := "been switched"; !strings.Contains(err.Error(), subStr) {
123                 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err)
124         }
125
126         _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM))
127         if err == nil {
128                 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates")
129         }
130         if subStr := "certificate"; !strings.Contains(err.Error(), subStr) {
131                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err)
132         }
133
134         const nonsensePEM = `
135 -----BEGIN NONSENSE-----
136 Zm9vZm9vZm9v
137 -----END NONSENSE-----
138 `
139
140         _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM))
141         if err == nil {
142                 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense")
143         }
144         if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) {
145                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err)
146         }
147 }
148
149 func TestX509MixedKeyPair(t *testing.T) {
150         if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil {
151                 t.Error("Load of RSA certificate succeeded with ECDSA private key")
152         }
153         if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil {
154                 t.Error("Load of ECDSA certificate succeeded with RSA private key")
155         }
156 }
157
158 func newLocalListener(t testing.TB) net.Listener {
159         ln, err := net.Listen("tcp", "127.0.0.1:0")
160         if err != nil {
161                 ln, err = net.Listen("tcp6", "[::1]:0")
162         }
163         if err != nil {
164                 t.Fatal(err)
165         }
166         return ln
167 }
168
169 func TestDialTimeout(t *testing.T) {
170         if testing.Short() {
171                 t.Skip("skipping in short mode")
172         }
173
174         timeout := 100 * time.Microsecond
175         for !t.Failed() {
176                 acceptc := make(chan net.Conn)
177                 listener := newLocalListener(t)
178                 go func() {
179                         for {
180                                 conn, err := listener.Accept()
181                                 if err != nil {
182                                         close(acceptc)
183                                         return
184                                 }
185                                 acceptc <- conn
186                         }
187                 }()
188
189                 addr := listener.Addr().String()
190                 dialer := &net.Dialer{
191                         Timeout: timeout,
192                 }
193                 if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil {
194                         conn.Close()
195                         t.Errorf("DialWithTimeout unexpectedly completed successfully")
196                 } else if !isTimeoutError(err) {
197                         t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
198                 }
199
200                 listener.Close()
201
202                 // We're looking for a timeout during the handshake, so check that the
203                 // Listener actually accepted the connection to initiate it. (If the server
204                 // takes too long to accept the connection, we might cancel before the
205                 // underlying net.Conn is ever dialed — without ever attempting a
206                 // handshake.)
207                 lconn, ok := <-acceptc
208                 if ok {
209                         // The Listener accepted a connection, so assume that it was from our
210                         // Dial: we triggered the timeout at the point where we wanted it!
211                         t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr())
212                         lconn.Close()
213                 }
214                 // Close any spurious extra connecitions from the listener. (This is
215                 // possible if there are, for example, stray Dial calls from other tests.)
216                 for extraConn := range acceptc {
217                         t.Logf("spurious extra connection from %s", extraConn.RemoteAddr())
218                         extraConn.Close()
219                 }
220                 if ok {
221                         break
222                 }
223
224                 t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout)
225                 timeout *= 2
226         }
227 }
228
229 func TestDeadlineOnWrite(t *testing.T) {
230         if testing.Short() {
231                 t.Skip("skipping in short mode")
232         }
233
234         ln := newLocalListener(t)
235         defer ln.Close()
236
237         srvCh := make(chan *Conn, 1)
238
239         go func() {
240                 sconn, err := ln.Accept()
241                 if err != nil {
242                         srvCh <- nil
243                         return
244                 }
245                 srv := Server(sconn, testConfig.Clone())
246                 if err := srv.Handshake(); err != nil {
247                         srvCh <- nil
248                         return
249                 }
250                 srvCh <- srv
251         }()
252
253         clientConfig := testConfig.Clone()
254         clientConfig.MaxVersion = VersionTLS12
255         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
256         if err != nil {
257                 t.Fatal(err)
258         }
259         defer conn.Close()
260
261         srv := <-srvCh
262         if srv == nil {
263                 t.Error(err)
264         }
265
266         // Make sure the client/server is setup correctly and is able to do a typical Write/Read
267         buf := make([]byte, 6)
268         if _, err := srv.Write([]byte("foobar")); err != nil {
269                 t.Errorf("Write err: %v", err)
270         }
271         if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" {
272                 t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
273         }
274
275         // Set a deadline which should cause Write to timeout
276         if err = srv.SetDeadline(time.Now()); err != nil {
277                 t.Fatalf("SetDeadline(time.Now()) err: %v", err)
278         }
279         if _, err = srv.Write([]byte("should fail")); err == nil {
280                 t.Fatal("Write should have timed out")
281         }
282
283         // Clear deadline and make sure it still times out
284         if err = srv.SetDeadline(time.Time{}); err != nil {
285                 t.Fatalf("SetDeadline(time.Time{}) err: %v", err)
286         }
287         if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil {
288                 t.Fatal("Write which previously failed should still time out")
289         }
290
291         // Verify the error
292         if ne := err.(net.Error); ne.Temporary() != false {
293                 t.Error("Write timed out but incorrectly classified the error as Temporary")
294         }
295         if !isTimeoutError(err) {
296                 t.Error("Write timed out but did not classify the error as a Timeout")
297         }
298 }
299
300 type readerFunc func([]byte) (int, error)
301
302 func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
303
304 // TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
305 // (The other cases are all handled by the existing dial tests in this package, which
306 // all also flow through the same code shared code paths)
307 func TestDialer(t *testing.T) {
308         ln := newLocalListener(t)
309         defer ln.Close()
310
311         unblockServer := make(chan struct{}) // close-only
312         defer close(unblockServer)
313         go func() {
314                 conn, err := ln.Accept()
315                 if err != nil {
316                         return
317                 }
318                 defer conn.Close()
319                 <-unblockServer
320         }()
321
322         ctx, cancel := context.WithCancel(context.Background())
323         d := Dialer{Config: &Config{
324                 Rand: readerFunc(func(b []byte) (n int, err error) {
325                         // By the time crypto/tls wants randomness, that means it has a TCP
326                         // connection, so we're past the Dialer's dial and now blocked
327                         // in a handshake. Cancel our context and see if we get unstuck.
328                         // (Our TCP listener above never reads or writes, so the Handshake
329                         // would otherwise be stuck forever)
330                         cancel()
331                         return len(b), nil
332                 }),
333                 ServerName: "foo",
334         }}
335         _, err := d.DialContext(ctx, "tcp", ln.Addr().String())
336         if err != context.Canceled {
337                 t.Errorf("err = %v; want context.Canceled", err)
338         }
339 }
340
341 func isTimeoutError(err error) bool {
342         if ne, ok := err.(net.Error); ok {
343                 return ne.Timeout()
344         }
345         return false
346 }
347
348 // tests that Conn.Read returns (non-zero, io.EOF) instead of
349 // (non-zero, nil) when a Close (alertCloseNotify) is sitting right
350 // behind the application data in the buffer.
351 func TestConnReadNonzeroAndEOF(t *testing.T) {
352         // This test is racy: it assumes that after a write to a
353         // localhost TCP connection, the peer TCP connection can
354         // immediately read it. Because it's racy, we skip this test
355         // in short mode, and then retry it several times with an
356         // increasing sleep in between our final write (via srv.Close
357         // below) and the following read.
358         if testing.Short() {
359                 t.Skip("skipping in short mode")
360         }
361         var err error
362         for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 {
363                 if err = testConnReadNonzeroAndEOF(t, delay); err == nil {
364                         return
365                 }
366         }
367         t.Error(err)
368 }
369
370 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error {
371         ln := newLocalListener(t)
372         defer ln.Close()
373
374         srvCh := make(chan *Conn, 1)
375         var serr error
376         go func() {
377                 sconn, err := ln.Accept()
378                 if err != nil {
379                         serr = err
380                         srvCh <- nil
381                         return
382                 }
383                 serverConfig := testConfig.Clone()
384                 srv := Server(sconn, serverConfig)
385                 if err := srv.Handshake(); err != nil {
386                         serr = fmt.Errorf("handshake: %v", err)
387                         srvCh <- nil
388                         return
389                 }
390                 srvCh <- srv
391         }()
392
393         clientConfig := testConfig.Clone()
394         // In TLS 1.3, alerts are encrypted and disguised as application data, so
395         // the opportunistic peek won't work.
396         clientConfig.MaxVersion = VersionTLS12
397         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
398         if err != nil {
399                 t.Fatal(err)
400         }
401         defer conn.Close()
402
403         srv := <-srvCh
404         if srv == nil {
405                 return serr
406         }
407
408         buf := make([]byte, 6)
409
410         srv.Write([]byte("foobar"))
411         n, err := conn.Read(buf)
412         if n != 6 || err != nil || string(buf) != "foobar" {
413                 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
414         }
415
416         srv.Write([]byte("abcdef"))
417         srv.Close()
418         time.Sleep(delay)
419         n, err = conn.Read(buf)
420         if n != 6 || string(buf) != "abcdef" {
421                 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf)
422         }
423         if err != io.EOF {
424                 return fmt.Errorf("Second Read error = %v; want io.EOF", err)
425         }
426         return nil
427 }
428
429 func TestTLSUniqueMatches(t *testing.T) {
430         ln := newLocalListener(t)
431         defer ln.Close()
432
433         serverTLSUniques := make(chan []byte)
434         parentDone := make(chan struct{})
435         childDone := make(chan struct{})
436         defer close(parentDone)
437         go func() {
438                 defer close(childDone)
439                 for i := 0; i < 2; i++ {
440                         sconn, err := ln.Accept()
441                         if err != nil {
442                                 t.Error(err)
443                                 return
444                         }
445                         serverConfig := testConfig.Clone()
446                         serverConfig.MaxVersion = VersionTLS12 // TLSUnique is not defined in TLS 1.3
447                         srv := Server(sconn, serverConfig)
448                         if err := srv.Handshake(); err != nil {
449                                 t.Error(err)
450                                 return
451                         }
452                         select {
453                         case <-parentDone:
454                                 return
455                         case serverTLSUniques <- srv.ConnectionState().TLSUnique:
456                         }
457                 }
458         }()
459
460         clientConfig := testConfig.Clone()
461         clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
462         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
463         if err != nil {
464                 t.Fatal(err)
465         }
466
467         var serverTLSUniquesValue []byte
468         select {
469         case <-childDone:
470                 return
471         case serverTLSUniquesValue = <-serverTLSUniques:
472         }
473
474         if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
475                 t.Error("client and server channel bindings differ")
476         }
477         conn.Close()
478
479         conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
480         if err != nil {
481                 t.Fatal(err)
482         }
483         defer conn.Close()
484         if !conn.ConnectionState().DidResume {
485                 t.Error("second session did not use resumption")
486         }
487
488         select {
489         case <-childDone:
490                 return
491         case serverTLSUniquesValue = <-serverTLSUniques:
492         }
493
494         if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
495                 t.Error("client and server channel bindings differ when session resumption is used")
496         }
497 }
498
499 func TestVerifyHostname(t *testing.T) {
500         testenv.MustHaveExternalNetwork(t)
501
502         c, err := Dial("tcp", "www.google.com:https", nil)
503         if err != nil {
504                 t.Fatal(err)
505         }
506         if err := c.VerifyHostname("www.google.com"); err != nil {
507                 t.Fatalf("verify www.google.com: %v", err)
508         }
509         if err := c.VerifyHostname("www.yahoo.com"); err == nil {
510                 t.Fatalf("verify www.yahoo.com succeeded")
511         }
512
513         c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true})
514         if err != nil {
515                 t.Fatal(err)
516         }
517         if err := c.VerifyHostname("www.google.com"); err == nil {
518                 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
519         }
520 }
521
522 func TestConnCloseBreakingWrite(t *testing.T) {
523         ln := newLocalListener(t)
524         defer ln.Close()
525
526         srvCh := make(chan *Conn, 1)
527         var serr error
528         var sconn net.Conn
529         go func() {
530                 var err error
531                 sconn, err = ln.Accept()
532                 if err != nil {
533                         serr = err
534                         srvCh <- nil
535                         return
536                 }
537                 serverConfig := testConfig.Clone()
538                 srv := Server(sconn, serverConfig)
539                 if err := srv.Handshake(); err != nil {
540                         serr = fmt.Errorf("handshake: %v", err)
541                         srvCh <- nil
542                         return
543                 }
544                 srvCh <- srv
545         }()
546
547         cconn, err := net.Dial("tcp", ln.Addr().String())
548         if err != nil {
549                 t.Fatal(err)
550         }
551         defer cconn.Close()
552
553         conn := &changeImplConn{
554                 Conn: cconn,
555         }
556
557         clientConfig := testConfig.Clone()
558         tconn := Client(conn, clientConfig)
559         if err := tconn.Handshake(); err != nil {
560                 t.Fatal(err)
561         }
562
563         srv := <-srvCh
564         if srv == nil {
565                 t.Fatal(serr)
566         }
567         defer sconn.Close()
568
569         connClosed := make(chan struct{})
570         conn.closeFunc = func() error {
571                 close(connClosed)
572                 return nil
573         }
574
575         inWrite := make(chan bool, 1)
576         var errConnClosed = errors.New("conn closed for test")
577         conn.writeFunc = func(p []byte) (n int, err error) {
578                 inWrite <- true
579                 <-connClosed
580                 return 0, errConnClosed
581         }
582
583         closeReturned := make(chan bool, 1)
584         go func() {
585                 <-inWrite
586                 tconn.Close() // test that this doesn't block forever.
587                 closeReturned <- true
588         }()
589
590         _, err = tconn.Write([]byte("foo"))
591         if err != errConnClosed {
592                 t.Errorf("Write error = %v; want errConnClosed", err)
593         }
594
595         <-closeReturned
596         if err := tconn.Close(); err != net.ErrClosed {
597                 t.Errorf("Close error = %v; want net.ErrClosed", err)
598         }
599 }
600
601 func TestConnCloseWrite(t *testing.T) {
602         ln := newLocalListener(t)
603         defer ln.Close()
604
605         clientDoneChan := make(chan struct{})
606
607         serverCloseWrite := func() error {
608                 sconn, err := ln.Accept()
609                 if err != nil {
610                         return fmt.Errorf("accept: %v", err)
611                 }
612                 defer sconn.Close()
613
614                 serverConfig := testConfig.Clone()
615                 srv := Server(sconn, serverConfig)
616                 if err := srv.Handshake(); err != nil {
617                         return fmt.Errorf("handshake: %v", err)
618                 }
619                 defer srv.Close()
620
621                 data, err := io.ReadAll(srv)
622                 if err != nil {
623                         return err
624                 }
625                 if len(data) > 0 {
626                         return fmt.Errorf("Read data = %q; want nothing", data)
627                 }
628
629                 if err := srv.CloseWrite(); err != nil {
630                         return fmt.Errorf("server CloseWrite: %v", err)
631                 }
632
633                 // Wait for clientCloseWrite to finish, so we know we
634                 // tested the CloseWrite before we defer the
635                 // sconn.Close above, which would also cause the
636                 // client to unblock like CloseWrite.
637                 <-clientDoneChan
638                 return nil
639         }
640
641         clientCloseWrite := func() error {
642                 defer close(clientDoneChan)
643
644                 clientConfig := testConfig.Clone()
645                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
646                 if err != nil {
647                         return err
648                 }
649                 if err := conn.Handshake(); err != nil {
650                         return err
651                 }
652                 defer conn.Close()
653
654                 if err := conn.CloseWrite(); err != nil {
655                         return fmt.Errorf("client CloseWrite: %v", err)
656                 }
657
658                 if _, err := conn.Write([]byte{0}); err != errShutdown {
659                         return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
660                 }
661
662                 data, err := io.ReadAll(conn)
663                 if err != nil {
664                         return err
665                 }
666                 if len(data) > 0 {
667                         return fmt.Errorf("Read data = %q; want nothing", data)
668                 }
669                 return nil
670         }
671
672         errChan := make(chan error, 2)
673
674         go func() { errChan <- serverCloseWrite() }()
675         go func() { errChan <- clientCloseWrite() }()
676
677         for i := 0; i < 2; i++ {
678                 select {
679                 case err := <-errChan:
680                         if err != nil {
681                                 t.Fatal(err)
682                         }
683                 case <-time.After(10 * time.Second):
684                         t.Fatal("deadlock")
685                 }
686         }
687
688         // Also test CloseWrite being called before the handshake is
689         // finished:
690         {
691                 ln2 := newLocalListener(t)
692                 defer ln2.Close()
693
694                 netConn, err := net.Dial("tcp", ln2.Addr().String())
695                 if err != nil {
696                         t.Fatal(err)
697                 }
698                 defer netConn.Close()
699                 conn := Client(netConn, testConfig.Clone())
700
701                 if err := conn.CloseWrite(); err != errEarlyCloseWrite {
702                         t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
703                 }
704         }
705 }
706
707 func TestWarningAlertFlood(t *testing.T) {
708         ln := newLocalListener(t)
709         defer ln.Close()
710
711         server := func() error {
712                 sconn, err := ln.Accept()
713                 if err != nil {
714                         return fmt.Errorf("accept: %v", err)
715                 }
716                 defer sconn.Close()
717
718                 serverConfig := testConfig.Clone()
719                 srv := Server(sconn, serverConfig)
720                 if err := srv.Handshake(); err != nil {
721                         return fmt.Errorf("handshake: %v", err)
722                 }
723                 defer srv.Close()
724
725                 _, err = io.ReadAll(srv)
726                 if err == nil {
727                         return errors.New("unexpected lack of error from server")
728                 }
729                 const expected = "too many ignored"
730                 if str := err.Error(); !strings.Contains(str, expected) {
731                         return fmt.Errorf("expected error containing %q, but saw: %s", expected, str)
732                 }
733
734                 return nil
735         }
736
737         errChan := make(chan error, 1)
738         go func() { errChan <- server() }()
739
740         clientConfig := testConfig.Clone()
741         clientConfig.MaxVersion = VersionTLS12 // there are no warning alerts in TLS 1.3
742         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
743         if err != nil {
744                 t.Fatal(err)
745         }
746         defer conn.Close()
747         if err := conn.Handshake(); err != nil {
748                 t.Fatal(err)
749         }
750
751         for i := 0; i < maxUselessRecords+1; i++ {
752                 conn.sendAlert(alertNoRenegotiation)
753         }
754
755         if err := <-errChan; err != nil {
756                 t.Fatal(err)
757         }
758 }
759
760 func TestCloneFuncFields(t *testing.T) {
761         const expectedCount = 8
762         called := 0
763
764         c1 := Config{
765                 Time: func() time.Time {
766                         called |= 1 << 0
767                         return time.Time{}
768                 },
769                 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
770                         called |= 1 << 1
771                         return nil, nil
772                 },
773                 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) {
774                         called |= 1 << 2
775                         return nil, nil
776                 },
777                 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) {
778                         called |= 1 << 3
779                         return nil, nil
780                 },
781                 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
782                         called |= 1 << 4
783                         return nil
784                 },
785                 VerifyConnection: func(ConnectionState) error {
786                         called |= 1 << 5
787                         return nil
788                 },
789                 UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) {
790                         called |= 1 << 6
791                         return nil, nil
792                 },
793                 WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) {
794                         called |= 1 << 7
795                         return nil, nil
796                 },
797         }
798
799         c2 := c1.Clone()
800
801         c2.Time()
802         c2.GetCertificate(nil)
803         c2.GetClientCertificate(nil)
804         c2.GetConfigForClient(nil)
805         c2.VerifyPeerCertificate(nil, nil)
806         c2.VerifyConnection(ConnectionState{})
807         c2.UnwrapSession(nil, ConnectionState{})
808         c2.WrapSession(ConnectionState{}, nil)
809
810         if called != (1<<expectedCount)-1 {
811                 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
812         }
813 }
814
815 func TestCloneNonFuncFields(t *testing.T) {
816         var c1 Config
817         v := reflect.ValueOf(&c1).Elem()
818
819         typ := v.Type()
820         for i := 0; i < typ.NumField(); i++ {
821                 f := v.Field(i)
822                 // testing/quick can't handle functions or interfaces and so
823                 // isn't used here.
824                 switch fn := typ.Field(i).Name; fn {
825                 case "Rand":
826                         f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
827                 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession":
828                         // DeepEqual can't compare functions. If you add a
829                         // function field to this list, you must also change
830                         // TestCloneFuncFields to ensure that the func field is
831                         // cloned.
832                 case "Certificates":
833                         f.Set(reflect.ValueOf([]Certificate{
834                                 {Certificate: [][]byte{{'b'}}},
835                         }))
836                 case "NameToCertificate":
837                         f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil}))
838                 case "RootCAs", "ClientCAs":
839                         f.Set(reflect.ValueOf(x509.NewCertPool()))
840                 case "ClientSessionCache":
841                         f.Set(reflect.ValueOf(NewLRUClientSessionCache(10)))
842                 case "KeyLogWriter":
843                         f.Set(reflect.ValueOf(io.Writer(os.Stdout)))
844                 case "NextProtos":
845                         f.Set(reflect.ValueOf([]string{"a", "b"}))
846                 case "ServerName":
847                         f.Set(reflect.ValueOf("b"))
848                 case "ClientAuth":
849                         f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
850                 case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
851                         f.Set(reflect.ValueOf(true))
852                 case "MinVersion", "MaxVersion":
853                         f.Set(reflect.ValueOf(uint16(VersionTLS12)))
854                 case "SessionTicketKey":
855                         f.Set(reflect.ValueOf([32]byte{}))
856                 case "CipherSuites":
857                         f.Set(reflect.ValueOf([]uint16{1, 2}))
858                 case "CurvePreferences":
859                         f.Set(reflect.ValueOf([]CurveID{CurveP256}))
860                 case "Renegotiation":
861                         f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
862                 case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
863                         continue // these are unexported fields that are handled separately
864                 default:
865                         t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
866                 }
867         }
868         // Set the unexported fields related to session ticket keys, which are copied with Clone().
869         c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
870         c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
871
872         c2 := c1.Clone()
873         if !reflect.DeepEqual(&c1, c2) {
874                 t.Errorf("clone failed to copy a field")
875         }
876 }
877
878 func TestCloneNilConfig(t *testing.T) {
879         var config *Config
880         if cc := config.Clone(); cc != nil {
881                 t.Fatalf("Clone with nil should return nil, got: %+v", cc)
882         }
883 }
884
885 // changeImplConn is a net.Conn which can change its Write and Close
886 // methods.
887 type changeImplConn struct {
888         net.Conn
889         writeFunc func([]byte) (int, error)
890         closeFunc func() error
891 }
892
893 func (w *changeImplConn) Write(p []byte) (n int, err error) {
894         if w.writeFunc != nil {
895                 return w.writeFunc(p)
896         }
897         return w.Conn.Write(p)
898 }
899
900 func (w *changeImplConn) Close() error {
901         if w.closeFunc != nil {
902                 return w.closeFunc()
903         }
904         return w.Conn.Close()
905 }
906
907 func throughput(b *testing.B, version uint16, totalBytes int64, dynamicRecordSizingDisabled bool) {
908         ln := newLocalListener(b)
909         defer ln.Close()
910
911         N := b.N
912
913         // Less than 64KB because Windows appears to use a TCP rwin < 64KB.
914         // See Issue #15899.
915         const bufsize = 32 << 10
916
917         go func() {
918                 buf := make([]byte, bufsize)
919                 for i := 0; i < N; i++ {
920                         sconn, err := ln.Accept()
921                         if err != nil {
922                                 // panic rather than synchronize to avoid benchmark overhead
923                                 // (cannot call b.Fatal in goroutine)
924                                 panic(fmt.Errorf("accept: %v", err))
925                         }
926                         serverConfig := testConfig.Clone()
927                         serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers
928                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
929                         srv := Server(sconn, serverConfig)
930                         if err := srv.Handshake(); err != nil {
931                                 panic(fmt.Errorf("handshake: %v", err))
932                         }
933                         if _, err := io.CopyBuffer(srv, srv, buf); err != nil {
934                                 panic(fmt.Errorf("copy buffer: %v", err))
935                         }
936                 }
937         }()
938
939         b.SetBytes(totalBytes)
940         clientConfig := testConfig.Clone()
941         clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers
942         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
943         clientConfig.MaxVersion = version
944
945         buf := make([]byte, bufsize)
946         chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
947         for i := 0; i < N; i++ {
948                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
949                 if err != nil {
950                         b.Fatal(err)
951                 }
952                 for j := 0; j < chunks; j++ {
953                         _, err := conn.Write(buf)
954                         if err != nil {
955                                 b.Fatal(err)
956                         }
957                         _, err = io.ReadFull(conn, buf)
958                         if err != nil {
959                                 b.Fatal(err)
960                         }
961                 }
962                 conn.Close()
963         }
964 }
965
966 func BenchmarkThroughput(b *testing.B) {
967         for _, mode := range []string{"Max", "Dynamic"} {
968                 for size := 1; size <= 64; size <<= 1 {
969                         name := fmt.Sprintf("%sPacket/%dMB", mode, size)
970                         b.Run(name, func(b *testing.B) {
971                                 b.Run("TLSv12", func(b *testing.B) {
972                                         throughput(b, VersionTLS12, int64(size<<20), mode == "Max")
973                                 })
974                                 b.Run("TLSv13", func(b *testing.B) {
975                                         throughput(b, VersionTLS13, int64(size<<20), mode == "Max")
976                                 })
977                         })
978                 }
979         }
980 }
981
982 type slowConn struct {
983         net.Conn
984         bps int
985 }
986
987 func (c *slowConn) Write(p []byte) (int, error) {
988         if c.bps == 0 {
989                 panic("too slow")
990         }
991         t0 := time.Now()
992         wrote := 0
993         for wrote < len(p) {
994                 time.Sleep(100 * time.Microsecond)
995                 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8
996                 if allowed > len(p) {
997                         allowed = len(p)
998                 }
999                 if wrote < allowed {
1000                         n, err := c.Conn.Write(p[wrote:allowed])
1001                         wrote += n
1002                         if err != nil {
1003                                 return wrote, err
1004                         }
1005                 }
1006         }
1007         return len(p), nil
1008 }
1009
1010 func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) {
1011         ln := newLocalListener(b)
1012         defer ln.Close()
1013
1014         N := b.N
1015
1016         go func() {
1017                 for i := 0; i < N; i++ {
1018                         sconn, err := ln.Accept()
1019                         if err != nil {
1020                                 // panic rather than synchronize to avoid benchmark overhead
1021                                 // (cannot call b.Fatal in goroutine)
1022                                 panic(fmt.Errorf("accept: %v", err))
1023                         }
1024                         serverConfig := testConfig.Clone()
1025                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
1026                         srv := Server(&slowConn{sconn, bps}, serverConfig)
1027                         if err := srv.Handshake(); err != nil {
1028                                 panic(fmt.Errorf("handshake: %v", err))
1029                         }
1030                         io.Copy(srv, srv)
1031                 }
1032         }()
1033
1034         clientConfig := testConfig.Clone()
1035         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
1036         clientConfig.MaxVersion = version
1037
1038         buf := make([]byte, 16384)
1039         peek := make([]byte, 1)
1040
1041         for i := 0; i < N; i++ {
1042                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
1043                 if err != nil {
1044                         b.Fatal(err)
1045                 }
1046                 // make sure we're connected and previous connection has stopped
1047                 if _, err := conn.Write(buf[:1]); err != nil {
1048                         b.Fatal(err)
1049                 }
1050                 if _, err := io.ReadFull(conn, peek); err != nil {
1051                         b.Fatal(err)
1052                 }
1053                 if _, err := conn.Write(buf); err != nil {
1054                         b.Fatal(err)
1055                 }
1056                 if _, err = io.ReadFull(conn, peek); err != nil {
1057                         b.Fatal(err)
1058                 }
1059                 conn.Close()
1060         }
1061 }
1062
1063 func BenchmarkLatency(b *testing.B) {
1064         for _, mode := range []string{"Max", "Dynamic"} {
1065                 for _, kbps := range []int{200, 500, 1000, 2000, 5000} {
1066                         name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps)
1067                         b.Run(name, func(b *testing.B) {
1068                                 b.Run("TLSv12", func(b *testing.B) {
1069                                         latency(b, VersionTLS12, kbps*1000, mode == "Max")
1070                                 })
1071                                 b.Run("TLSv13", func(b *testing.B) {
1072                                         latency(b, VersionTLS13, kbps*1000, mode == "Max")
1073                                 })
1074                         })
1075                 }
1076         }
1077 }
1078
1079 func TestConnectionStateMarshal(t *testing.T) {
1080         cs := &ConnectionState{}
1081         _, err := json.Marshal(cs)
1082         if err != nil {
1083                 t.Errorf("json.Marshal failed on ConnectionState: %v", err)
1084         }
1085 }
1086
1087 func TestConnectionState(t *testing.T) {
1088         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1089         if err != nil {
1090                 panic(err)
1091         }
1092         rootCAs := x509.NewCertPool()
1093         rootCAs.AddCert(issuer)
1094
1095         now := func() time.Time { return time.Unix(1476984729, 0) }
1096
1097         const alpnProtocol = "golang"
1098         const serverName = "example.golang"
1099         var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1100         var ocsp = []byte("dummy ocsp")
1101
1102         for _, v := range []uint16{VersionTLS12, VersionTLS13} {
1103                 var name string
1104                 switch v {
1105                 case VersionTLS12:
1106                         name = "TLSv12"
1107                 case VersionTLS13:
1108                         name = "TLSv13"
1109                 }
1110                 t.Run(name, func(t *testing.T) {
1111                         config := &Config{
1112                                 Time:         now,
1113                                 Rand:         zeroSource{},
1114                                 Certificates: make([]Certificate, 1),
1115                                 MaxVersion:   v,
1116                                 RootCAs:      rootCAs,
1117                                 ClientCAs:    rootCAs,
1118                                 ClientAuth:   RequireAndVerifyClientCert,
1119                                 NextProtos:   []string{alpnProtocol},
1120                                 ServerName:   serverName,
1121                         }
1122                         config.Certificates[0].Certificate = [][]byte{testRSACertificate}
1123                         config.Certificates[0].PrivateKey = testRSAPrivateKey
1124                         config.Certificates[0].SignedCertificateTimestamps = scts
1125                         config.Certificates[0].OCSPStaple = ocsp
1126
1127                         ss, cs, err := testHandshake(t, config, config)
1128                         if err != nil {
1129                                 t.Fatalf("Handshake failed: %v", err)
1130                         }
1131
1132                         if ss.Version != v || cs.Version != v {
1133                                 t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v)
1134                         }
1135
1136                         if !ss.HandshakeComplete || !cs.HandshakeComplete {
1137                                 t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete)
1138                         }
1139
1140                         if ss.DidResume || cs.DidResume {
1141                                 t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume)
1142                         }
1143
1144                         if ss.CipherSuite == 0 || cs.CipherSuite == 0 {
1145                                 t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite)
1146                         }
1147
1148                         if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol {
1149                                 t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol)
1150                         }
1151
1152                         if !cs.NegotiatedProtocolIsMutual {
1153                                 t.Errorf("Got false NegotiatedProtocolIsMutual on the client side")
1154                         }
1155                         // NegotiatedProtocolIsMutual on the server side is unspecified.
1156
1157                         if ss.ServerName != serverName {
1158                                 t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
1159                         }
1160                         if cs.ServerName != serverName {
1161                                 t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
1162                         }
1163
1164                         if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
1165                                 t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1)
1166                         }
1167
1168                         if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 {
1169                                 t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1)
1170                         } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 {
1171                                 t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2)
1172                         }
1173
1174                         if len(cs.SignedCertificateTimestamps) != 2 {
1175                                 t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2)
1176                         }
1177                         if !bytes.Equal(cs.OCSPResponse, ocsp) {
1178                                 t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp)
1179                         }
1180                         // Only TLS 1.3 supports OCSP and SCTs on client certs.
1181                         if v == VersionTLS13 {
1182                                 if len(ss.SignedCertificateTimestamps) != 2 {
1183                                         t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2)
1184                                 }
1185                                 if !bytes.Equal(ss.OCSPResponse, ocsp) {
1186                                         t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp)
1187                                 }
1188                         }
1189
1190                         if v == VersionTLS13 {
1191                                 if ss.TLSUnique != nil || cs.TLSUnique != nil {
1192                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique)
1193                                 }
1194                         } else {
1195                                 if ss.TLSUnique == nil || cs.TLSUnique == nil {
1196                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique)
1197                                 }
1198                         }
1199                 })
1200         }
1201 }
1202
1203 // Issue 28744: Ensure that we don't modify memory
1204 // that Config doesn't own such as Certificates.
1205 func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) {
1206         c0 := Certificate{
1207                 Certificate: [][]byte{testRSACertificate},
1208                 PrivateKey:  testRSAPrivateKey,
1209         }
1210         c1 := Certificate{
1211                 Certificate: [][]byte{testSNICertificate},
1212                 PrivateKey:  testRSAPrivateKey,
1213         }
1214         config := testConfig.Clone()
1215         config.Certificates = []Certificate{c0, c1}
1216
1217         config.BuildNameToCertificate()
1218         got := config.Certificates
1219         want := []Certificate{c0, c1}
1220         if !reflect.DeepEqual(got, want) {
1221                 t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want)
1222         }
1223 }
1224
1225 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
1226
1227 func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
1228         rsaCert := &Certificate{
1229                 Certificate: [][]byte{testRSACertificate},
1230                 PrivateKey:  testRSAPrivateKey,
1231         }
1232         pkcs1Cert := &Certificate{
1233                 Certificate:                  [][]byte{testRSACertificate},
1234                 PrivateKey:                   testRSAPrivateKey,
1235                 SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256},
1236         }
1237         ecdsaCert := &Certificate{
1238                 // ECDSA P-256 certificate
1239                 Certificate: [][]byte{testP256Certificate},
1240                 PrivateKey:  testP256PrivateKey,
1241         }
1242         ed25519Cert := &Certificate{
1243                 Certificate: [][]byte{testEd25519Certificate},
1244                 PrivateKey:  testEd25519PrivateKey,
1245         }
1246
1247         tests := []struct {
1248                 c       *Certificate
1249                 chi     *ClientHelloInfo
1250                 wantErr string
1251         }{
1252                 {rsaCert, &ClientHelloInfo{
1253                         ServerName:        "example.golang",
1254                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1255                         SupportedVersions: []uint16{VersionTLS13},
1256                 }, ""},
1257                 {ecdsaCert, &ClientHelloInfo{
1258                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1259                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1260                 }, ""},
1261                 {rsaCert, &ClientHelloInfo{
1262                         ServerName:        "example.com",
1263                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1264                         SupportedVersions: []uint16{VersionTLS13},
1265                 }, "not valid for requested server name"},
1266                 {ecdsaCert, &ClientHelloInfo{
1267                         SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1268                         SupportedVersions: []uint16{VersionTLS13},
1269                 }, "signature algorithms"},
1270                 {pkcs1Cert, &ClientHelloInfo{
1271                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1272                         SupportedVersions: []uint16{VersionTLS13},
1273                 }, "signature algorithms"},
1274
1275                 {rsaCert, &ClientHelloInfo{
1276                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1277                         SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1278                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1279                 }, "signature algorithms"},
1280                 {rsaCert, &ClientHelloInfo{
1281                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1282                         SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1283                         SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1284                         config: &Config{
1285                                 MaxVersion: VersionTLS12,
1286                         },
1287                 }, ""}, // Check that mutual version selection works.
1288
1289                 {ecdsaCert, &ClientHelloInfo{
1290                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1291                         SupportedCurves:   []CurveID{CurveP256},
1292                         SupportedPoints:   []uint8{pointFormatUncompressed},
1293                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1294                         SupportedVersions: []uint16{VersionTLS12},
1295                 }, ""},
1296                 {ecdsaCert, &ClientHelloInfo{
1297                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1298                         SupportedCurves:   []CurveID{CurveP256},
1299                         SupportedPoints:   []uint8{pointFormatUncompressed},
1300                         SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1301                         SupportedVersions: []uint16{VersionTLS12},
1302                 }, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme.
1303                 {ecdsaCert, &ClientHelloInfo{
1304                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1305                         SupportedCurves:   []CurveID{CurveP256},
1306                         SupportedPoints:   []uint8{pointFormatUncompressed},
1307                         SignatureSchemes:  nil,
1308                         SupportedVersions: []uint16{VersionTLS12},
1309                 }, ""}, // TLS 1.2 comes with default signature schemes.
1310                 {ecdsaCert, &ClientHelloInfo{
1311                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1312                         SupportedCurves:   []CurveID{CurveP256},
1313                         SupportedPoints:   []uint8{pointFormatUncompressed},
1314                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1315                         SupportedVersions: []uint16{VersionTLS12},
1316                 }, "cipher suite"},
1317                 {ecdsaCert, &ClientHelloInfo{
1318                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1319                         SupportedCurves:   []CurveID{CurveP256},
1320                         SupportedPoints:   []uint8{pointFormatUncompressed},
1321                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1322                         SupportedVersions: []uint16{VersionTLS12},
1323                         config: &Config{
1324                                 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1325                         },
1326                 }, "cipher suite"},
1327                 {ecdsaCert, &ClientHelloInfo{
1328                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1329                         SupportedCurves:   []CurveID{CurveP384},
1330                         SupportedPoints:   []uint8{pointFormatUncompressed},
1331                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1332                         SupportedVersions: []uint16{VersionTLS12},
1333                 }, "certificate curve"},
1334                 {ecdsaCert, &ClientHelloInfo{
1335                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1336                         SupportedCurves:   []CurveID{CurveP256},
1337                         SupportedPoints:   []uint8{1},
1338                         SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1339                         SupportedVersions: []uint16{VersionTLS12},
1340                 }, "doesn't support ECDHE"},
1341                 {ecdsaCert, &ClientHelloInfo{
1342                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1343                         SupportedCurves:   []CurveID{CurveP256},
1344                         SupportedPoints:   []uint8{pointFormatUncompressed},
1345                         SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1346                         SupportedVersions: []uint16{VersionTLS12},
1347                 }, "signature algorithms"},
1348
1349                 {ed25519Cert, &ClientHelloInfo{
1350                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1351                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1352                         SupportedPoints:   []uint8{pointFormatUncompressed},
1353                         SignatureSchemes:  []SignatureScheme{Ed25519},
1354                         SupportedVersions: []uint16{VersionTLS12},
1355                 }, ""},
1356                 {ed25519Cert, &ClientHelloInfo{
1357                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1358                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1359                         SupportedPoints:   []uint8{pointFormatUncompressed},
1360                         SignatureSchemes:  []SignatureScheme{Ed25519},
1361                         SupportedVersions: []uint16{VersionTLS10},
1362                 }, "doesn't support Ed25519"},
1363                 {ed25519Cert, &ClientHelloInfo{
1364                         CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1365                         SupportedCurves:   []CurveID{},
1366                         SupportedPoints:   []uint8{pointFormatUncompressed},
1367                         SignatureSchemes:  []SignatureScheme{Ed25519},
1368                         SupportedVersions: []uint16{VersionTLS12},
1369                 }, "doesn't support ECDHE"},
1370
1371                 {rsaCert, &ClientHelloInfo{
1372                         CipherSuites:      []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
1373                         SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1374                         SupportedPoints:   []uint8{pointFormatUncompressed},
1375                         SupportedVersions: []uint16{VersionTLS10},
1376                 }, ""},
1377                 {rsaCert, &ClientHelloInfo{
1378                         CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1379                         SupportedVersions: []uint16{VersionTLS12},
1380                 }, ""}, // static RSA fallback
1381         }
1382         for i, tt := range tests {
1383                 err := tt.chi.SupportsCertificate(tt.c)
1384                 switch {
1385                 case tt.wantErr == "" && err != nil:
1386                         t.Errorf("%d: unexpected error: %v", i, err)
1387                 case tt.wantErr != "" && err == nil:
1388                         t.Errorf("%d: unexpected success", i)
1389                 case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr):
1390                         t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr)
1391                 }
1392         }
1393 }
1394
1395 func TestCipherSuites(t *testing.T) {
1396         var lastID uint16
1397         for _, c := range CipherSuites() {
1398                 if lastID > c.ID {
1399                         t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1400                 } else {
1401                         lastID = c.ID
1402                 }
1403
1404                 if c.Insecure {
1405                         t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID)
1406                 }
1407         }
1408         lastID = 0
1409         for _, c := range InsecureCipherSuites() {
1410                 if lastID > c.ID {
1411                         t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1412                 } else {
1413                         lastID = c.ID
1414                 }
1415
1416                 if !c.Insecure {
1417                         t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID)
1418                 }
1419         }
1420
1421         CipherSuiteByID := func(id uint16) *CipherSuite {
1422                 for _, c := range CipherSuites() {
1423                         if c.ID == id {
1424                                 return c
1425                         }
1426                 }
1427                 for _, c := range InsecureCipherSuites() {
1428                         if c.ID == id {
1429                                 return c
1430                         }
1431                 }
1432                 return nil
1433         }
1434
1435         for _, c := range cipherSuites {
1436                 cc := CipherSuiteByID(c.id)
1437                 if cc == nil {
1438                         t.Errorf("%#04x: no CipherSuite entry", c.id)
1439                         continue
1440                 }
1441
1442                 if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
1443                         t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1444                 } else if !tls12Only && len(cc.SupportedVersions) != 3 {
1445                         t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1446                 }
1447
1448                 if got := CipherSuiteName(c.id); got != cc.Name {
1449                         t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1450                 }
1451         }
1452         for _, c := range cipherSuitesTLS13 {
1453                 cc := CipherSuiteByID(c.id)
1454                 if cc == nil {
1455                         t.Errorf("%#04x: no CipherSuite entry", c.id)
1456                         continue
1457                 }
1458
1459                 if cc.Insecure {
1460                         t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure)
1461                 }
1462                 if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 {
1463                         t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1464                 }
1465
1466                 if got := CipherSuiteName(c.id); got != cc.Name {
1467                         t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1468                 }
1469         }
1470
1471         if got := CipherSuiteName(0xabc); got != "0x0ABC" {
1472                 t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
1473         }
1474
1475         if len(cipherSuitesPreferenceOrder) != len(cipherSuites) {
1476                 t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites")
1477         }
1478         if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) {
1479                 t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder")
1480         }
1481
1482         // Check that disabled suites are at the end of the preference lists, and
1483         // that they are marked insecure.
1484         for i, id := range disabledCipherSuites {
1485                 offset := len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
1486                 if cipherSuitesPreferenceOrder[offset+i] != id {
1487                         t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrder", i)
1488                 }
1489                 if cipherSuitesPreferenceOrderNoAES[offset+i] != id {
1490                         t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrderNoAES", i)
1491                 }
1492                 c := CipherSuiteByID(id)
1493                 if c == nil {
1494                         t.Errorf("%#04x: no CipherSuite entry", id)
1495                         continue
1496                 }
1497                 if !c.Insecure {
1498                         t.Errorf("%#04x: disabled by default but not marked insecure", id)
1499                 }
1500         }
1501
1502         for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} {
1503                 // Check that insecure and HTTP/2 bad cipher suites are at the end of
1504                 // the preference lists.
1505                 var sawInsecure, sawBad bool
1506                 for _, id := range prefOrder {
1507                         c := CipherSuiteByID(id)
1508                         if c == nil {
1509                                 t.Errorf("%#04x: no CipherSuite entry", id)
1510                                 continue
1511                         }
1512
1513                         if c.Insecure {
1514                                 sawInsecure = true
1515                         } else if sawInsecure {
1516                                 t.Errorf("%#04x: secure suite after insecure one(s)", id)
1517                         }
1518
1519                         if http2isBadCipher(id) {
1520                                 sawBad = true
1521                         } else if sawBad {
1522                                 t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id)
1523                         }
1524                 }
1525
1526                 // Check that the list is sorted according to the documented criteria.
1527                 isBetter := func(a, b int) bool {
1528                         aSuite, bSuite := cipherSuiteByID(prefOrder[a]), cipherSuiteByID(prefOrder[b])
1529                         aName, bName := CipherSuiteName(prefOrder[a]), CipherSuiteName(prefOrder[b])
1530                         // * < RC4
1531                         if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") {
1532                                 return true
1533                         } else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") {
1534                                 return false
1535                         }
1536                         // * < CBC_SHA256
1537                         if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") {
1538                                 return true
1539                         } else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") {
1540                                 return false
1541                         }
1542                         // * < 3DES
1543                         if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") {
1544                                 return true
1545                         } else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") {
1546                                 return false
1547                         }
1548                         // ECDHE < *
1549                         if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 {
1550                                 return true
1551                         } else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 {
1552                                 return false
1553                         }
1554                         // AEAD < CBC
1555                         if aSuite.aead != nil && bSuite.aead == nil {
1556                                 return true
1557                         } else if aSuite.aead == nil && bSuite.aead != nil {
1558                                 return false
1559                         }
1560                         // AES < ChaCha20
1561                         if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") {
1562                                 return i == 0 // true for cipherSuitesPreferenceOrder
1563                         } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") {
1564                                 return i != 0 // true for cipherSuitesPreferenceOrderNoAES
1565                         }
1566                         // AES-128 < AES-256
1567                         if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") {
1568                                 return true
1569                         } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") {
1570                                 return false
1571                         }
1572                         // ECDSA < RSA
1573                         if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 {
1574                                 return true
1575                         } else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 {
1576                                 return false
1577                         }
1578                         t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName)
1579                         panic("unreachable")
1580                 }
1581                 if !sort.SliceIsSorted(prefOrder, isBetter) {
1582                         t.Error("preference order is not sorted according to the rules")
1583                 }
1584         }
1585 }
1586
1587 // http2isBadCipher is copied from net/http.
1588 // TODO: if it ends up exposed somewhere, use that instead.
1589 func http2isBadCipher(cipher uint16) bool {
1590         switch cipher {
1591         case TLS_RSA_WITH_RC4_128_SHA,
1592                 TLS_RSA_WITH_3DES_EDE_CBC_SHA,
1593                 TLS_RSA_WITH_AES_128_CBC_SHA,
1594                 TLS_RSA_WITH_AES_256_CBC_SHA,
1595                 TLS_RSA_WITH_AES_128_CBC_SHA256,
1596                 TLS_RSA_WITH_AES_128_GCM_SHA256,
1597                 TLS_RSA_WITH_AES_256_GCM_SHA384,
1598                 TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
1599                 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
1600                 TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
1601                 TLS_ECDHE_RSA_WITH_RC4_128_SHA,
1602                 TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
1603                 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
1604                 TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
1605                 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
1606                 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
1607                 return true
1608         default:
1609                 return false
1610         }
1611 }
1612
1613 type brokenSigner struct{ crypto.Signer }
1614
1615 func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
1616         // Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded.
1617         return s.Signer.Sign(rand, digest, opts.HashFunc())
1618 }
1619
1620 // TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that
1621 // always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS.
1622 func TestPKCS1OnlyCert(t *testing.T) {
1623         clientConfig := testConfig.Clone()
1624         clientConfig.Certificates = []Certificate{{
1625                 Certificate: [][]byte{testRSACertificate},
1626                 PrivateKey:  brokenSigner{testRSAPrivateKey},
1627         }}
1628         serverConfig := testConfig.Clone()
1629         serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5
1630         serverConfig.ClientAuth = RequireAnyClientCert
1631
1632         // If RSA-PSS is selected, the handshake should fail.
1633         if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil {
1634                 t.Fatal("expected broken certificate to cause connection to fail")
1635         }
1636
1637         clientConfig.Certificates[0].SupportedSignatureAlgorithms =
1638                 []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}
1639
1640         // But if the certificate restricts supported algorithms, RSA-PSS should not
1641         // be selected, and the handshake should succeed.
1642         if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
1643                 t.Error(err)
1644         }
1645 }