]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/tls_test.go
crypto/tls: set ServerName and unset TLSUnique in ConnectionState in TLS 1.3
[gostls13.git] / src / crypto / tls / tls_test.go
1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "crypto/x509"
10         "encoding/json"
11         "errors"
12         "fmt"
13         "internal/testenv"
14         "io"
15         "io/ioutil"
16         "math"
17         "net"
18         "os"
19         "reflect"
20         "strings"
21         "testing"
22         "time"
23 )
24
25 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
26 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
27 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
28 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
29 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
30 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
31 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
32 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
33 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
34 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
35 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
36 -----END CERTIFICATE-----
37 `
38
39 var rsaKeyPEM = `-----BEGIN RSA PRIVATE KEY-----
40 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
41 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
42 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
43 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
44 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
45 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
46 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
47 -----END RSA PRIVATE KEY-----
48 `
49
50 // keyPEM is the same as rsaKeyPEM, but declares itself as just
51 // "PRIVATE KEY", not "RSA PRIVATE KEY".  https://golang.org/issue/4477
52 var keyPEM = `-----BEGIN PRIVATE KEY-----
53 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
54 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
55 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
56 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
57 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
58 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
59 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
60 -----END PRIVATE KEY-----
61 `
62
63 var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
64 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
65 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
66 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG
67 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk
68 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR
69 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl
70 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8
71 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo
72 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb
73 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1
74 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA==
75 -----END CERTIFICATE-----
76 `
77
78 var ecdsaKeyPEM = `-----BEGIN EC PARAMETERS-----
79 BgUrgQQAIw==
80 -----END EC PARAMETERS-----
81 -----BEGIN EC PRIVATE KEY-----
82 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
83 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
84 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
85 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
86 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
87 -----END EC PRIVATE KEY-----
88 `
89
90 var keyPairTests = []struct {
91         algo string
92         cert string
93         key  string
94 }{
95         {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM},
96         {"RSA", rsaCertPEM, rsaKeyPEM},
97         {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477
98 }
99
100 func TestX509KeyPair(t *testing.T) {
101         t.Parallel()
102         var pem []byte
103         for _, test := range keyPairTests {
104                 pem = []byte(test.cert + test.key)
105                 if _, err := X509KeyPair(pem, pem); err != nil {
106                         t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err)
107                 }
108                 pem = []byte(test.key + test.cert)
109                 if _, err := X509KeyPair(pem, pem); err != nil {
110                         t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err)
111                 }
112         }
113 }
114
115 func TestX509KeyPairErrors(t *testing.T) {
116         _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM))
117         if err == nil {
118                 t.Fatalf("X509KeyPair didn't return an error when arguments were switched")
119         }
120         if subStr := "been switched"; !strings.Contains(err.Error(), subStr) {
121                 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err)
122         }
123
124         _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM))
125         if err == nil {
126                 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates")
127         }
128         if subStr := "certificate"; !strings.Contains(err.Error(), subStr) {
129                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err)
130         }
131
132         const nonsensePEM = `
133 -----BEGIN NONSENSE-----
134 Zm9vZm9vZm9v
135 -----END NONSENSE-----
136 `
137
138         _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM))
139         if err == nil {
140                 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense")
141         }
142         if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) {
143                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err)
144         }
145 }
146
147 func TestX509MixedKeyPair(t *testing.T) {
148         if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil {
149                 t.Error("Load of RSA certificate succeeded with ECDSA private key")
150         }
151         if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil {
152                 t.Error("Load of ECDSA certificate succeeded with RSA private key")
153         }
154 }
155
156 func newLocalListener(t testing.TB) net.Listener {
157         ln, err := net.Listen("tcp", "127.0.0.1:0")
158         if err != nil {
159                 ln, err = net.Listen("tcp6", "[::1]:0")
160         }
161         if err != nil {
162                 t.Fatal(err)
163         }
164         return ln
165 }
166
167 func TestDialTimeout(t *testing.T) {
168         if testing.Short() {
169                 t.Skip("skipping in short mode")
170         }
171         listener := newLocalListener(t)
172
173         addr := listener.Addr().String()
174         defer listener.Close()
175
176         complete := make(chan bool)
177         defer close(complete)
178
179         go func() {
180                 conn, err := listener.Accept()
181                 if err != nil {
182                         t.Error(err)
183                         return
184                 }
185                 <-complete
186                 conn.Close()
187         }()
188
189         dialer := &net.Dialer{
190                 Timeout: 10 * time.Millisecond,
191         }
192
193         var err error
194         if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
195                 t.Fatal("DialWithTimeout completed successfully")
196         }
197
198         if !isTimeoutError(err) {
199                 t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
200         }
201 }
202
203 func isTimeoutError(err error) bool {
204         if ne, ok := err.(net.Error); ok {
205                 return ne.Timeout()
206         }
207         return false
208 }
209
210 // tests that Conn.Read returns (non-zero, io.EOF) instead of
211 // (non-zero, nil) when a Close (alertCloseNotify) is sitting right
212 // behind the application data in the buffer.
213 func TestConnReadNonzeroAndEOF(t *testing.T) {
214         // This test is racy: it assumes that after a write to a
215         // localhost TCP connection, the peer TCP connection can
216         // immediately read it. Because it's racy, we skip this test
217         // in short mode, and then retry it several times with an
218         // increasing sleep in between our final write (via srv.Close
219         // below) and the following read.
220         if testing.Short() {
221                 t.Skip("skipping in short mode")
222         }
223         var err error
224         for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 {
225                 if err = testConnReadNonzeroAndEOF(t, delay); err == nil {
226                         return
227                 }
228         }
229         t.Error(err)
230 }
231
232 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error {
233         ln := newLocalListener(t)
234         defer ln.Close()
235
236         srvCh := make(chan *Conn, 1)
237         var serr error
238         go func() {
239                 sconn, err := ln.Accept()
240                 if err != nil {
241                         serr = err
242                         srvCh <- nil
243                         return
244                 }
245                 serverConfig := testConfig.Clone()
246                 srv := Server(sconn, serverConfig)
247                 if err := srv.Handshake(); err != nil {
248                         serr = fmt.Errorf("handshake: %v", err)
249                         srvCh <- nil
250                         return
251                 }
252                 srvCh <- srv
253         }()
254
255         clientConfig := testConfig.Clone()
256         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
257         if err != nil {
258                 t.Fatal(err)
259         }
260         defer conn.Close()
261
262         srv := <-srvCh
263         if srv == nil {
264                 return serr
265         }
266
267         buf := make([]byte, 6)
268
269         srv.Write([]byte("foobar"))
270         n, err := conn.Read(buf)
271         if n != 6 || err != nil || string(buf) != "foobar" {
272                 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
273         }
274
275         srv.Write([]byte("abcdef"))
276         srv.Close()
277         time.Sleep(delay)
278         n, err = conn.Read(buf)
279         if n != 6 || string(buf) != "abcdef" {
280                 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf)
281         }
282         if err != io.EOF {
283                 return fmt.Errorf("Second Read error = %v; want io.EOF", err)
284         }
285         return nil
286 }
287
288 func TestTLSUniqueMatches(t *testing.T) {
289         ln := newLocalListener(t)
290         defer ln.Close()
291
292         serverTLSUniques := make(chan []byte)
293         go func() {
294                 for i := 0; i < 2; i++ {
295                         sconn, err := ln.Accept()
296                         if err != nil {
297                                 t.Error(err)
298                                 return
299                         }
300                         serverConfig := testConfig.Clone()
301                         srv := Server(sconn, serverConfig)
302                         if err := srv.Handshake(); err != nil {
303                                 t.Error(err)
304                                 return
305                         }
306                         serverTLSUniques <- srv.ConnectionState().TLSUnique
307                 }
308         }()
309
310         clientConfig := testConfig.Clone()
311         clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
312         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
313         if err != nil {
314                 t.Fatal(err)
315         }
316         if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) {
317                 t.Error("client and server channel bindings differ")
318         }
319         conn.Close()
320
321         conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
322         if err != nil {
323                 t.Fatal(err)
324         }
325         defer conn.Close()
326         if !conn.ConnectionState().DidResume {
327                 t.Error("second session did not use resumption")
328         }
329         if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) {
330                 t.Error("client and server channel bindings differ when session resumption is used")
331         }
332 }
333
334 func TestVerifyHostname(t *testing.T) {
335         testenv.MustHaveExternalNetwork(t)
336
337         c, err := Dial("tcp", "www.google.com:https", nil)
338         if err != nil {
339                 t.Fatal(err)
340         }
341         if err := c.VerifyHostname("www.google.com"); err != nil {
342                 t.Fatalf("verify www.google.com: %v", err)
343         }
344         if err := c.VerifyHostname("www.yahoo.com"); err == nil {
345                 t.Fatalf("verify www.yahoo.com succeeded")
346         }
347
348         c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true})
349         if err != nil {
350                 t.Fatal(err)
351         }
352         if err := c.VerifyHostname("www.google.com"); err == nil {
353                 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
354         }
355         if err := c.VerifyHostname("www.yahoo.com"); err == nil {
356                 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
357         }
358 }
359
360 func TestVerifyHostnameResumed(t *testing.T) {
361         testenv.MustHaveExternalNetwork(t)
362
363         config := &Config{
364                 ClientSessionCache: NewLRUClientSessionCache(32),
365         }
366         for i := 0; i < 2; i++ {
367                 c, err := Dial("tcp", "www.google.com:https", config)
368                 if err != nil {
369                         t.Fatalf("Dial #%d: %v", i, err)
370                 }
371                 cs := c.ConnectionState()
372                 if i > 0 && !cs.DidResume {
373                         t.Fatalf("Subsequent connection unexpectedly didn't resume")
374                 }
375                 if cs.VerifiedChains == nil {
376                         t.Fatalf("Dial #%d: cs.VerifiedChains == nil", i)
377                 }
378                 if err := c.VerifyHostname("www.google.com"); err != nil {
379                         t.Fatalf("verify www.google.com #%d: %v", i, err)
380                 }
381                 c.Close()
382         }
383 }
384
385 func TestConnCloseBreakingWrite(t *testing.T) {
386         ln := newLocalListener(t)
387         defer ln.Close()
388
389         srvCh := make(chan *Conn, 1)
390         var serr error
391         var sconn net.Conn
392         go func() {
393                 var err error
394                 sconn, err = ln.Accept()
395                 if err != nil {
396                         serr = err
397                         srvCh <- nil
398                         return
399                 }
400                 serverConfig := testConfig.Clone()
401                 srv := Server(sconn, serverConfig)
402                 if err := srv.Handshake(); err != nil {
403                         serr = fmt.Errorf("handshake: %v", err)
404                         srvCh <- nil
405                         return
406                 }
407                 srvCh <- srv
408         }()
409
410         cconn, err := net.Dial("tcp", ln.Addr().String())
411         if err != nil {
412                 t.Fatal(err)
413         }
414         defer cconn.Close()
415
416         conn := &changeImplConn{
417                 Conn: cconn,
418         }
419
420         clientConfig := testConfig.Clone()
421         tconn := Client(conn, clientConfig)
422         if err := tconn.Handshake(); err != nil {
423                 t.Fatal(err)
424         }
425
426         srv := <-srvCh
427         if srv == nil {
428                 t.Fatal(serr)
429         }
430         defer sconn.Close()
431
432         connClosed := make(chan struct{})
433         conn.closeFunc = func() error {
434                 close(connClosed)
435                 return nil
436         }
437
438         inWrite := make(chan bool, 1)
439         var errConnClosed = errors.New("conn closed for test")
440         conn.writeFunc = func(p []byte) (n int, err error) {
441                 inWrite <- true
442                 <-connClosed
443                 return 0, errConnClosed
444         }
445
446         closeReturned := make(chan bool, 1)
447         go func() {
448                 <-inWrite
449                 tconn.Close() // test that this doesn't block forever.
450                 closeReturned <- true
451         }()
452
453         _, err = tconn.Write([]byte("foo"))
454         if err != errConnClosed {
455                 t.Errorf("Write error = %v; want errConnClosed", err)
456         }
457
458         <-closeReturned
459         if err := tconn.Close(); err != errClosed {
460                 t.Errorf("Close error = %v; want errClosed", err)
461         }
462 }
463
464 func TestConnCloseWrite(t *testing.T) {
465         ln := newLocalListener(t)
466         defer ln.Close()
467
468         clientDoneChan := make(chan struct{})
469
470         serverCloseWrite := func() error {
471                 sconn, err := ln.Accept()
472                 if err != nil {
473                         return fmt.Errorf("accept: %v", err)
474                 }
475                 defer sconn.Close()
476
477                 serverConfig := testConfig.Clone()
478                 srv := Server(sconn, serverConfig)
479                 if err := srv.Handshake(); err != nil {
480                         return fmt.Errorf("handshake: %v", err)
481                 }
482                 defer srv.Close()
483
484                 data, err := ioutil.ReadAll(srv)
485                 if err != nil {
486                         return err
487                 }
488                 if len(data) > 0 {
489                         return fmt.Errorf("Read data = %q; want nothing", data)
490                 }
491
492                 if err := srv.CloseWrite(); err != nil {
493                         return fmt.Errorf("server CloseWrite: %v", err)
494                 }
495
496                 // Wait for clientCloseWrite to finish, so we know we
497                 // tested the CloseWrite before we defer the
498                 // sconn.Close above, which would also cause the
499                 // client to unblock like CloseWrite.
500                 <-clientDoneChan
501                 return nil
502         }
503
504         clientCloseWrite := func() error {
505                 defer close(clientDoneChan)
506
507                 clientConfig := testConfig.Clone()
508                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
509                 if err != nil {
510                         return err
511                 }
512                 if err := conn.Handshake(); err != nil {
513                         return err
514                 }
515                 defer conn.Close()
516
517                 if err := conn.CloseWrite(); err != nil {
518                         return fmt.Errorf("client CloseWrite: %v", err)
519                 }
520
521                 if _, err := conn.Write([]byte{0}); err != errShutdown {
522                         return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
523                 }
524
525                 data, err := ioutil.ReadAll(conn)
526                 if err != nil {
527                         return err
528                 }
529                 if len(data) > 0 {
530                         return fmt.Errorf("Read data = %q; want nothing", data)
531                 }
532                 return nil
533         }
534
535         errChan := make(chan error, 2)
536
537         go func() { errChan <- serverCloseWrite() }()
538         go func() { errChan <- clientCloseWrite() }()
539
540         for i := 0; i < 2; i++ {
541                 select {
542                 case err := <-errChan:
543                         if err != nil {
544                                 t.Fatal(err)
545                         }
546                 case <-time.After(10 * time.Second):
547                         t.Fatal("deadlock")
548                 }
549         }
550
551         // Also test CloseWrite being called before the handshake is
552         // finished:
553         {
554                 ln2 := newLocalListener(t)
555                 defer ln2.Close()
556
557                 netConn, err := net.Dial("tcp", ln2.Addr().String())
558                 if err != nil {
559                         t.Fatal(err)
560                 }
561                 defer netConn.Close()
562                 conn := Client(netConn, testConfig.Clone())
563
564                 if err := conn.CloseWrite(); err != errEarlyCloseWrite {
565                         t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
566                 }
567         }
568 }
569
570 func TestWarningAlertFlood(t *testing.T) {
571         ln := newLocalListener(t)
572         defer ln.Close()
573
574         server := func() error {
575                 sconn, err := ln.Accept()
576                 if err != nil {
577                         return fmt.Errorf("accept: %v", err)
578                 }
579                 defer sconn.Close()
580
581                 serverConfig := testConfig.Clone()
582                 srv := Server(sconn, serverConfig)
583                 if err := srv.Handshake(); err != nil {
584                         return fmt.Errorf("handshake: %v", err)
585                 }
586                 defer srv.Close()
587
588                 _, err = ioutil.ReadAll(srv)
589                 if err == nil {
590                         return errors.New("unexpected lack of error from server")
591                 }
592                 const expected = "too many ignored"
593                 if str := err.Error(); !strings.Contains(str, expected) {
594                         return fmt.Errorf("expected error containing %q, but saw: %s", expected, str)
595                 }
596
597                 return nil
598         }
599
600         errChan := make(chan error, 1)
601         go func() { errChan <- server() }()
602
603         clientConfig := testConfig.Clone()
604         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
605         if err != nil {
606                 t.Fatal(err)
607         }
608         defer conn.Close()
609         if err := conn.Handshake(); err != nil {
610                 t.Fatal(err)
611         }
612
613         for i := 0; i < maxUselessRecords+1; i++ {
614                 conn.sendAlert(alertNoRenegotiation)
615         }
616
617         if err := <-errChan; err != nil {
618                 t.Fatal(err)
619         }
620 }
621
622 func TestCloneFuncFields(t *testing.T) {
623         const expectedCount = 5
624         called := 0
625
626         c1 := Config{
627                 Time: func() time.Time {
628                         called |= 1 << 0
629                         return time.Time{}
630                 },
631                 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
632                         called |= 1 << 1
633                         return nil, nil
634                 },
635                 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) {
636                         called |= 1 << 2
637                         return nil, nil
638                 },
639                 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) {
640                         called |= 1 << 3
641                         return nil, nil
642                 },
643                 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
644                         called |= 1 << 4
645                         return nil
646                 },
647         }
648
649         c2 := c1.Clone()
650
651         c2.Time()
652         c2.GetCertificate(nil)
653         c2.GetClientCertificate(nil)
654         c2.GetConfigForClient(nil)
655         c2.VerifyPeerCertificate(nil, nil)
656
657         if called != (1<<expectedCount)-1 {
658                 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
659         }
660 }
661
662 func TestCloneNonFuncFields(t *testing.T) {
663         var c1 Config
664         v := reflect.ValueOf(&c1).Elem()
665
666         typ := v.Type()
667         for i := 0; i < typ.NumField(); i++ {
668                 f := v.Field(i)
669                 if !f.CanSet() {
670                         // unexported field; not cloned.
671                         continue
672                 }
673
674                 // testing/quick can't handle functions or interfaces and so
675                 // isn't used here.
676                 switch fn := typ.Field(i).Name; fn {
677                 case "Rand":
678                         f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
679                 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
680                         // DeepEqual can't compare functions. If you add a
681                         // function field to this list, you must also change
682                         // TestCloneFuncFields to ensure that the func field is
683                         // cloned.
684                 case "Certificates":
685                         f.Set(reflect.ValueOf([]Certificate{
686                                 {Certificate: [][]byte{{'b'}}},
687                         }))
688                 case "NameToCertificate":
689                         f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil}))
690                 case "RootCAs", "ClientCAs":
691                         f.Set(reflect.ValueOf(x509.NewCertPool()))
692                 case "ClientSessionCache":
693                         f.Set(reflect.ValueOf(NewLRUClientSessionCache(10)))
694                 case "KeyLogWriter":
695                         f.Set(reflect.ValueOf(io.Writer(os.Stdout)))
696                 case "NextProtos":
697                         f.Set(reflect.ValueOf([]string{"a", "b"}))
698                 case "ServerName":
699                         f.Set(reflect.ValueOf("b"))
700                 case "ClientAuth":
701                         f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
702                 case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
703                         f.Set(reflect.ValueOf(true))
704                 case "MinVersion", "MaxVersion":
705                         f.Set(reflect.ValueOf(uint16(VersionTLS12)))
706                 case "SessionTicketKey":
707                         f.Set(reflect.ValueOf([32]byte{}))
708                 case "CipherSuites":
709                         f.Set(reflect.ValueOf([]uint16{1, 2}))
710                 case "CurvePreferences":
711                         f.Set(reflect.ValueOf([]CurveID{CurveP256}))
712                 case "Renegotiation":
713                         f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
714                 default:
715                         t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
716                 }
717         }
718
719         c2 := c1.Clone()
720         // DeepEqual also compares unexported fields, thus c2 needs to have run
721         // serverInit in order to be DeepEqual to c1. Cloning it and discarding
722         // the result is sufficient.
723         c2.Clone()
724
725         if !reflect.DeepEqual(&c1, c2) {
726                 t.Errorf("clone failed to copy a field")
727         }
728 }
729
730 // changeImplConn is a net.Conn which can change its Write and Close
731 // methods.
732 type changeImplConn struct {
733         net.Conn
734         writeFunc func([]byte) (int, error)
735         closeFunc func() error
736 }
737
738 func (w *changeImplConn) Write(p []byte) (n int, err error) {
739         if w.writeFunc != nil {
740                 return w.writeFunc(p)
741         }
742         return w.Conn.Write(p)
743 }
744
745 func (w *changeImplConn) Close() error {
746         if w.closeFunc != nil {
747                 return w.closeFunc()
748         }
749         return w.Conn.Close()
750 }
751
752 func throughput(b *testing.B, totalBytes int64, dynamicRecordSizingDisabled bool) {
753         ln := newLocalListener(b)
754         defer ln.Close()
755
756         N := b.N
757
758         // Less than 64KB because Windows appears to use a TCP rwin < 64KB.
759         // See Issue #15899.
760         const bufsize = 32 << 10
761
762         go func() {
763                 buf := make([]byte, bufsize)
764                 for i := 0; i < N; i++ {
765                         sconn, err := ln.Accept()
766                         if err != nil {
767                                 // panic rather than synchronize to avoid benchmark overhead
768                                 // (cannot call b.Fatal in goroutine)
769                                 panic(fmt.Errorf("accept: %v", err))
770                         }
771                         serverConfig := testConfig.Clone()
772                         serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers
773                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
774                         srv := Server(sconn, serverConfig)
775                         if err := srv.Handshake(); err != nil {
776                                 panic(fmt.Errorf("handshake: %v", err))
777                         }
778                         if _, err := io.CopyBuffer(srv, srv, buf); err != nil {
779                                 panic(fmt.Errorf("copy buffer: %v", err))
780                         }
781                 }
782         }()
783
784         b.SetBytes(totalBytes)
785         clientConfig := testConfig.Clone()
786         clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers
787         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
788
789         buf := make([]byte, bufsize)
790         chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
791         for i := 0; i < N; i++ {
792                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
793                 if err != nil {
794                         b.Fatal(err)
795                 }
796                 for j := 0; j < chunks; j++ {
797                         _, err := conn.Write(buf)
798                         if err != nil {
799                                 b.Fatal(err)
800                         }
801                         _, err = io.ReadFull(conn, buf)
802                         if err != nil {
803                                 b.Fatal(err)
804                         }
805                 }
806                 conn.Close()
807         }
808 }
809
810 func BenchmarkThroughput(b *testing.B) {
811         for _, mode := range []string{"Max", "Dynamic"} {
812                 for size := 1; size <= 64; size <<= 1 {
813                         name := fmt.Sprintf("%sPacket/%dMB", mode, size)
814                         b.Run(name, func(b *testing.B) {
815                                 throughput(b, int64(size<<20), mode == "Max")
816                         })
817                 }
818         }
819 }
820
821 type slowConn struct {
822         net.Conn
823         bps int
824 }
825
826 func (c *slowConn) Write(p []byte) (int, error) {
827         if c.bps == 0 {
828                 panic("too slow")
829         }
830         t0 := time.Now()
831         wrote := 0
832         for wrote < len(p) {
833                 time.Sleep(100 * time.Microsecond)
834                 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8
835                 if allowed > len(p) {
836                         allowed = len(p)
837                 }
838                 if wrote < allowed {
839                         n, err := c.Conn.Write(p[wrote:allowed])
840                         wrote += n
841                         if err != nil {
842                                 return wrote, err
843                         }
844                 }
845         }
846         return len(p), nil
847 }
848
849 func latency(b *testing.B, bps int, dynamicRecordSizingDisabled bool) {
850         ln := newLocalListener(b)
851         defer ln.Close()
852
853         N := b.N
854
855         go func() {
856                 for i := 0; i < N; i++ {
857                         sconn, err := ln.Accept()
858                         if err != nil {
859                                 // panic rather than synchronize to avoid benchmark overhead
860                                 // (cannot call b.Fatal in goroutine)
861                                 panic(fmt.Errorf("accept: %v", err))
862                         }
863                         serverConfig := testConfig.Clone()
864                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
865                         srv := Server(&slowConn{sconn, bps}, serverConfig)
866                         if err := srv.Handshake(); err != nil {
867                                 panic(fmt.Errorf("handshake: %v", err))
868                         }
869                         io.Copy(srv, srv)
870                 }
871         }()
872
873         clientConfig := testConfig.Clone()
874         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
875
876         buf := make([]byte, 16384)
877         peek := make([]byte, 1)
878
879         for i := 0; i < N; i++ {
880                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
881                 if err != nil {
882                         b.Fatal(err)
883                 }
884                 // make sure we're connected and previous connection has stopped
885                 if _, err := conn.Write(buf[:1]); err != nil {
886                         b.Fatal(err)
887                 }
888                 if _, err := io.ReadFull(conn, peek); err != nil {
889                         b.Fatal(err)
890                 }
891                 if _, err := conn.Write(buf); err != nil {
892                         b.Fatal(err)
893                 }
894                 if _, err = io.ReadFull(conn, peek); err != nil {
895                         b.Fatal(err)
896                 }
897                 conn.Close()
898         }
899 }
900
901 func BenchmarkLatency(b *testing.B) {
902         for _, mode := range []string{"Max", "Dynamic"} {
903                 for _, kbps := range []int{200, 500, 1000, 2000, 5000} {
904                         name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps)
905                         b.Run(name, func(b *testing.B) {
906                                 latency(b, kbps*1000, mode == "Max")
907                         })
908                 }
909         }
910 }
911
912 func TestConnectionStateMarshal(t *testing.T) {
913         cs := &ConnectionState{}
914         _, err := json.Marshal(cs)
915         if err != nil {
916                 t.Errorf("json.Marshal failed on ConnectionState: %v", err)
917         }
918 }
919
920 func TestConnectionState(t *testing.T) {
921         issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
922         if err != nil {
923                 panic(err)
924         }
925         rootCAs := x509.NewCertPool()
926         rootCAs.AddCert(issuer)
927
928         now := func() time.Time { return time.Unix(1476984729, 0) }
929
930         const alpnProtocol = "golang"
931         const serverName = "example.golang"
932         var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
933         var ocsp = []byte("dummy ocsp")
934
935         for _, v := range []uint16{VersionTLS12, VersionTLS13} {
936                 var name string
937                 switch v {
938                 case VersionTLS12:
939                         name = "TLSv12"
940                 case VersionTLS13:
941                         name = "TLSv13"
942                 }
943                 t.Run(name, func(t *testing.T) {
944                         config := &Config{
945                                 Time:         now,
946                                 Rand:         zeroSource{},
947                                 Certificates: make([]Certificate, 1),
948                                 MaxVersion:   v,
949                                 RootCAs:      rootCAs,
950                                 ClientCAs:    rootCAs,
951                                 ClientAuth:   RequireAndVerifyClientCert,
952                                 NextProtos:   []string{alpnProtocol},
953                                 ServerName:   serverName,
954                         }
955                         config.Certificates[0].Certificate = [][]byte{testRSACertificate}
956                         config.Certificates[0].PrivateKey = testRSAPrivateKey
957                         config.Certificates[0].SignedCertificateTimestamps = scts
958                         config.Certificates[0].OCSPStaple = ocsp
959
960                         ss, cs, err := testHandshake(t, config, config)
961                         if err != nil {
962                                 t.Fatalf("Handshake failed: %v", err)
963                         }
964
965                         if ss.Version != v || cs.Version != v {
966                                 t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v)
967                         }
968
969                         if !ss.HandshakeComplete || !cs.HandshakeComplete {
970                                 t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete)
971                         }
972
973                         if ss.DidResume || cs.DidResume {
974                                 t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume)
975                         }
976
977                         if ss.CipherSuite == 0 || cs.CipherSuite == 0 {
978                                 t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite)
979                         }
980
981                         if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol {
982                                 t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol)
983                         }
984
985                         if !cs.NegotiatedProtocolIsMutual {
986                                 t.Errorf("Got false NegotiatedProtocolIsMutual on the client side")
987                         }
988                         // NegotiatedProtocolIsMutual on the server side is unspecified.
989
990                         if ss.ServerName != serverName {
991                                 t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
992                         }
993                         if cs.ServerName != "" {
994                                 t.Errorf("Got unexpected server name on the client side")
995                         }
996
997                         if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
998                                 t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1)
999                         }
1000
1001                         if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 {
1002                                 t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1)
1003                         } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 {
1004                                 t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2)
1005                         }
1006
1007                         if len(cs.SignedCertificateTimestamps) != 2 {
1008                                 t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2)
1009                         }
1010                         if !bytes.Equal(cs.OCSPResponse, ocsp) {
1011                                 t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp)
1012                         }
1013                         // Only TLS 1.3 supports OCSP and SCTs on client certs.
1014                         if v == VersionTLS13 {
1015                                 if len(ss.SignedCertificateTimestamps) != 2 {
1016                                         t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2)
1017                                 }
1018                                 if !bytes.Equal(ss.OCSPResponse, ocsp) {
1019                                         t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp)
1020                                 }
1021                         }
1022
1023                         if v == VersionTLS13 {
1024                                 if ss.TLSUnique != nil || cs.TLSUnique != nil {
1025                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique)
1026                                 }
1027                         } else {
1028                                 if ss.TLSUnique == nil || cs.TLSUnique == nil {
1029                                         t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique)
1030                                 }
1031                         }
1032                 })
1033         }
1034 }