]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/tls_test.go
crypto/tls: make Config.Clone also clone the GetClientCertificate field
[gostls13.git] / src / crypto / tls / tls_test.go
1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "crypto/x509"
10         "errors"
11         "fmt"
12         "internal/testenv"
13         "io"
14         "io/ioutil"
15         "math"
16         "net"
17         "os"
18         "reflect"
19         "strings"
20         "testing"
21         "time"
22 )
23
24 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
25 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
26 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
27 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
28 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
29 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
30 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
31 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
32 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
33 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
34 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
35 -----END CERTIFICATE-----
36 `
37
38 var rsaKeyPEM = `-----BEGIN RSA PRIVATE KEY-----
39 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
40 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
41 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
42 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
43 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
44 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
45 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
46 -----END RSA PRIVATE KEY-----
47 `
48
49 // keyPEM is the same as rsaKeyPEM, but declares itself as just
50 // "PRIVATE KEY", not "RSA PRIVATE KEY".  https://golang.org/issue/4477
51 var keyPEM = `-----BEGIN PRIVATE KEY-----
52 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
53 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
54 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
55 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
56 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
57 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
58 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
59 -----END PRIVATE KEY-----
60 `
61
62 var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
63 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
64 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
65 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG
66 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk
67 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR
68 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl
69 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8
70 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo
71 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb
72 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1
73 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA==
74 -----END CERTIFICATE-----
75 `
76
77 var ecdsaKeyPEM = `-----BEGIN EC PARAMETERS-----
78 BgUrgQQAIw==
79 -----END EC PARAMETERS-----
80 -----BEGIN EC PRIVATE KEY-----
81 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
82 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
83 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
84 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
85 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
86 -----END EC PRIVATE KEY-----
87 `
88
89 var keyPairTests = []struct {
90         algo string
91         cert string
92         key  string
93 }{
94         {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM},
95         {"RSA", rsaCertPEM, rsaKeyPEM},
96         {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477
97 }
98
99 func TestX509KeyPair(t *testing.T) {
100         t.Parallel()
101         var pem []byte
102         for _, test := range keyPairTests {
103                 pem = []byte(test.cert + test.key)
104                 if _, err := X509KeyPair(pem, pem); err != nil {
105                         t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err)
106                 }
107                 pem = []byte(test.key + test.cert)
108                 if _, err := X509KeyPair(pem, pem); err != nil {
109                         t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err)
110                 }
111         }
112 }
113
114 func TestX509KeyPairErrors(t *testing.T) {
115         _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM))
116         if err == nil {
117                 t.Fatalf("X509KeyPair didn't return an error when arguments were switched")
118         }
119         if subStr := "been switched"; !strings.Contains(err.Error(), subStr) {
120                 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err)
121         }
122
123         _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM))
124         if err == nil {
125                 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates")
126         }
127         if subStr := "certificate"; !strings.Contains(err.Error(), subStr) {
128                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err)
129         }
130
131         const nonsensePEM = `
132 -----BEGIN NONSENSE-----
133 Zm9vZm9vZm9v
134 -----END NONSENSE-----
135 `
136
137         _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM))
138         if err == nil {
139                 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense")
140         }
141         if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) {
142                 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err)
143         }
144 }
145
146 func TestX509MixedKeyPair(t *testing.T) {
147         if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil {
148                 t.Error("Load of RSA certificate succeeded with ECDSA private key")
149         }
150         if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil {
151                 t.Error("Load of ECDSA certificate succeeded with RSA private key")
152         }
153 }
154
155 func newLocalListener(t testing.TB) net.Listener {
156         ln, err := net.Listen("tcp", "127.0.0.1:0")
157         if err != nil {
158                 ln, err = net.Listen("tcp6", "[::1]:0")
159         }
160         if err != nil {
161                 t.Fatal(err)
162         }
163         return ln
164 }
165
166 func TestDialTimeout(t *testing.T) {
167         if testing.Short() {
168                 t.Skip("skipping in short mode")
169         }
170         listener := newLocalListener(t)
171
172         addr := listener.Addr().String()
173         defer listener.Close()
174
175         complete := make(chan bool)
176         defer close(complete)
177
178         go func() {
179                 conn, err := listener.Accept()
180                 if err != nil {
181                         t.Error(err)
182                         return
183                 }
184                 <-complete
185                 conn.Close()
186         }()
187
188         dialer := &net.Dialer{
189                 Timeout: 10 * time.Millisecond,
190         }
191
192         var err error
193         if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
194                 t.Fatal("DialWithTimeout completed successfully")
195         }
196
197         if !isTimeoutError(err) {
198                 t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
199         }
200 }
201
202 func isTimeoutError(err error) bool {
203         if ne, ok := err.(net.Error); ok {
204                 return ne.Timeout()
205         }
206         return false
207 }
208
209 // tests that Conn.Read returns (non-zero, io.EOF) instead of
210 // (non-zero, nil) when a Close (alertCloseNotify) is sitting right
211 // behind the application data in the buffer.
212 func TestConnReadNonzeroAndEOF(t *testing.T) {
213         // This test is racy: it assumes that after a write to a
214         // localhost TCP connection, the peer TCP connection can
215         // immediately read it. Because it's racy, we skip this test
216         // in short mode, and then retry it several times with an
217         // increasing sleep in between our final write (via srv.Close
218         // below) and the following read.
219         if testing.Short() {
220                 t.Skip("skipping in short mode")
221         }
222         var err error
223         for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 {
224                 if err = testConnReadNonzeroAndEOF(t, delay); err == nil {
225                         return
226                 }
227         }
228         t.Error(err)
229 }
230
231 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error {
232         ln := newLocalListener(t)
233         defer ln.Close()
234
235         srvCh := make(chan *Conn, 1)
236         var serr error
237         go func() {
238                 sconn, err := ln.Accept()
239                 if err != nil {
240                         serr = err
241                         srvCh <- nil
242                         return
243                 }
244                 serverConfig := testConfig.Clone()
245                 srv := Server(sconn, serverConfig)
246                 if err := srv.Handshake(); err != nil {
247                         serr = fmt.Errorf("handshake: %v", err)
248                         srvCh <- nil
249                         return
250                 }
251                 srvCh <- srv
252         }()
253
254         clientConfig := testConfig.Clone()
255         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
256         if err != nil {
257                 t.Fatal(err)
258         }
259         defer conn.Close()
260
261         srv := <-srvCh
262         if srv == nil {
263                 return serr
264         }
265
266         buf := make([]byte, 6)
267
268         srv.Write([]byte("foobar"))
269         n, err := conn.Read(buf)
270         if n != 6 || err != nil || string(buf) != "foobar" {
271                 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
272         }
273
274         srv.Write([]byte("abcdef"))
275         srv.Close()
276         time.Sleep(delay)
277         n, err = conn.Read(buf)
278         if n != 6 || string(buf) != "abcdef" {
279                 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf)
280         }
281         if err != io.EOF {
282                 return fmt.Errorf("Second Read error = %v; want io.EOF", err)
283         }
284         return nil
285 }
286
287 func TestTLSUniqueMatches(t *testing.T) {
288         ln := newLocalListener(t)
289         defer ln.Close()
290
291         serverTLSUniques := make(chan []byte)
292         go func() {
293                 for i := 0; i < 2; i++ {
294                         sconn, err := ln.Accept()
295                         if err != nil {
296                                 t.Error(err)
297                                 return
298                         }
299                         serverConfig := testConfig.Clone()
300                         srv := Server(sconn, serverConfig)
301                         if err := srv.Handshake(); err != nil {
302                                 t.Error(err)
303                                 return
304                         }
305                         serverTLSUniques <- srv.ConnectionState().TLSUnique
306                 }
307         }()
308
309         clientConfig := testConfig.Clone()
310         clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
311         conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
312         if err != nil {
313                 t.Fatal(err)
314         }
315         if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) {
316                 t.Error("client and server channel bindings differ")
317         }
318         conn.Close()
319
320         conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
321         if err != nil {
322                 t.Fatal(err)
323         }
324         defer conn.Close()
325         if !conn.ConnectionState().DidResume {
326                 t.Error("second session did not use resumption")
327         }
328         if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) {
329                 t.Error("client and server channel bindings differ when session resumption is used")
330         }
331 }
332
333 func TestVerifyHostname(t *testing.T) {
334         testenv.MustHaveExternalNetwork(t)
335
336         c, err := Dial("tcp", "www.google.com:https", nil)
337         if err != nil {
338                 t.Fatal(err)
339         }
340         if err := c.VerifyHostname("www.google.com"); err != nil {
341                 t.Fatalf("verify www.google.com: %v", err)
342         }
343         if err := c.VerifyHostname("www.yahoo.com"); err == nil {
344                 t.Fatalf("verify www.yahoo.com succeeded")
345         }
346
347         c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true})
348         if err != nil {
349                 t.Fatal(err)
350         }
351         if err := c.VerifyHostname("www.google.com"); err == nil {
352                 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
353         }
354         if err := c.VerifyHostname("www.yahoo.com"); err == nil {
355                 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
356         }
357 }
358
359 func TestVerifyHostnameResumed(t *testing.T) {
360         testenv.MustHaveExternalNetwork(t)
361
362         config := &Config{
363                 ClientSessionCache: NewLRUClientSessionCache(32),
364         }
365         for i := 0; i < 2; i++ {
366                 c, err := Dial("tcp", "www.google.com:https", config)
367                 if err != nil {
368                         t.Fatalf("Dial #%d: %v", i, err)
369                 }
370                 cs := c.ConnectionState()
371                 if i > 0 && !cs.DidResume {
372                         t.Fatalf("Subsequent connection unexpectedly didn't resume")
373                 }
374                 if cs.VerifiedChains == nil {
375                         t.Fatalf("Dial #%d: cs.VerifiedChains == nil", i)
376                 }
377                 if err := c.VerifyHostname("www.google.com"); err != nil {
378                         t.Fatalf("verify www.google.com #%d: %v", i, err)
379                 }
380                 c.Close()
381         }
382 }
383
384 func TestConnCloseBreakingWrite(t *testing.T) {
385         ln := newLocalListener(t)
386         defer ln.Close()
387
388         srvCh := make(chan *Conn, 1)
389         var serr error
390         var sconn net.Conn
391         go func() {
392                 var err error
393                 sconn, err = ln.Accept()
394                 if err != nil {
395                         serr = err
396                         srvCh <- nil
397                         return
398                 }
399                 serverConfig := testConfig.Clone()
400                 srv := Server(sconn, serverConfig)
401                 if err := srv.Handshake(); err != nil {
402                         serr = fmt.Errorf("handshake: %v", err)
403                         srvCh <- nil
404                         return
405                 }
406                 srvCh <- srv
407         }()
408
409         cconn, err := net.Dial("tcp", ln.Addr().String())
410         if err != nil {
411                 t.Fatal(err)
412         }
413         defer cconn.Close()
414
415         conn := &changeImplConn{
416                 Conn: cconn,
417         }
418
419         clientConfig := testConfig.Clone()
420         tconn := Client(conn, clientConfig)
421         if err := tconn.Handshake(); err != nil {
422                 t.Fatal(err)
423         }
424
425         srv := <-srvCh
426         if srv == nil {
427                 t.Fatal(serr)
428         }
429         defer sconn.Close()
430
431         connClosed := make(chan struct{})
432         conn.closeFunc = func() error {
433                 close(connClosed)
434                 return nil
435         }
436
437         inWrite := make(chan bool, 1)
438         var errConnClosed = errors.New("conn closed for test")
439         conn.writeFunc = func(p []byte) (n int, err error) {
440                 inWrite <- true
441                 <-connClosed
442                 return 0, errConnClosed
443         }
444
445         closeReturned := make(chan bool, 1)
446         go func() {
447                 <-inWrite
448                 tconn.Close() // test that this doesn't block forever.
449                 closeReturned <- true
450         }()
451
452         _, err = tconn.Write([]byte("foo"))
453         if err != errConnClosed {
454                 t.Errorf("Write error = %v; want errConnClosed", err)
455         }
456
457         <-closeReturned
458         if err := tconn.Close(); err != errClosed {
459                 t.Errorf("Close error = %v; want errClosed", err)
460         }
461 }
462
463 func TestConnCloseWrite(t *testing.T) {
464         ln := newLocalListener(t)
465         defer ln.Close()
466
467         clientDoneChan := make(chan struct{})
468
469         serverCloseWrite := func() error {
470                 sconn, err := ln.Accept()
471                 if err != nil {
472                         return fmt.Errorf("accept: %v", err)
473                 }
474                 defer sconn.Close()
475
476                 serverConfig := testConfig.Clone()
477                 srv := Server(sconn, serverConfig)
478                 if err := srv.Handshake(); err != nil {
479                         return fmt.Errorf("handshake: %v", err)
480                 }
481                 defer srv.Close()
482
483                 data, err := ioutil.ReadAll(srv)
484                 if err != nil {
485                         return err
486                 }
487                 if len(data) > 0 {
488                         return fmt.Errorf("Read data = %q; want nothing", data)
489                 }
490
491                 if err := srv.CloseWrite(); err != nil {
492                         return fmt.Errorf("server CloseWrite: %v", err)
493                 }
494
495                 // Wait for clientCloseWrite to finish, so we know we
496                 // tested the CloseWrite before we defer the
497                 // sconn.Close above, which would also cause the
498                 // client to unblock like CloseWrite.
499                 <-clientDoneChan
500                 return nil
501         }
502
503         clientCloseWrite := func() error {
504                 defer close(clientDoneChan)
505
506                 clientConfig := testConfig.Clone()
507                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
508                 if err != nil {
509                         return err
510                 }
511                 if err := conn.Handshake(); err != nil {
512                         return err
513                 }
514                 defer conn.Close()
515
516                 if err := conn.CloseWrite(); err != nil {
517                         return fmt.Errorf("client CloseWrite: %v", err)
518                 }
519
520                 if _, err := conn.Write([]byte{0}); err != errShutdown {
521                         return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
522                 }
523
524                 data, err := ioutil.ReadAll(conn)
525                 if err != nil {
526                         return err
527                 }
528                 if len(data) > 0 {
529                         return fmt.Errorf("Read data = %q; want nothing", data)
530                 }
531                 return nil
532         }
533
534         errChan := make(chan error, 2)
535
536         go func() { errChan <- serverCloseWrite() }()
537         go func() { errChan <- clientCloseWrite() }()
538
539         for i := 0; i < 2; i++ {
540                 select {
541                 case err := <-errChan:
542                         if err != nil {
543                                 t.Fatal(err)
544                         }
545                 case <-time.After(10 * time.Second):
546                         t.Fatal("deadlock")
547                 }
548         }
549
550         // Also test CloseWrite being called before the handshake is
551         // finished:
552         {
553                 ln2 := newLocalListener(t)
554                 defer ln2.Close()
555
556                 netConn, err := net.Dial("tcp", ln2.Addr().String())
557                 if err != nil {
558                         t.Fatal(err)
559                 }
560                 defer netConn.Close()
561                 conn := Client(netConn, testConfig.Clone())
562
563                 if err := conn.CloseWrite(); err != errEarlyCloseWrite {
564                         t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
565                 }
566         }
567 }
568
569 func TestCloneFuncFields(t *testing.T) {
570         const expectedCount = 5
571         called := 0
572
573         c1 := Config{
574                 Time: func() time.Time {
575                         called |= 1 << 0
576                         return time.Time{}
577                 },
578                 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
579                         called |= 1 << 1
580                         return nil, nil
581                 },
582                 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) {
583                         called |= 1 << 2
584                         return nil, nil
585                 },
586                 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) {
587                         called |= 1 << 3
588                         return nil, nil
589                 },
590                 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
591                         called |= 1 << 4
592                         return nil
593                 },
594         }
595
596         c2 := c1.Clone()
597
598         c2.Time()
599         c2.GetCertificate(nil)
600         c2.GetClientCertificate(nil)
601         c2.GetConfigForClient(nil)
602         c2.VerifyPeerCertificate(nil, nil)
603
604         if called != (1<<expectedCount)-1 {
605                 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
606         }
607 }
608
609 func TestCloneNonFuncFields(t *testing.T) {
610         var c1 Config
611         v := reflect.ValueOf(&c1).Elem()
612
613         typ := v.Type()
614         for i := 0; i < typ.NumField(); i++ {
615                 f := v.Field(i)
616                 if !f.CanSet() {
617                         // unexported field; not cloned.
618                         continue
619                 }
620
621                 // testing/quick can't handle functions or interfaces and so
622                 // isn't used here.
623                 switch fn := typ.Field(i).Name; fn {
624                 case "Rand":
625                         f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
626                 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
627                         // DeepEqual can't compare functions. If you add a
628                         // function field to this list, you must also change
629                         // TestCloneFuncFields to ensure that the func field is
630                         // cloned.
631                 case "Certificates":
632                         f.Set(reflect.ValueOf([]Certificate{
633                                 {Certificate: [][]byte{{'b'}}},
634                         }))
635                 case "NameToCertificate":
636                         f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil}))
637                 case "RootCAs", "ClientCAs":
638                         f.Set(reflect.ValueOf(x509.NewCertPool()))
639                 case "ClientSessionCache":
640                         f.Set(reflect.ValueOf(NewLRUClientSessionCache(10)))
641                 case "KeyLogWriter":
642                         f.Set(reflect.ValueOf(io.Writer(os.Stdout)))
643                 case "NextProtos":
644                         f.Set(reflect.ValueOf([]string{"a", "b"}))
645                 case "ServerName":
646                         f.Set(reflect.ValueOf("b"))
647                 case "ClientAuth":
648                         f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
649                 case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
650                         f.Set(reflect.ValueOf(true))
651                 case "MinVersion", "MaxVersion":
652                         f.Set(reflect.ValueOf(uint16(VersionTLS12)))
653                 case "SessionTicketKey":
654                         f.Set(reflect.ValueOf([32]byte{}))
655                 case "CipherSuites":
656                         f.Set(reflect.ValueOf([]uint16{1, 2}))
657                 case "CurvePreferences":
658                         f.Set(reflect.ValueOf([]CurveID{CurveP256}))
659                 case "Renegotiation":
660                         f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
661                 default:
662                         t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
663                 }
664         }
665
666         c2 := c1.Clone()
667         // DeepEqual also compares unexported fields, thus c2 needs to have run
668         // serverInit in order to be DeepEqual to c1. Cloning it and discarding
669         // the result is sufficient.
670         c2.Clone()
671
672         if !reflect.DeepEqual(&c1, c2) {
673                 t.Errorf("clone failed to copy a field")
674         }
675 }
676
677 // changeImplConn is a net.Conn which can change its Write and Close
678 // methods.
679 type changeImplConn struct {
680         net.Conn
681         writeFunc func([]byte) (int, error)
682         closeFunc func() error
683 }
684
685 func (w *changeImplConn) Write(p []byte) (n int, err error) {
686         if w.writeFunc != nil {
687                 return w.writeFunc(p)
688         }
689         return w.Conn.Write(p)
690 }
691
692 func (w *changeImplConn) Close() error {
693         if w.closeFunc != nil {
694                 return w.closeFunc()
695         }
696         return w.Conn.Close()
697 }
698
699 func throughput(b *testing.B, totalBytes int64, dynamicRecordSizingDisabled bool) {
700         ln := newLocalListener(b)
701         defer ln.Close()
702
703         N := b.N
704
705         // Less than 64KB because Windows appears to use a TCP rwin < 64KB.
706         // See Issue #15899.
707         const bufsize = 32 << 10
708
709         go func() {
710                 buf := make([]byte, bufsize)
711                 for i := 0; i < N; i++ {
712                         sconn, err := ln.Accept()
713                         if err != nil {
714                                 // panic rather than synchronize to avoid benchmark overhead
715                                 // (cannot call b.Fatal in goroutine)
716                                 panic(fmt.Errorf("accept: %v", err))
717                         }
718                         serverConfig := testConfig.Clone()
719                         serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers
720                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
721                         srv := Server(sconn, serverConfig)
722                         if err := srv.Handshake(); err != nil {
723                                 panic(fmt.Errorf("handshake: %v", err))
724                         }
725                         if _, err := io.CopyBuffer(srv, srv, buf); err != nil {
726                                 panic(fmt.Errorf("copy buffer: %v", err))
727                         }
728                 }
729         }()
730
731         b.SetBytes(totalBytes)
732         clientConfig := testConfig.Clone()
733         clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers
734         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
735
736         buf := make([]byte, bufsize)
737         chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
738         for i := 0; i < N; i++ {
739                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
740                 if err != nil {
741                         b.Fatal(err)
742                 }
743                 for j := 0; j < chunks; j++ {
744                         _, err := conn.Write(buf)
745                         if err != nil {
746                                 b.Fatal(err)
747                         }
748                         _, err = io.ReadFull(conn, buf)
749                         if err != nil {
750                                 b.Fatal(err)
751                         }
752                 }
753                 conn.Close()
754         }
755 }
756
757 func BenchmarkThroughput(b *testing.B) {
758         for _, mode := range []string{"Max", "Dynamic"} {
759                 for size := 1; size <= 64; size <<= 1 {
760                         name := fmt.Sprintf("%sPacket/%dMB", mode, size)
761                         b.Run(name, func(b *testing.B) {
762                                 throughput(b, int64(size<<20), mode == "Max")
763                         })
764                 }
765         }
766 }
767
768 type slowConn struct {
769         net.Conn
770         bps int
771 }
772
773 func (c *slowConn) Write(p []byte) (int, error) {
774         if c.bps == 0 {
775                 panic("too slow")
776         }
777         t0 := time.Now()
778         wrote := 0
779         for wrote < len(p) {
780                 time.Sleep(100 * time.Microsecond)
781                 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8
782                 if allowed > len(p) {
783                         allowed = len(p)
784                 }
785                 if wrote < allowed {
786                         n, err := c.Conn.Write(p[wrote:allowed])
787                         wrote += n
788                         if err != nil {
789                                 return wrote, err
790                         }
791                 }
792         }
793         return len(p), nil
794 }
795
796 func latency(b *testing.B, bps int, dynamicRecordSizingDisabled bool) {
797         ln := newLocalListener(b)
798         defer ln.Close()
799
800         N := b.N
801
802         go func() {
803                 for i := 0; i < N; i++ {
804                         sconn, err := ln.Accept()
805                         if err != nil {
806                                 // panic rather than synchronize to avoid benchmark overhead
807                                 // (cannot call b.Fatal in goroutine)
808                                 panic(fmt.Errorf("accept: %v", err))
809                         }
810                         serverConfig := testConfig.Clone()
811                         serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
812                         srv := Server(&slowConn{sconn, bps}, serverConfig)
813                         if err := srv.Handshake(); err != nil {
814                                 panic(fmt.Errorf("handshake: %v", err))
815                         }
816                         io.Copy(srv, srv)
817                 }
818         }()
819
820         clientConfig := testConfig.Clone()
821         clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
822
823         buf := make([]byte, 16384)
824         peek := make([]byte, 1)
825
826         for i := 0; i < N; i++ {
827                 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
828                 if err != nil {
829                         b.Fatal(err)
830                 }
831                 // make sure we're connected and previous connection has stopped
832                 if _, err := conn.Write(buf[:1]); err != nil {
833                         b.Fatal(err)
834                 }
835                 if _, err := io.ReadFull(conn, peek); err != nil {
836                         b.Fatal(err)
837                 }
838                 if _, err := conn.Write(buf); err != nil {
839                         b.Fatal(err)
840                 }
841                 if _, err = io.ReadFull(conn, peek); err != nil {
842                         b.Fatal(err)
843                 }
844                 conn.Close()
845         }
846 }
847
848 func BenchmarkLatency(b *testing.B) {
849         for _, mode := range []string{"Max", "Dynamic"} {
850                 for _, kbps := range []int{200, 500, 1000, 2000, 5000} {
851                         name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps)
852                         b.Run(name, func(b *testing.B) {
853                                 latency(b, kbps*1000, mode == "Max")
854                         })
855                 }
856         }
857 }