]> Cypherpunks.ru repositories - gostls13.git/blob - src/net/http/client_test.go
[dev.ssa] Merge remote-tracking branch 'origin/master' into ssamerge
[gostls13.git] / src / net / http / client_test.go
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // Tests for client.go
6
7 package http_test
8
9 import (
10         "bytes"
11         "crypto/tls"
12         "crypto/x509"
13         "encoding/base64"
14         "errors"
15         "fmt"
16         "io"
17         "io/ioutil"
18         "log"
19         "net"
20         . "net/http"
21         "net/http/httptest"
22         "net/url"
23         "strconv"
24         "strings"
25         "sync"
26         "testing"
27         "time"
28 )
29
30 var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
31         w.Header().Set("Last-Modified", "sometime")
32         fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
33 })
34
35 // pedanticReadAll works like ioutil.ReadAll but additionally
36 // verifies that r obeys the documented io.Reader contract.
37 func pedanticReadAll(r io.Reader) (b []byte, err error) {
38         var bufa [64]byte
39         buf := bufa[:]
40         for {
41                 n, err := r.Read(buf)
42                 if n == 0 && err == nil {
43                         return nil, fmt.Errorf("Read: n=0 with err=nil")
44                 }
45                 b = append(b, buf[:n]...)
46                 if err == io.EOF {
47                         n, err := r.Read(buf)
48                         if n != 0 || err != io.EOF {
49                                 return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err)
50                         }
51                         return b, nil
52                 }
53                 if err != nil {
54                         return b, err
55                 }
56         }
57 }
58
59 type chanWriter chan string
60
61 func (w chanWriter) Write(p []byte) (n int, err error) {
62         w <- string(p)
63         return len(p), nil
64 }
65
66 func TestClient(t *testing.T) {
67         defer afterTest(t)
68         ts := httptest.NewServer(robotsTxtHandler)
69         defer ts.Close()
70
71         r, err := Get(ts.URL)
72         var b []byte
73         if err == nil {
74                 b, err = pedanticReadAll(r.Body)
75                 r.Body.Close()
76         }
77         if err != nil {
78                 t.Error(err)
79         } else if s := string(b); !strings.HasPrefix(s, "User-agent:") {
80                 t.Errorf("Incorrect page body (did not begin with User-agent): %q", s)
81         }
82 }
83
84 func TestClientHead_h1(t *testing.T) { testClientHead(t, h1Mode) }
85 func TestClientHead_h2(t *testing.T) { testClientHead(t, h2Mode) }
86
87 func testClientHead(t *testing.T, h2 bool) {
88         defer afterTest(t)
89         cst := newClientServerTest(t, h2, robotsTxtHandler)
90         defer cst.close()
91
92         r, err := cst.c.Head(cst.ts.URL)
93         if err != nil {
94                 t.Fatal(err)
95         }
96         if _, ok := r.Header["Last-Modified"]; !ok {
97                 t.Error("Last-Modified header not found.")
98         }
99 }
100
101 type recordingTransport struct {
102         req *Request
103 }
104
105 func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) {
106         t.req = req
107         return nil, errors.New("dummy impl")
108 }
109
110 func TestGetRequestFormat(t *testing.T) {
111         defer afterTest(t)
112         tr := &recordingTransport{}
113         client := &Client{Transport: tr}
114         url := "http://dummy.faketld/"
115         client.Get(url) // Note: doesn't hit network
116         if tr.req.Method != "GET" {
117                 t.Errorf("expected method %q; got %q", "GET", tr.req.Method)
118         }
119         if tr.req.URL.String() != url {
120                 t.Errorf("expected URL %q; got %q", url, tr.req.URL.String())
121         }
122         if tr.req.Header == nil {
123                 t.Errorf("expected non-nil request Header")
124         }
125 }
126
127 func TestPostRequestFormat(t *testing.T) {
128         defer afterTest(t)
129         tr := &recordingTransport{}
130         client := &Client{Transport: tr}
131
132         url := "http://dummy.faketld/"
133         json := `{"key":"value"}`
134         b := strings.NewReader(json)
135         client.Post(url, "application/json", b) // Note: doesn't hit network
136
137         if tr.req.Method != "POST" {
138                 t.Errorf("got method %q, want %q", tr.req.Method, "POST")
139         }
140         if tr.req.URL.String() != url {
141                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
142         }
143         if tr.req.Header == nil {
144                 t.Fatalf("expected non-nil request Header")
145         }
146         if tr.req.Close {
147                 t.Error("got Close true, want false")
148         }
149         if g, e := tr.req.ContentLength, int64(len(json)); g != e {
150                 t.Errorf("got ContentLength %d, want %d", g, e)
151         }
152 }
153
154 func TestPostFormRequestFormat(t *testing.T) {
155         defer afterTest(t)
156         tr := &recordingTransport{}
157         client := &Client{Transport: tr}
158
159         urlStr := "http://dummy.faketld/"
160         form := make(url.Values)
161         form.Set("foo", "bar")
162         form.Add("foo", "bar2")
163         form.Set("bar", "baz")
164         client.PostForm(urlStr, form) // Note: doesn't hit network
165
166         if tr.req.Method != "POST" {
167                 t.Errorf("got method %q, want %q", tr.req.Method, "POST")
168         }
169         if tr.req.URL.String() != urlStr {
170                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), urlStr)
171         }
172         if tr.req.Header == nil {
173                 t.Fatalf("expected non-nil request Header")
174         }
175         if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e {
176                 t.Errorf("got Content-Type %q, want %q", g, e)
177         }
178         if tr.req.Close {
179                 t.Error("got Close true, want false")
180         }
181         // Depending on map iteration, body can be either of these.
182         expectedBody := "foo=bar&foo=bar2&bar=baz"
183         expectedBody1 := "bar=baz&foo=bar&foo=bar2"
184         if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e {
185                 t.Errorf("got ContentLength %d, want %d", g, e)
186         }
187         bodyb, err := ioutil.ReadAll(tr.req.Body)
188         if err != nil {
189                 t.Fatalf("ReadAll on req.Body: %v", err)
190         }
191         if g := string(bodyb); g != expectedBody && g != expectedBody1 {
192                 t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1)
193         }
194 }
195
196 func TestClientRedirects(t *testing.T) {
197         defer afterTest(t)
198         var ts *httptest.Server
199         ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
200                 n, _ := strconv.Atoi(r.FormValue("n"))
201                 // Test Referer header. (7 is arbitrary position to test at)
202                 if n == 7 {
203                         if g, e := r.Referer(), ts.URL+"/?n=6"; e != g {
204                                 t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g)
205                         }
206                 }
207                 if n < 15 {
208                         Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound)
209                         return
210                 }
211                 fmt.Fprintf(w, "n=%d", n)
212         }))
213         defer ts.Close()
214
215         c := &Client{}
216         _, err := c.Get(ts.URL)
217         if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
218                 t.Errorf("with default client Get, expected error %q, got %q", e, g)
219         }
220
221         // HEAD request should also have the ability to follow redirects.
222         _, err = c.Head(ts.URL)
223         if e, g := "Head /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
224                 t.Errorf("with default client Head, expected error %q, got %q", e, g)
225         }
226
227         // Do should also follow redirects.
228         greq, _ := NewRequest("GET", ts.URL, nil)
229         _, err = c.Do(greq)
230         if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
231                 t.Errorf("with default client Do, expected error %q, got %q", e, g)
232         }
233
234         // Requests with an empty Method should also redirect (Issue 12705)
235         greq.Method = ""
236         _, err = c.Do(greq)
237         if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
238                 t.Errorf("with default client Do and empty Method, expected error %q, got %q", e, g)
239         }
240
241         var checkErr error
242         var lastVia []*Request
243         var lastReq *Request
244         c = &Client{CheckRedirect: func(req *Request, via []*Request) error {
245                 lastReq = req
246                 lastVia = via
247                 return checkErr
248         }}
249         res, err := c.Get(ts.URL)
250         if err != nil {
251                 t.Fatalf("Get error: %v", err)
252         }
253         res.Body.Close()
254         finalUrl := res.Request.URL.String()
255         if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
256                 t.Errorf("with custom client, expected error %q, got %q", e, g)
257         }
258         if !strings.HasSuffix(finalUrl, "/?n=15") {
259                 t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl)
260         }
261         if e, g := 15, len(lastVia); e != g {
262                 t.Errorf("expected lastVia to have contained %d elements; got %d", e, g)
263         }
264
265         // Test that Request.Cancel is propagated between requests (Issue 14053)
266         creq, _ := NewRequest("HEAD", ts.URL, nil)
267         cancel := make(chan struct{})
268         creq.Cancel = cancel
269         if _, err := c.Do(creq); err != nil {
270                 t.Fatal(err)
271         }
272         if lastReq == nil {
273                 t.Fatal("didn't see redirect")
274         }
275         if lastReq.Cancel != cancel {
276                 t.Errorf("expected lastReq to have the cancel channel set on the initial req")
277         }
278
279         checkErr = errors.New("no redirects allowed")
280         res, err = c.Get(ts.URL)
281         if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr {
282                 t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err)
283         }
284         if res == nil {
285                 t.Fatalf("Expected a non-nil Response on CheckRedirect failure (https://golang.org/issue/3795)")
286         }
287         res.Body.Close()
288         if res.Header.Get("Location") == "" {
289                 t.Errorf("no Location header in Response")
290         }
291 }
292
293 func TestPostRedirects(t *testing.T) {
294         defer afterTest(t)
295         var log struct {
296                 sync.Mutex
297                 bytes.Buffer
298         }
299         var ts *httptest.Server
300         ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
301                 log.Lock()
302                 fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI)
303                 log.Unlock()
304                 if v := r.URL.Query().Get("code"); v != "" {
305                         code, _ := strconv.Atoi(v)
306                         if code/100 == 3 {
307                                 w.Header().Set("Location", ts.URL)
308                         }
309                         w.WriteHeader(code)
310                 }
311         }))
312         defer ts.Close()
313         tests := []struct {
314                 suffix string
315                 want   int // response code
316         }{
317                 {"/", 200},
318                 {"/?code=301", 301},
319                 {"/?code=302", 200},
320                 {"/?code=303", 200},
321                 {"/?code=404", 404},
322         }
323         for _, tt := range tests {
324                 res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content"))
325                 if err != nil {
326                         t.Fatal(err)
327                 }
328                 if res.StatusCode != tt.want {
329                         t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
330                 }
331         }
332         log.Lock()
333         got := log.String()
334         log.Unlock()
335         want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 "
336         if got != want {
337                 t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want)
338         }
339 }
340
341 var expectedCookies = []*Cookie{
342         {Name: "ChocolateChip", Value: "tasty"},
343         {Name: "First", Value: "Hit"},
344         {Name: "Second", Value: "Hit"},
345 }
346
347 var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
348         for _, cookie := range r.Cookies() {
349                 SetCookie(w, cookie)
350         }
351         if r.URL.Path == "/" {
352                 SetCookie(w, expectedCookies[1])
353                 Redirect(w, r, "/second", StatusMovedPermanently)
354         } else {
355                 SetCookie(w, expectedCookies[2])
356                 w.Write([]byte("hello"))
357         }
358 })
359
360 func TestClientSendsCookieFromJar(t *testing.T) {
361         defer afterTest(t)
362         tr := &recordingTransport{}
363         client := &Client{Transport: tr}
364         client.Jar = &TestJar{perURL: make(map[string][]*Cookie)}
365         us := "http://dummy.faketld/"
366         u, _ := url.Parse(us)
367         client.Jar.SetCookies(u, expectedCookies)
368
369         client.Get(us) // Note: doesn't hit network
370         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
371
372         client.Head(us) // Note: doesn't hit network
373         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
374
375         client.Post(us, "text/plain", strings.NewReader("body")) // Note: doesn't hit network
376         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
377
378         client.PostForm(us, url.Values{}) // Note: doesn't hit network
379         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
380
381         req, _ := NewRequest("GET", us, nil)
382         client.Do(req) // Note: doesn't hit network
383         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
384
385         req, _ = NewRequest("POST", us, nil)
386         client.Do(req) // Note: doesn't hit network
387         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
388 }
389
390 // Just enough correctness for our redirect tests. Uses the URL.Host as the
391 // scope of all cookies.
392 type TestJar struct {
393         m      sync.Mutex
394         perURL map[string][]*Cookie
395 }
396
397 func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) {
398         j.m.Lock()
399         defer j.m.Unlock()
400         if j.perURL == nil {
401                 j.perURL = make(map[string][]*Cookie)
402         }
403         j.perURL[u.Host] = cookies
404 }
405
406 func (j *TestJar) Cookies(u *url.URL) []*Cookie {
407         j.m.Lock()
408         defer j.m.Unlock()
409         return j.perURL[u.Host]
410 }
411
412 func TestRedirectCookiesJar(t *testing.T) {
413         defer afterTest(t)
414         var ts *httptest.Server
415         ts = httptest.NewServer(echoCookiesRedirectHandler)
416         defer ts.Close()
417         c := &Client{
418                 Jar: new(TestJar),
419         }
420         u, _ := url.Parse(ts.URL)
421         c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
422         resp, err := c.Get(ts.URL)
423         if err != nil {
424                 t.Fatalf("Get: %v", err)
425         }
426         resp.Body.Close()
427         matchReturnedCookies(t, expectedCookies, resp.Cookies())
428 }
429
430 func matchReturnedCookies(t *testing.T, expected, given []*Cookie) {
431         if len(given) != len(expected) {
432                 t.Logf("Received cookies: %v", given)
433                 t.Errorf("Expected %d cookies, got %d", len(expected), len(given))
434         }
435         for _, ec := range expected {
436                 foundC := false
437                 for _, c := range given {
438                         if ec.Name == c.Name && ec.Value == c.Value {
439                                 foundC = true
440                                 break
441                         }
442                 }
443                 if !foundC {
444                         t.Errorf("Missing cookie %v", ec)
445                 }
446         }
447 }
448
449 func TestJarCalls(t *testing.T) {
450         defer afterTest(t)
451         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
452                 pathSuffix := r.RequestURI[1:]
453                 if r.RequestURI == "/nosetcookie" {
454                         return // don't set cookies for this path
455                 }
456                 SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix})
457                 if r.RequestURI == "/" {
458                         Redirect(w, r, "http://secondhost.fake/secondpath", 302)
459                 }
460         }))
461         defer ts.Close()
462         jar := new(RecordingJar)
463         c := &Client{
464                 Jar: jar,
465                 Transport: &Transport{
466                         Dial: func(_ string, _ string) (net.Conn, error) {
467                                 return net.Dial("tcp", ts.Listener.Addr().String())
468                         },
469                 },
470         }
471         _, err := c.Get("http://firsthost.fake/")
472         if err != nil {
473                 t.Fatal(err)
474         }
475         _, err = c.Get("http://firsthost.fake/nosetcookie")
476         if err != nil {
477                 t.Fatal(err)
478         }
479         got := jar.log.String()
480         want := `Cookies("http://firsthost.fake/")
481 SetCookie("http://firsthost.fake/", [name=val])
482 Cookies("http://secondhost.fake/secondpath")
483 SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath])
484 Cookies("http://firsthost.fake/nosetcookie")
485 `
486         if got != want {
487                 t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want)
488         }
489 }
490
491 // RecordingJar keeps a log of calls made to it, without
492 // tracking any cookies.
493 type RecordingJar struct {
494         mu  sync.Mutex
495         log bytes.Buffer
496 }
497
498 func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) {
499         j.logf("SetCookie(%q, %v)\n", u, cookies)
500 }
501
502 func (j *RecordingJar) Cookies(u *url.URL) []*Cookie {
503         j.logf("Cookies(%q)\n", u)
504         return nil
505 }
506
507 func (j *RecordingJar) logf(format string, args ...interface{}) {
508         j.mu.Lock()
509         defer j.mu.Unlock()
510         fmt.Fprintf(&j.log, format, args...)
511 }
512
513 func TestStreamingGet_h1(t *testing.T) { testStreamingGet(t, h1Mode) }
514 func TestStreamingGet_h2(t *testing.T) { testStreamingGet(t, h2Mode) }
515
516 func testStreamingGet(t *testing.T, h2 bool) {
517         defer afterTest(t)
518         say := make(chan string)
519         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
520                 w.(Flusher).Flush()
521                 for str := range say {
522                         w.Write([]byte(str))
523                         w.(Flusher).Flush()
524                 }
525         }))
526         defer cst.close()
527
528         c := cst.c
529         res, err := c.Get(cst.ts.URL)
530         if err != nil {
531                 t.Fatal(err)
532         }
533         var buf [10]byte
534         for _, str := range []string{"i", "am", "also", "known", "as", "comet"} {
535                 say <- str
536                 n, err := io.ReadFull(res.Body, buf[0:len(str)])
537                 if err != nil {
538                         t.Fatalf("ReadFull on %q: %v", str, err)
539                 }
540                 if n != len(str) {
541                         t.Fatalf("Receiving %q, only read %d bytes", str, n)
542                 }
543                 got := string(buf[0:n])
544                 if got != str {
545                         t.Fatalf("Expected %q, got %q", str, got)
546                 }
547         }
548         close(say)
549         _, err = io.ReadFull(res.Body, buf[0:1])
550         if err != io.EOF {
551                 t.Fatalf("at end expected EOF, got %v", err)
552         }
553 }
554
555 type writeCountingConn struct {
556         net.Conn
557         count *int
558 }
559
560 func (c *writeCountingConn) Write(p []byte) (int, error) {
561         *c.count++
562         return c.Conn.Write(p)
563 }
564
565 // TestClientWrites verifies that client requests are buffered and we
566 // don't send a TCP packet per line of the http request + body.
567 func TestClientWrites(t *testing.T) {
568         defer afterTest(t)
569         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
570         }))
571         defer ts.Close()
572
573         writes := 0
574         dialer := func(netz string, addr string) (net.Conn, error) {
575                 c, err := net.Dial(netz, addr)
576                 if err == nil {
577                         c = &writeCountingConn{c, &writes}
578                 }
579                 return c, err
580         }
581         c := &Client{Transport: &Transport{Dial: dialer}}
582
583         _, err := c.Get(ts.URL)
584         if err != nil {
585                 t.Fatal(err)
586         }
587         if writes != 1 {
588                 t.Errorf("Get request did %d Write calls, want 1", writes)
589         }
590
591         writes = 0
592         _, err = c.PostForm(ts.URL, url.Values{"foo": {"bar"}})
593         if err != nil {
594                 t.Fatal(err)
595         }
596         if writes != 1 {
597                 t.Errorf("Post request did %d Write calls, want 1", writes)
598         }
599 }
600
601 func TestClientInsecureTransport(t *testing.T) {
602         defer afterTest(t)
603         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
604                 w.Write([]byte("Hello"))
605         }))
606         errc := make(chanWriter, 10) // but only expecting 1
607         ts.Config.ErrorLog = log.New(errc, "", 0)
608         defer ts.Close()
609
610         // TODO(bradfitz): add tests for skipping hostname checks too?
611         // would require a new cert for testing, and probably
612         // redundant with these tests.
613         for _, insecure := range []bool{true, false} {
614                 tr := &Transport{
615                         TLSClientConfig: &tls.Config{
616                                 InsecureSkipVerify: insecure,
617                         },
618                 }
619                 defer tr.CloseIdleConnections()
620                 c := &Client{Transport: tr}
621                 res, err := c.Get(ts.URL)
622                 if (err == nil) != insecure {
623                         t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
624                 }
625                 if res != nil {
626                         res.Body.Close()
627                 }
628         }
629
630         select {
631         case v := <-errc:
632                 if !strings.Contains(v, "TLS handshake error") {
633                         t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
634                 }
635         case <-time.After(5 * time.Second):
636                 t.Errorf("timeout waiting for logged error")
637         }
638
639 }
640
641 func TestClientErrorWithRequestURI(t *testing.T) {
642         defer afterTest(t)
643         req, _ := NewRequest("GET", "http://localhost:1234/", nil)
644         req.RequestURI = "/this/field/is/illegal/and/should/error/"
645         _, err := DefaultClient.Do(req)
646         if err == nil {
647                 t.Fatalf("expected an error")
648         }
649         if !strings.Contains(err.Error(), "RequestURI") {
650                 t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
651         }
652 }
653
654 func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
655         certs := x509.NewCertPool()
656         for _, c := range ts.TLS.Certificates {
657                 roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
658                 if err != nil {
659                         t.Fatalf("error parsing server's root cert: %v", err)
660                 }
661                 for _, root := range roots {
662                         certs.AddCert(root)
663                 }
664         }
665         return &Transport{
666                 TLSClientConfig: &tls.Config{RootCAs: certs},
667         }
668 }
669
670 func TestClientWithCorrectTLSServerName(t *testing.T) {
671         defer afterTest(t)
672
673         const serverName = "example.com"
674         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
675                 if r.TLS.ServerName != serverName {
676                         t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName)
677                 }
678         }))
679         defer ts.Close()
680
681         trans := newTLSTransport(t, ts)
682         trans.TLSClientConfig.ServerName = serverName
683         c := &Client{Transport: trans}
684         if _, err := c.Get(ts.URL); err != nil {
685                 t.Fatalf("expected successful TLS connection, got error: %v", err)
686         }
687 }
688
689 func TestClientWithIncorrectTLSServerName(t *testing.T) {
690         defer afterTest(t)
691         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
692         defer ts.Close()
693         errc := make(chanWriter, 10) // but only expecting 1
694         ts.Config.ErrorLog = log.New(errc, "", 0)
695
696         trans := newTLSTransport(t, ts)
697         trans.TLSClientConfig.ServerName = "badserver"
698         c := &Client{Transport: trans}
699         _, err := c.Get(ts.URL)
700         if err == nil {
701                 t.Fatalf("expected an error")
702         }
703         if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
704                 t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
705         }
706         select {
707         case v := <-errc:
708                 if !strings.Contains(v, "TLS handshake error") {
709                         t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
710                 }
711         case <-time.After(5 * time.Second):
712                 t.Errorf("timeout waiting for logged error")
713         }
714 }
715
716 // Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName
717 // when not empty.
718 //
719 // tls.Config.ServerName (non-empty, set to "example.com") takes
720 // precedence over "some-other-host.tld" which previously incorrectly
721 // took precedence. We don't actually connect to (or even resolve)
722 // "some-other-host.tld", though, because of the Transport.Dial hook.
723 //
724 // The httptest.Server has a cert with "example.com" as its name.
725 func TestTransportUsesTLSConfigServerName(t *testing.T) {
726         defer afterTest(t)
727         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
728                 w.Write([]byte("Hello"))
729         }))
730         defer ts.Close()
731
732         tr := newTLSTransport(t, ts)
733         tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names
734         tr.Dial = func(netw, addr string) (net.Conn, error) {
735                 return net.Dial(netw, ts.Listener.Addr().String())
736         }
737         defer tr.CloseIdleConnections()
738         c := &Client{Transport: tr}
739         res, err := c.Get("https://some-other-host.tld/")
740         if err != nil {
741                 t.Fatal(err)
742         }
743         res.Body.Close()
744 }
745
746 func TestResponseSetsTLSConnectionState(t *testing.T) {
747         defer afterTest(t)
748         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
749                 w.Write([]byte("Hello"))
750         }))
751         defer ts.Close()
752
753         tr := newTLSTransport(t, ts)
754         tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA}
755         tr.Dial = func(netw, addr string) (net.Conn, error) {
756                 return net.Dial(netw, ts.Listener.Addr().String())
757         }
758         defer tr.CloseIdleConnections()
759         c := &Client{Transport: tr}
760         res, err := c.Get("https://example.com/")
761         if err != nil {
762                 t.Fatal(err)
763         }
764         defer res.Body.Close()
765         if res.TLS == nil {
766                 t.Fatal("Response didn't set TLS Connection State.")
767         }
768         if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want {
769                 t.Errorf("TLS Cipher Suite = %d; want %d", got, want)
770         }
771 }
772
773 // Check that an HTTPS client can interpret a particular TLS error
774 // to determine that the server is speaking HTTP.
775 // See golang.org/issue/11111.
776 func TestHTTPSClientDetectsHTTPServer(t *testing.T) {
777         defer afterTest(t)
778         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
779         defer ts.Close()
780
781         _, err := Get(strings.Replace(ts.URL, "http", "https", 1))
782         if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") {
783                 t.Fatalf("error = %q; want error indicating HTTP response to HTTPS request", got)
784         }
785 }
786
787 // Verify Response.ContentLength is populated. https://golang.org/issue/4126
788 func TestClientHeadContentLength_h1(t *testing.T) {
789         testClientHeadContentLength(t, h1Mode)
790 }
791
792 func TestClientHeadContentLength_h2(t *testing.T) {
793         testClientHeadContentLength(t, h2Mode)
794 }
795
796 func testClientHeadContentLength(t *testing.T, h2 bool) {
797         defer afterTest(t)
798         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
799                 if v := r.FormValue("cl"); v != "" {
800                         w.Header().Set("Content-Length", v)
801                 }
802         }))
803         defer cst.close()
804         tests := []struct {
805                 suffix string
806                 want   int64
807         }{
808                 {"/?cl=1234", 1234},
809                 {"/?cl=0", 0},
810                 {"", -1},
811         }
812         for _, tt := range tests {
813                 req, _ := NewRequest("HEAD", cst.ts.URL+tt.suffix, nil)
814                 res, err := cst.c.Do(req)
815                 if err != nil {
816                         t.Fatal(err)
817                 }
818                 if res.ContentLength != tt.want {
819                         t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
820                 }
821                 bs, err := ioutil.ReadAll(res.Body)
822                 if err != nil {
823                         t.Fatal(err)
824                 }
825                 if len(bs) != 0 {
826                         t.Errorf("Unexpected content: %q", bs)
827                 }
828         }
829 }
830
831 func TestEmptyPasswordAuth(t *testing.T) {
832         defer afterTest(t)
833         gopher := "gopher"
834         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
835                 auth := r.Header.Get("Authorization")
836                 if strings.HasPrefix(auth, "Basic ") {
837                         encoded := auth[6:]
838                         decoded, err := base64.StdEncoding.DecodeString(encoded)
839                         if err != nil {
840                                 t.Fatal(err)
841                         }
842                         expected := gopher + ":"
843                         s := string(decoded)
844                         if expected != s {
845                                 t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
846                         }
847                 } else {
848                         t.Errorf("Invalid auth %q", auth)
849                 }
850         }))
851         defer ts.Close()
852         c := &Client{}
853         req, err := NewRequest("GET", ts.URL, nil)
854         if err != nil {
855                 t.Fatal(err)
856         }
857         req.URL.User = url.User(gopher)
858         resp, err := c.Do(req)
859         if err != nil {
860                 t.Fatal(err)
861         }
862         defer resp.Body.Close()
863 }
864
865 func TestBasicAuth(t *testing.T) {
866         defer afterTest(t)
867         tr := &recordingTransport{}
868         client := &Client{Transport: tr}
869
870         url := "http://My%20User:My%20Pass@dummy.faketld/"
871         expected := "My User:My Pass"
872         client.Get(url)
873
874         if tr.req.Method != "GET" {
875                 t.Errorf("got method %q, want %q", tr.req.Method, "GET")
876         }
877         if tr.req.URL.String() != url {
878                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
879         }
880         if tr.req.Header == nil {
881                 t.Fatalf("expected non-nil request Header")
882         }
883         auth := tr.req.Header.Get("Authorization")
884         if strings.HasPrefix(auth, "Basic ") {
885                 encoded := auth[6:]
886                 decoded, err := base64.StdEncoding.DecodeString(encoded)
887                 if err != nil {
888                         t.Fatal(err)
889                 }
890                 s := string(decoded)
891                 if expected != s {
892                         t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
893                 }
894         } else {
895                 t.Errorf("Invalid auth %q", auth)
896         }
897 }
898
899 func TestBasicAuthHeadersPreserved(t *testing.T) {
900         defer afterTest(t)
901         tr := &recordingTransport{}
902         client := &Client{Transport: tr}
903
904         // If Authorization header is provided, username in URL should not override it
905         url := "http://My%20User@dummy.faketld/"
906         req, err := NewRequest("GET", url, nil)
907         if err != nil {
908                 t.Fatal(err)
909         }
910         req.SetBasicAuth("My User", "My Pass")
911         expected := "My User:My Pass"
912         client.Do(req)
913
914         if tr.req.Method != "GET" {
915                 t.Errorf("got method %q, want %q", tr.req.Method, "GET")
916         }
917         if tr.req.URL.String() != url {
918                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
919         }
920         if tr.req.Header == nil {
921                 t.Fatalf("expected non-nil request Header")
922         }
923         auth := tr.req.Header.Get("Authorization")
924         if strings.HasPrefix(auth, "Basic ") {
925                 encoded := auth[6:]
926                 decoded, err := base64.StdEncoding.DecodeString(encoded)
927                 if err != nil {
928                         t.Fatal(err)
929                 }
930                 s := string(decoded)
931                 if expected != s {
932                         t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
933                 }
934         } else {
935                 t.Errorf("Invalid auth %q", auth)
936         }
937
938 }
939
940 func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) }
941 func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) }
942
943 func testClientTimeout(t *testing.T, h2 bool) {
944         if testing.Short() {
945                 t.Skip("skipping in short mode")
946         }
947         defer afterTest(t)
948         sawRoot := make(chan bool, 1)
949         sawSlow := make(chan bool, 1)
950         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
951                 if r.URL.Path == "/" {
952                         sawRoot <- true
953                         Redirect(w, r, "/slow", StatusFound)
954                         return
955                 }
956                 if r.URL.Path == "/slow" {
957                         w.Write([]byte("Hello"))
958                         w.(Flusher).Flush()
959                         sawSlow <- true
960                         time.Sleep(2 * time.Second)
961                         return
962                 }
963         }))
964         defer cst.close()
965         const timeout = 500 * time.Millisecond
966         cst.c.Timeout = timeout
967
968         res, err := cst.c.Get(cst.ts.URL)
969         if err != nil {
970                 t.Fatal(err)
971         }
972
973         select {
974         case <-sawRoot:
975                 // good.
976         default:
977                 t.Fatal("handler never got / request")
978         }
979
980         select {
981         case <-sawSlow:
982                 // good.
983         default:
984                 t.Fatal("handler never got /slow request")
985         }
986
987         errc := make(chan error, 1)
988         go func() {
989                 _, err := ioutil.ReadAll(res.Body)
990                 errc <- err
991                 res.Body.Close()
992         }()
993
994         const failTime = timeout * 2
995         select {
996         case err := <-errc:
997                 if err == nil {
998                         t.Fatal("expected error from ReadAll")
999                 }
1000                 ne, ok := err.(net.Error)
1001                 if !ok {
1002                         t.Errorf("error value from ReadAll was %T; expected some net.Error", err)
1003                 } else if !ne.Timeout() {
1004                         t.Errorf("net.Error.Timeout = false; want true")
1005                 }
1006                 if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
1007                         t.Errorf("error string = %q; missing timeout substring", got)
1008                 }
1009         case <-time.After(failTime):
1010                 t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
1011         }
1012 }
1013
1014 func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) }
1015 func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h2Mode) }
1016
1017 // Client.Timeout firing before getting to the body
1018 func testClientTimeout_Headers(t *testing.T, h2 bool) {
1019         if testing.Short() {
1020                 t.Skip("skipping in short mode")
1021         }
1022         defer afterTest(t)
1023         donec := make(chan bool)
1024         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1025                 <-donec
1026         }))
1027         defer cst.close()
1028         // Note that we use a channel send here and not a close.
1029         // The race detector doesn't know that we're waiting for a timeout
1030         // and thinks that the waitgroup inside httptest.Server is added to concurrently
1031         // with us closing it. If we timed out immediately, we could close the testserver
1032         // before we entered the handler. We're not timing out immediately and there's
1033         // no way we would be done before we entered the handler, but the race detector
1034         // doesn't know this, so synchronize explicitly.
1035         defer func() { donec <- true }()
1036
1037         cst.c.Timeout = 500 * time.Millisecond
1038         _, err := cst.c.Get(cst.ts.URL)
1039         if err == nil {
1040                 t.Fatal("got response from Get; expected error")
1041         }
1042         if _, ok := err.(*url.Error); !ok {
1043                 t.Fatalf("Got error of type %T; want *url.Error", err)
1044         }
1045         ne, ok := err.(net.Error)
1046         if !ok {
1047                 t.Fatalf("Got error of type %T; want some net.Error", err)
1048         }
1049         if !ne.Timeout() {
1050                 t.Error("net.Error.Timeout = false; want true")
1051         }
1052         if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
1053                 t.Errorf("error string = %q; missing timeout substring", got)
1054         }
1055 }
1056
1057 func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) }
1058 func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) }
1059 func testClientRedirectEatsBody(t *testing.T, h2 bool) {
1060         defer afterTest(t)
1061         saw := make(chan string, 2)
1062         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1063                 saw <- r.RemoteAddr
1064                 if r.URL.Path == "/" {
1065                         Redirect(w, r, "/foo", StatusFound) // which includes a body
1066                 }
1067         }))
1068         defer cst.close()
1069
1070         res, err := cst.c.Get(cst.ts.URL)
1071         if err != nil {
1072                 t.Fatal(err)
1073         }
1074         _, err = ioutil.ReadAll(res.Body)
1075         if err != nil {
1076                 t.Fatal(err)
1077         }
1078         res.Body.Close()
1079
1080         var first string
1081         select {
1082         case first = <-saw:
1083         default:
1084                 t.Fatal("server didn't see a request")
1085         }
1086
1087         var second string
1088         select {
1089         case second = <-saw:
1090         default:
1091                 t.Fatal("server didn't see a second request")
1092         }
1093
1094         if first != second {
1095                 t.Fatal("server saw different client ports before & after the redirect")
1096         }
1097 }
1098
1099 // eofReaderFunc is an io.Reader that runs itself, and then returns io.EOF.
1100 type eofReaderFunc func()
1101
1102 func (f eofReaderFunc) Read(p []byte) (n int, err error) {
1103         f()
1104         return 0, io.EOF
1105 }
1106
1107 func TestReferer(t *testing.T) {
1108         tests := []struct {
1109                 lastReq, newReq string // from -> to URLs
1110                 want            string
1111         }{
1112                 // don't send user:
1113                 {"http://gopher@test.com", "http://link.com", "http://test.com"},
1114                 {"https://gopher@test.com", "https://link.com", "https://test.com"},
1115
1116                 // don't send a user and password:
1117                 {"http://gopher:go@test.com", "http://link.com", "http://test.com"},
1118                 {"https://gopher:go@test.com", "https://link.com", "https://test.com"},
1119
1120                 // nothing to do:
1121                 {"http://test.com", "http://link.com", "http://test.com"},
1122                 {"https://test.com", "https://link.com", "https://test.com"},
1123
1124                 // https to http doesn't send a referer:
1125                 {"https://test.com", "http://link.com", ""},
1126                 {"https://gopher:go@test.com", "http://link.com", ""},
1127         }
1128         for _, tt := range tests {
1129                 l, err := url.Parse(tt.lastReq)
1130                 if err != nil {
1131                         t.Fatal(err)
1132                 }
1133                 n, err := url.Parse(tt.newReq)
1134                 if err != nil {
1135                         t.Fatal(err)
1136                 }
1137                 r := ExportRefererForURL(l, n)
1138                 if r != tt.want {
1139                         t.Errorf("refererForURL(%q, %q) = %q; want %q", tt.lastReq, tt.newReq, r, tt.want)
1140                 }
1141         }
1142 }