]> Cypherpunks.ru repositories - gostls13.git/blob - src/net/http/client_test.go
net/http: deflake TestClientRedirectTypes and maybe some similar ones
[gostls13.git] / src / net / http / client_test.go
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // Tests for client.go
6
7 package http_test
8
9 import (
10         "bytes"
11         "context"
12         "crypto/tls"
13         "crypto/x509"
14         "encoding/base64"
15         "errors"
16         "fmt"
17         "io"
18         "io/ioutil"
19         "log"
20         "net"
21         . "net/http"
22         "net/http/cookiejar"
23         "net/http/httptest"
24         "net/url"
25         "reflect"
26         "strconv"
27         "strings"
28         "sync"
29         "testing"
30         "time"
31 )
32
33 var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
34         w.Header().Set("Last-Modified", "sometime")
35         fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
36 })
37
38 // pedanticReadAll works like ioutil.ReadAll but additionally
39 // verifies that r obeys the documented io.Reader contract.
40 func pedanticReadAll(r io.Reader) (b []byte, err error) {
41         var bufa [64]byte
42         buf := bufa[:]
43         for {
44                 n, err := r.Read(buf)
45                 if n == 0 && err == nil {
46                         return nil, fmt.Errorf("Read: n=0 with err=nil")
47                 }
48                 b = append(b, buf[:n]...)
49                 if err == io.EOF {
50                         n, err := r.Read(buf)
51                         if n != 0 || err != io.EOF {
52                                 return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err)
53                         }
54                         return b, nil
55                 }
56                 if err != nil {
57                         return b, err
58                 }
59         }
60 }
61
62 type chanWriter chan string
63
64 func (w chanWriter) Write(p []byte) (n int, err error) {
65         w <- string(p)
66         return len(p), nil
67 }
68
69 func TestClient(t *testing.T) {
70         setParallel(t)
71         defer afterTest(t)
72         ts := httptest.NewServer(robotsTxtHandler)
73         defer ts.Close()
74
75         c := &Client{Transport: &Transport{DisableKeepAlives: true}}
76         r, err := c.Get(ts.URL)
77         var b []byte
78         if err == nil {
79                 b, err = pedanticReadAll(r.Body)
80                 r.Body.Close()
81         }
82         if err != nil {
83                 t.Error(err)
84         } else if s := string(b); !strings.HasPrefix(s, "User-agent:") {
85                 t.Errorf("Incorrect page body (did not begin with User-agent): %q", s)
86         }
87 }
88
89 func TestClientHead_h1(t *testing.T) { testClientHead(t, h1Mode) }
90 func TestClientHead_h2(t *testing.T) { testClientHead(t, h2Mode) }
91
92 func testClientHead(t *testing.T, h2 bool) {
93         defer afterTest(t)
94         cst := newClientServerTest(t, h2, robotsTxtHandler)
95         defer cst.close()
96
97         r, err := cst.c.Head(cst.ts.URL)
98         if err != nil {
99                 t.Fatal(err)
100         }
101         if _, ok := r.Header["Last-Modified"]; !ok {
102                 t.Error("Last-Modified header not found.")
103         }
104 }
105
106 type recordingTransport struct {
107         req *Request
108 }
109
110 func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) {
111         t.req = req
112         return nil, errors.New("dummy impl")
113 }
114
115 func TestGetRequestFormat(t *testing.T) {
116         setParallel(t)
117         defer afterTest(t)
118         tr := &recordingTransport{}
119         client := &Client{Transport: tr}
120         url := "http://dummy.faketld/"
121         client.Get(url) // Note: doesn't hit network
122         if tr.req.Method != "GET" {
123                 t.Errorf("expected method %q; got %q", "GET", tr.req.Method)
124         }
125         if tr.req.URL.String() != url {
126                 t.Errorf("expected URL %q; got %q", url, tr.req.URL.String())
127         }
128         if tr.req.Header == nil {
129                 t.Errorf("expected non-nil request Header")
130         }
131 }
132
133 func TestPostRequestFormat(t *testing.T) {
134         defer afterTest(t)
135         tr := &recordingTransport{}
136         client := &Client{Transport: tr}
137
138         url := "http://dummy.faketld/"
139         json := `{"key":"value"}`
140         b := strings.NewReader(json)
141         client.Post(url, "application/json", b) // Note: doesn't hit network
142
143         if tr.req.Method != "POST" {
144                 t.Errorf("got method %q, want %q", tr.req.Method, "POST")
145         }
146         if tr.req.URL.String() != url {
147                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
148         }
149         if tr.req.Header == nil {
150                 t.Fatalf("expected non-nil request Header")
151         }
152         if tr.req.Close {
153                 t.Error("got Close true, want false")
154         }
155         if g, e := tr.req.ContentLength, int64(len(json)); g != e {
156                 t.Errorf("got ContentLength %d, want %d", g, e)
157         }
158 }
159
160 func TestPostFormRequestFormat(t *testing.T) {
161         defer afterTest(t)
162         tr := &recordingTransport{}
163         client := &Client{Transport: tr}
164
165         urlStr := "http://dummy.faketld/"
166         form := make(url.Values)
167         form.Set("foo", "bar")
168         form.Add("foo", "bar2")
169         form.Set("bar", "baz")
170         client.PostForm(urlStr, form) // Note: doesn't hit network
171
172         if tr.req.Method != "POST" {
173                 t.Errorf("got method %q, want %q", tr.req.Method, "POST")
174         }
175         if tr.req.URL.String() != urlStr {
176                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), urlStr)
177         }
178         if tr.req.Header == nil {
179                 t.Fatalf("expected non-nil request Header")
180         }
181         if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e {
182                 t.Errorf("got Content-Type %q, want %q", g, e)
183         }
184         if tr.req.Close {
185                 t.Error("got Close true, want false")
186         }
187         // Depending on map iteration, body can be either of these.
188         expectedBody := "foo=bar&foo=bar2&bar=baz"
189         expectedBody1 := "bar=baz&foo=bar&foo=bar2"
190         if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e {
191                 t.Errorf("got ContentLength %d, want %d", g, e)
192         }
193         bodyb, err := ioutil.ReadAll(tr.req.Body)
194         if err != nil {
195                 t.Fatalf("ReadAll on req.Body: %v", err)
196         }
197         if g := string(bodyb); g != expectedBody && g != expectedBody1 {
198                 t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1)
199         }
200 }
201
202 func TestClientRedirects(t *testing.T) {
203         setParallel(t)
204         defer afterTest(t)
205         var ts *httptest.Server
206         ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
207                 n, _ := strconv.Atoi(r.FormValue("n"))
208                 // Test Referer header. (7 is arbitrary position to test at)
209                 if n == 7 {
210                         if g, e := r.Referer(), ts.URL+"/?n=6"; e != g {
211                                 t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g)
212                         }
213                 }
214                 if n < 15 {
215                         Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusTemporaryRedirect)
216                         return
217                 }
218                 fmt.Fprintf(w, "n=%d", n)
219         }))
220         defer ts.Close()
221
222         tr := &Transport{}
223         defer tr.CloseIdleConnections()
224
225         c := &Client{Transport: tr}
226         _, err := c.Get(ts.URL)
227         if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
228                 t.Errorf("with default client Get, expected error %q, got %q", e, g)
229         }
230
231         // HEAD request should also have the ability to follow redirects.
232         _, err = c.Head(ts.URL)
233         if e, g := "Head /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
234                 t.Errorf("with default client Head, expected error %q, got %q", e, g)
235         }
236
237         // Do should also follow redirects.
238         greq, _ := NewRequest("GET", ts.URL, nil)
239         _, err = c.Do(greq)
240         if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
241                 t.Errorf("with default client Do, expected error %q, got %q", e, g)
242         }
243
244         // Requests with an empty Method should also redirect (Issue 12705)
245         greq.Method = ""
246         _, err = c.Do(greq)
247         if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
248                 t.Errorf("with default client Do and empty Method, expected error %q, got %q", e, g)
249         }
250
251         var checkErr error
252         var lastVia []*Request
253         var lastReq *Request
254         c = &Client{CheckRedirect: func(req *Request, via []*Request) error {
255                 lastReq = req
256                 lastVia = via
257                 return checkErr
258         }}
259         res, err := c.Get(ts.URL)
260         if err != nil {
261                 t.Fatalf("Get error: %v", err)
262         }
263         res.Body.Close()
264         finalUrl := res.Request.URL.String()
265         if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
266                 t.Errorf("with custom client, expected error %q, got %q", e, g)
267         }
268         if !strings.HasSuffix(finalUrl, "/?n=15") {
269                 t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl)
270         }
271         if e, g := 15, len(lastVia); e != g {
272                 t.Errorf("expected lastVia to have contained %d elements; got %d", e, g)
273         }
274
275         // Test that Request.Cancel is propagated between requests (Issue 14053)
276         creq, _ := NewRequest("HEAD", ts.URL, nil)
277         cancel := make(chan struct{})
278         creq.Cancel = cancel
279         if _, err := c.Do(creq); err != nil {
280                 t.Fatal(err)
281         }
282         if lastReq == nil {
283                 t.Fatal("didn't see redirect")
284         }
285         if lastReq.Cancel != cancel {
286                 t.Errorf("expected lastReq to have the cancel channel set on the initial req")
287         }
288
289         checkErr = errors.New("no redirects allowed")
290         res, err = c.Get(ts.URL)
291         if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr {
292                 t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err)
293         }
294         if res == nil {
295                 t.Fatalf("Expected a non-nil Response on CheckRedirect failure (https://golang.org/issue/3795)")
296         }
297         res.Body.Close()
298         if res.Header.Get("Location") == "" {
299                 t.Errorf("no Location header in Response")
300         }
301 }
302
303 func TestClientRedirectContext(t *testing.T) {
304         setParallel(t)
305         defer afterTest(t)
306         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
307                 Redirect(w, r, "/", StatusTemporaryRedirect)
308         }))
309         defer ts.Close()
310
311         tr := &Transport{}
312         defer tr.CloseIdleConnections()
313
314         ctx, cancel := context.WithCancel(context.Background())
315         c := &Client{
316                 Transport: tr,
317                 CheckRedirect: func(req *Request, via []*Request) error {
318                         cancel()
319                         if len(via) > 2 {
320                                 return errors.New("too many redirects")
321                         }
322                         return nil
323                 },
324         }
325         req, _ := NewRequest("GET", ts.URL, nil)
326         req = req.WithContext(ctx)
327         _, err := c.Do(req)
328         ue, ok := err.(*url.Error)
329         if !ok {
330                 t.Fatalf("got error %T; want *url.Error", err)
331         }
332         if ue.Err != context.Canceled {
333                 t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled)
334         }
335 }
336
337 type redirectTest struct {
338         suffix       string
339         want         int // response code
340         redirectBody string
341 }
342
343 func TestPostRedirects(t *testing.T) {
344         postRedirectTests := []redirectTest{
345                 {"/", 200, "first"},
346                 {"/?code=301&next=302", 200, "c301"},
347                 {"/?code=302&next=302", 200, "c302"},
348                 {"/?code=303&next=301", 200, "c303wc301"}, // Issue 9348
349                 {"/?code=304", 304, "c304"},
350                 {"/?code=305", 305, "c305"},
351                 {"/?code=307&next=303,308,302", 200, "c307"},
352                 {"/?code=308&next=302,301", 200, "c308"},
353                 {"/?code=404", 404, "c404"},
354         }
355
356         wantSegments := []string{
357                 `POST / "first"`,
358                 `POST /?code=301&next=302 "c301"`,
359                 `GET /?code=302 "c301"`,
360                 `GET / "c301"`,
361                 `POST /?code=302&next=302 "c302"`,
362                 `GET /?code=302 "c302"`,
363                 `GET / "c302"`,
364                 `POST /?code=303&next=301 "c303wc301"`,
365                 `GET /?code=301 "c303wc301"`,
366                 `GET / "c303wc301"`,
367                 `POST /?code=304 "c304"`,
368                 `POST /?code=305 "c305"`,
369                 `POST /?code=307&next=303,308,302 "c307"`,
370                 `POST /?code=303&next=308,302 "c307"`,
371                 `GET /?code=308&next=302 "c307"`,
372                 `GET /?code=302 "c307"`,
373                 `GET / "c307"`,
374                 `POST /?code=308&next=302,301 "c308"`,
375                 `POST /?code=302&next=301 "c308"`,
376                 `GET /?code=301 "c308"`,
377                 `GET / "c308"`,
378                 `POST /?code=404 "c404"`,
379         }
380         want := strings.Join(wantSegments, "\n")
381         testRedirectsByMethod(t, "POST", postRedirectTests, want)
382 }
383
384 func TestDeleteRedirects(t *testing.T) {
385         deleteRedirectTests := []redirectTest{
386                 {"/", 200, "first"},
387                 {"/?code=301&next=302,308", 200, "c301"},
388                 {"/?code=302&next=302", 200, "c302"},
389                 {"/?code=303", 200, "c303"},
390                 {"/?code=307&next=301,308,303,302,304", 304, "c307"},
391                 {"/?code=308&next=307", 200, "c308"},
392                 {"/?code=404", 404, "c404"},
393         }
394
395         wantSegments := []string{
396                 `DELETE / "first"`,
397                 `DELETE /?code=301&next=302,308 "c301"`,
398                 `GET /?code=302&next=308 "c301"`,
399                 `GET /?code=308 "c301"`,
400                 `GET / "c301"`,
401                 `DELETE /?code=302&next=302 "c302"`,
402                 `GET /?code=302 "c302"`,
403                 `GET / "c302"`,
404                 `DELETE /?code=303 "c303"`,
405                 `GET / "c303"`,
406                 `DELETE /?code=307&next=301,308,303,302,304 "c307"`,
407                 `DELETE /?code=301&next=308,303,302,304 "c307"`,
408                 `GET /?code=308&next=303,302,304 "c307"`,
409                 `GET /?code=303&next=302,304 "c307"`,
410                 `GET /?code=302&next=304 "c307"`,
411                 `GET /?code=304 "c307"`,
412                 `DELETE /?code=308&next=307 "c308"`,
413                 `DELETE /?code=307 "c308"`,
414                 `DELETE / "c308"`,
415                 `DELETE /?code=404 "c404"`,
416         }
417         want := strings.Join(wantSegments, "\n")
418         testRedirectsByMethod(t, "DELETE", deleteRedirectTests, want)
419 }
420
421 func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, want string) {
422         defer afterTest(t)
423         var log struct {
424                 sync.Mutex
425                 bytes.Buffer
426         }
427         var ts *httptest.Server
428         ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
429                 log.Lock()
430                 slurp, _ := ioutil.ReadAll(r.Body)
431                 fmt.Fprintf(&log.Buffer, "%s %s %q\n", r.Method, r.RequestURI, slurp)
432                 log.Unlock()
433                 urlQuery := r.URL.Query()
434                 if v := urlQuery.Get("code"); v != "" {
435                         location := ts.URL
436                         if final := urlQuery.Get("next"); final != "" {
437                                 splits := strings.Split(final, ",")
438                                 first, rest := splits[0], splits[1:]
439                                 location = fmt.Sprintf("%s?code=%s", location, first)
440                                 if len(rest) > 0 {
441                                         location = fmt.Sprintf("%s&next=%s", location, strings.Join(rest, ","))
442                                 }
443                         }
444                         code, _ := strconv.Atoi(v)
445                         if code/100 == 3 {
446                                 w.Header().Set("Location", location)
447                         }
448                         w.WriteHeader(code)
449                 }
450         }))
451         defer ts.Close()
452
453         for _, tt := range table {
454                 content := tt.redirectBody
455                 req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content))
456                 req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil }
457                 res, err := DefaultClient.Do(req)
458
459                 if err != nil {
460                         t.Fatal(err)
461                 }
462                 if res.StatusCode != tt.want {
463                         t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
464                 }
465         }
466         log.Lock()
467         got := log.String()
468         log.Unlock()
469
470         got = strings.TrimSpace(got)
471         want = strings.TrimSpace(want)
472
473         if got != want {
474                 t.Errorf("Log differs.\n Got:\n%s\nWant:\n%s\n", got, want)
475         }
476 }
477
478 func TestClientRedirectUseResponse(t *testing.T) {
479         setParallel(t)
480         defer afterTest(t)
481         const body = "Hello, world."
482         var ts *httptest.Server
483         ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
484                 if strings.Contains(r.URL.Path, "/other") {
485                         io.WriteString(w, "wrong body")
486                 } else {
487                         w.Header().Set("Location", ts.URL+"/other")
488                         w.WriteHeader(StatusFound)
489                         io.WriteString(w, body)
490                 }
491         }))
492         defer ts.Close()
493
494         tr := &Transport{}
495         defer tr.CloseIdleConnections()
496
497         c := &Client{
498                 Transport: tr,
499                 CheckRedirect: func(req *Request, via []*Request) error {
500                         if req.Response == nil {
501                                 t.Error("expected non-nil Request.Response")
502                         }
503                         return ErrUseLastResponse
504                 },
505         }
506         res, err := c.Get(ts.URL)
507         if err != nil {
508                 t.Fatal(err)
509         }
510         if res.StatusCode != StatusFound {
511                 t.Errorf("status = %d; want %d", res.StatusCode, StatusFound)
512         }
513         defer res.Body.Close()
514         slurp, err := ioutil.ReadAll(res.Body)
515         if err != nil {
516                 t.Fatal(err)
517         }
518         if string(slurp) != body {
519                 t.Errorf("body = %q; want %q", slurp, body)
520         }
521 }
522
523 // Issue 17773: don't follow a 308 (or 307) if the response doesn't
524 // have a Location header.
525 func TestClientRedirect308NoLocation(t *testing.T) {
526         setParallel(t)
527         defer afterTest(t)
528         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
529                 w.Header().Set("Foo", "Bar")
530                 w.WriteHeader(308)
531         }))
532         defer ts.Close()
533         res, err := Get(ts.URL)
534         if err != nil {
535                 t.Fatal(err)
536         }
537         res.Body.Close()
538         if res.StatusCode != 308 {
539                 t.Errorf("status = %d; want %d", res.StatusCode, 308)
540         }
541         if got := res.Header.Get("Foo"); got != "Bar" {
542                 t.Errorf("Foo header = %q; want Bar", got)
543         }
544 }
545
546 // Don't follow a 307/308 if we can't resent the request body.
547 func TestClientRedirect308NoGetBody(t *testing.T) {
548         setParallel(t)
549         defer afterTest(t)
550         const fakeURL = "https://localhost:1234/" // won't be hit
551         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
552                 w.Header().Set("Location", fakeURL)
553                 w.WriteHeader(308)
554         }))
555         defer ts.Close()
556         req, err := NewRequest("POST", ts.URL, strings.NewReader("some body"))
557         if err != nil {
558                 t.Fatal(err)
559         }
560         req.GetBody = nil // so it can't rewind.
561         res, err := DefaultClient.Do(req)
562         if err != nil {
563                 t.Fatal(err)
564         }
565         res.Body.Close()
566         if res.StatusCode != 308 {
567                 t.Errorf("status = %d; want %d", res.StatusCode, 308)
568         }
569         if got := res.Header.Get("Location"); got != fakeURL {
570                 t.Errorf("Location header = %q; want %q", got, fakeURL)
571         }
572 }
573
574 var expectedCookies = []*Cookie{
575         {Name: "ChocolateChip", Value: "tasty"},
576         {Name: "First", Value: "Hit"},
577         {Name: "Second", Value: "Hit"},
578 }
579
580 var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
581         for _, cookie := range r.Cookies() {
582                 SetCookie(w, cookie)
583         }
584         if r.URL.Path == "/" {
585                 SetCookie(w, expectedCookies[1])
586                 Redirect(w, r, "/second", StatusMovedPermanently)
587         } else {
588                 SetCookie(w, expectedCookies[2])
589                 w.Write([]byte("hello"))
590         }
591 })
592
593 func TestClientSendsCookieFromJar(t *testing.T) {
594         defer afterTest(t)
595         tr := &recordingTransport{}
596         client := &Client{Transport: tr}
597         client.Jar = &TestJar{perURL: make(map[string][]*Cookie)}
598         us := "http://dummy.faketld/"
599         u, _ := url.Parse(us)
600         client.Jar.SetCookies(u, expectedCookies)
601
602         client.Get(us) // Note: doesn't hit network
603         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
604
605         client.Head(us) // Note: doesn't hit network
606         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
607
608         client.Post(us, "text/plain", strings.NewReader("body")) // Note: doesn't hit network
609         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
610
611         client.PostForm(us, url.Values{}) // Note: doesn't hit network
612         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
613
614         req, _ := NewRequest("GET", us, nil)
615         client.Do(req) // Note: doesn't hit network
616         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
617
618         req, _ = NewRequest("POST", us, nil)
619         client.Do(req) // Note: doesn't hit network
620         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
621 }
622
623 // Just enough correctness for our redirect tests. Uses the URL.Host as the
624 // scope of all cookies.
625 type TestJar struct {
626         m      sync.Mutex
627         perURL map[string][]*Cookie
628 }
629
630 func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) {
631         j.m.Lock()
632         defer j.m.Unlock()
633         if j.perURL == nil {
634                 j.perURL = make(map[string][]*Cookie)
635         }
636         j.perURL[u.Host] = cookies
637 }
638
639 func (j *TestJar) Cookies(u *url.URL) []*Cookie {
640         j.m.Lock()
641         defer j.m.Unlock()
642         return j.perURL[u.Host]
643 }
644
645 func TestRedirectCookiesJar(t *testing.T) {
646         setParallel(t)
647         defer afterTest(t)
648         var ts *httptest.Server
649         ts = httptest.NewServer(echoCookiesRedirectHandler)
650         defer ts.Close()
651         tr := &Transport{}
652         defer tr.CloseIdleConnections()
653         c := &Client{
654                 Transport: tr,
655                 Jar:       new(TestJar),
656         }
657         u, _ := url.Parse(ts.URL)
658         c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
659         resp, err := c.Get(ts.URL)
660         if err != nil {
661                 t.Fatalf("Get: %v", err)
662         }
663         resp.Body.Close()
664         matchReturnedCookies(t, expectedCookies, resp.Cookies())
665 }
666
667 func matchReturnedCookies(t *testing.T, expected, given []*Cookie) {
668         if len(given) != len(expected) {
669                 t.Logf("Received cookies: %v", given)
670                 t.Errorf("Expected %d cookies, got %d", len(expected), len(given))
671         }
672         for _, ec := range expected {
673                 foundC := false
674                 for _, c := range given {
675                         if ec.Name == c.Name && ec.Value == c.Value {
676                                 foundC = true
677                                 break
678                         }
679                 }
680                 if !foundC {
681                         t.Errorf("Missing cookie %v", ec)
682                 }
683         }
684 }
685
686 func TestJarCalls(t *testing.T) {
687         defer afterTest(t)
688         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
689                 pathSuffix := r.RequestURI[1:]
690                 if r.RequestURI == "/nosetcookie" {
691                         return // don't set cookies for this path
692                 }
693                 SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix})
694                 if r.RequestURI == "/" {
695                         Redirect(w, r, "http://secondhost.fake/secondpath", 302)
696                 }
697         }))
698         defer ts.Close()
699         jar := new(RecordingJar)
700         c := &Client{
701                 Jar: jar,
702                 Transport: &Transport{
703                         Dial: func(_ string, _ string) (net.Conn, error) {
704                                 return net.Dial("tcp", ts.Listener.Addr().String())
705                         },
706                 },
707         }
708         _, err := c.Get("http://firsthost.fake/")
709         if err != nil {
710                 t.Fatal(err)
711         }
712         _, err = c.Get("http://firsthost.fake/nosetcookie")
713         if err != nil {
714                 t.Fatal(err)
715         }
716         got := jar.log.String()
717         want := `Cookies("http://firsthost.fake/")
718 SetCookie("http://firsthost.fake/", [name=val])
719 Cookies("http://secondhost.fake/secondpath")
720 SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath])
721 Cookies("http://firsthost.fake/nosetcookie")
722 `
723         if got != want {
724                 t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want)
725         }
726 }
727
728 // RecordingJar keeps a log of calls made to it, without
729 // tracking any cookies.
730 type RecordingJar struct {
731         mu  sync.Mutex
732         log bytes.Buffer
733 }
734
735 func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) {
736         j.logf("SetCookie(%q, %v)\n", u, cookies)
737 }
738
739 func (j *RecordingJar) Cookies(u *url.URL) []*Cookie {
740         j.logf("Cookies(%q)\n", u)
741         return nil
742 }
743
744 func (j *RecordingJar) logf(format string, args ...interface{}) {
745         j.mu.Lock()
746         defer j.mu.Unlock()
747         fmt.Fprintf(&j.log, format, args...)
748 }
749
750 func TestStreamingGet_h1(t *testing.T) { testStreamingGet(t, h1Mode) }
751 func TestStreamingGet_h2(t *testing.T) { testStreamingGet(t, h2Mode) }
752
753 func testStreamingGet(t *testing.T, h2 bool) {
754         defer afterTest(t)
755         say := make(chan string)
756         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
757                 w.(Flusher).Flush()
758                 for str := range say {
759                         w.Write([]byte(str))
760                         w.(Flusher).Flush()
761                 }
762         }))
763         defer cst.close()
764
765         c := cst.c
766         res, err := c.Get(cst.ts.URL)
767         if err != nil {
768                 t.Fatal(err)
769         }
770         var buf [10]byte
771         for _, str := range []string{"i", "am", "also", "known", "as", "comet"} {
772                 say <- str
773                 n, err := io.ReadFull(res.Body, buf[0:len(str)])
774                 if err != nil {
775                         t.Fatalf("ReadFull on %q: %v", str, err)
776                 }
777                 if n != len(str) {
778                         t.Fatalf("Receiving %q, only read %d bytes", str, n)
779                 }
780                 got := string(buf[0:n])
781                 if got != str {
782                         t.Fatalf("Expected %q, got %q", str, got)
783                 }
784         }
785         close(say)
786         _, err = io.ReadFull(res.Body, buf[0:1])
787         if err != io.EOF {
788                 t.Fatalf("at end expected EOF, got %v", err)
789         }
790 }
791
792 type writeCountingConn struct {
793         net.Conn
794         count *int
795 }
796
797 func (c *writeCountingConn) Write(p []byte) (int, error) {
798         *c.count++
799         return c.Conn.Write(p)
800 }
801
802 // TestClientWrites verifies that client requests are buffered and we
803 // don't send a TCP packet per line of the http request + body.
804 func TestClientWrites(t *testing.T) {
805         defer afterTest(t)
806         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
807         }))
808         defer ts.Close()
809
810         writes := 0
811         dialer := func(netz string, addr string) (net.Conn, error) {
812                 c, err := net.Dial(netz, addr)
813                 if err == nil {
814                         c = &writeCountingConn{c, &writes}
815                 }
816                 return c, err
817         }
818         c := &Client{Transport: &Transport{Dial: dialer}}
819
820         _, err := c.Get(ts.URL)
821         if err != nil {
822                 t.Fatal(err)
823         }
824         if writes != 1 {
825                 t.Errorf("Get request did %d Write calls, want 1", writes)
826         }
827
828         writes = 0
829         _, err = c.PostForm(ts.URL, url.Values{"foo": {"bar"}})
830         if err != nil {
831                 t.Fatal(err)
832         }
833         if writes != 1 {
834                 t.Errorf("Post request did %d Write calls, want 1", writes)
835         }
836 }
837
838 func TestClientInsecureTransport(t *testing.T) {
839         setParallel(t)
840         defer afterTest(t)
841         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
842                 w.Write([]byte("Hello"))
843         }))
844         errc := make(chanWriter, 10) // but only expecting 1
845         ts.Config.ErrorLog = log.New(errc, "", 0)
846         defer ts.Close()
847
848         // TODO(bradfitz): add tests for skipping hostname checks too?
849         // would require a new cert for testing, and probably
850         // redundant with these tests.
851         for _, insecure := range []bool{true, false} {
852                 tr := &Transport{
853                         TLSClientConfig: &tls.Config{
854                                 InsecureSkipVerify: insecure,
855                         },
856                 }
857                 defer tr.CloseIdleConnections()
858                 c := &Client{Transport: tr}
859                 res, err := c.Get(ts.URL)
860                 if (err == nil) != insecure {
861                         t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
862                 }
863                 if res != nil {
864                         res.Body.Close()
865                 }
866         }
867
868         select {
869         case v := <-errc:
870                 if !strings.Contains(v, "TLS handshake error") {
871                         t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
872                 }
873         case <-time.After(5 * time.Second):
874                 t.Errorf("timeout waiting for logged error")
875         }
876
877 }
878
879 func TestClientErrorWithRequestURI(t *testing.T) {
880         defer afterTest(t)
881         req, _ := NewRequest("GET", "http://localhost:1234/", nil)
882         req.RequestURI = "/this/field/is/illegal/and/should/error/"
883         _, err := DefaultClient.Do(req)
884         if err == nil {
885                 t.Fatalf("expected an error")
886         }
887         if !strings.Contains(err.Error(), "RequestURI") {
888                 t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
889         }
890 }
891
892 func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
893         certs := x509.NewCertPool()
894         for _, c := range ts.TLS.Certificates {
895                 roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
896                 if err != nil {
897                         t.Fatalf("error parsing server's root cert: %v", err)
898                 }
899                 for _, root := range roots {
900                         certs.AddCert(root)
901                 }
902         }
903         return &Transport{
904                 TLSClientConfig: &tls.Config{RootCAs: certs},
905         }
906 }
907
908 func TestClientWithCorrectTLSServerName(t *testing.T) {
909         defer afterTest(t)
910
911         const serverName = "example.com"
912         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
913                 if r.TLS.ServerName != serverName {
914                         t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName)
915                 }
916         }))
917         defer ts.Close()
918
919         trans := newTLSTransport(t, ts)
920         trans.TLSClientConfig.ServerName = serverName
921         c := &Client{Transport: trans}
922         if _, err := c.Get(ts.URL); err != nil {
923                 t.Fatalf("expected successful TLS connection, got error: %v", err)
924         }
925 }
926
927 func TestClientWithIncorrectTLSServerName(t *testing.T) {
928         defer afterTest(t)
929         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
930         defer ts.Close()
931         errc := make(chanWriter, 10) // but only expecting 1
932         ts.Config.ErrorLog = log.New(errc, "", 0)
933
934         trans := newTLSTransport(t, ts)
935         trans.TLSClientConfig.ServerName = "badserver"
936         c := &Client{Transport: trans}
937         _, err := c.Get(ts.URL)
938         if err == nil {
939                 t.Fatalf("expected an error")
940         }
941         if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
942                 t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
943         }
944         select {
945         case v := <-errc:
946                 if !strings.Contains(v, "TLS handshake error") {
947                         t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
948                 }
949         case <-time.After(5 * time.Second):
950                 t.Errorf("timeout waiting for logged error")
951         }
952 }
953
954 // Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName
955 // when not empty.
956 //
957 // tls.Config.ServerName (non-empty, set to "example.com") takes
958 // precedence over "some-other-host.tld" which previously incorrectly
959 // took precedence. We don't actually connect to (or even resolve)
960 // "some-other-host.tld", though, because of the Transport.Dial hook.
961 //
962 // The httptest.Server has a cert with "example.com" as its name.
963 func TestTransportUsesTLSConfigServerName(t *testing.T) {
964         defer afterTest(t)
965         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
966                 w.Write([]byte("Hello"))
967         }))
968         defer ts.Close()
969
970         tr := newTLSTransport(t, ts)
971         tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names
972         tr.Dial = func(netw, addr string) (net.Conn, error) {
973                 return net.Dial(netw, ts.Listener.Addr().String())
974         }
975         defer tr.CloseIdleConnections()
976         c := &Client{Transport: tr}
977         res, err := c.Get("https://some-other-host.tld/")
978         if err != nil {
979                 t.Fatal(err)
980         }
981         res.Body.Close()
982 }
983
984 func TestResponseSetsTLSConnectionState(t *testing.T) {
985         defer afterTest(t)
986         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
987                 w.Write([]byte("Hello"))
988         }))
989         defer ts.Close()
990
991         tr := newTLSTransport(t, ts)
992         tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA}
993         tr.Dial = func(netw, addr string) (net.Conn, error) {
994                 return net.Dial(netw, ts.Listener.Addr().String())
995         }
996         defer tr.CloseIdleConnections()
997         c := &Client{Transport: tr}
998         res, err := c.Get("https://example.com/")
999         if err != nil {
1000                 t.Fatal(err)
1001         }
1002         defer res.Body.Close()
1003         if res.TLS == nil {
1004                 t.Fatal("Response didn't set TLS Connection State.")
1005         }
1006         if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want {
1007                 t.Errorf("TLS Cipher Suite = %d; want %d", got, want)
1008         }
1009 }
1010
1011 // Check that an HTTPS client can interpret a particular TLS error
1012 // to determine that the server is speaking HTTP.
1013 // See golang.org/issue/11111.
1014 func TestHTTPSClientDetectsHTTPServer(t *testing.T) {
1015         defer afterTest(t)
1016         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
1017         defer ts.Close()
1018
1019         _, err := Get(strings.Replace(ts.URL, "http", "https", 1))
1020         if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") {
1021                 t.Fatalf("error = %q; want error indicating HTTP response to HTTPS request", got)
1022         }
1023 }
1024
1025 // Verify Response.ContentLength is populated. https://golang.org/issue/4126
1026 func TestClientHeadContentLength_h1(t *testing.T) {
1027         testClientHeadContentLength(t, h1Mode)
1028 }
1029
1030 func TestClientHeadContentLength_h2(t *testing.T) {
1031         testClientHeadContentLength(t, h2Mode)
1032 }
1033
1034 func testClientHeadContentLength(t *testing.T, h2 bool) {
1035         defer afterTest(t)
1036         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1037                 if v := r.FormValue("cl"); v != "" {
1038                         w.Header().Set("Content-Length", v)
1039                 }
1040         }))
1041         defer cst.close()
1042         tests := []struct {
1043                 suffix string
1044                 want   int64
1045         }{
1046                 {"/?cl=1234", 1234},
1047                 {"/?cl=0", 0},
1048                 {"", -1},
1049         }
1050         for _, tt := range tests {
1051                 req, _ := NewRequest("HEAD", cst.ts.URL+tt.suffix, nil)
1052                 res, err := cst.c.Do(req)
1053                 if err != nil {
1054                         t.Fatal(err)
1055                 }
1056                 if res.ContentLength != tt.want {
1057                         t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
1058                 }
1059                 bs, err := ioutil.ReadAll(res.Body)
1060                 if err != nil {
1061                         t.Fatal(err)
1062                 }
1063                 if len(bs) != 0 {
1064                         t.Errorf("Unexpected content: %q", bs)
1065                 }
1066         }
1067 }
1068
1069 func TestEmptyPasswordAuth(t *testing.T) {
1070         setParallel(t)
1071         defer afterTest(t)
1072         gopher := "gopher"
1073         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1074                 auth := r.Header.Get("Authorization")
1075                 if strings.HasPrefix(auth, "Basic ") {
1076                         encoded := auth[6:]
1077                         decoded, err := base64.StdEncoding.DecodeString(encoded)
1078                         if err != nil {
1079                                 t.Fatal(err)
1080                         }
1081                         expected := gopher + ":"
1082                         s := string(decoded)
1083                         if expected != s {
1084                                 t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
1085                         }
1086                 } else {
1087                         t.Errorf("Invalid auth %q", auth)
1088                 }
1089         }))
1090         defer ts.Close()
1091         tr := &Transport{}
1092         defer tr.CloseIdleConnections()
1093         c := &Client{Transport: tr}
1094         req, err := NewRequest("GET", ts.URL, nil)
1095         if err != nil {
1096                 t.Fatal(err)
1097         }
1098         req.URL.User = url.User(gopher)
1099         resp, err := c.Do(req)
1100         if err != nil {
1101                 t.Fatal(err)
1102         }
1103         defer resp.Body.Close()
1104 }
1105
1106 func TestBasicAuth(t *testing.T) {
1107         defer afterTest(t)
1108         tr := &recordingTransport{}
1109         client := &Client{Transport: tr}
1110
1111         url := "http://My%20User:My%20Pass@dummy.faketld/"
1112         expected := "My User:My Pass"
1113         client.Get(url)
1114
1115         if tr.req.Method != "GET" {
1116                 t.Errorf("got method %q, want %q", tr.req.Method, "GET")
1117         }
1118         if tr.req.URL.String() != url {
1119                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
1120         }
1121         if tr.req.Header == nil {
1122                 t.Fatalf("expected non-nil request Header")
1123         }
1124         auth := tr.req.Header.Get("Authorization")
1125         if strings.HasPrefix(auth, "Basic ") {
1126                 encoded := auth[6:]
1127                 decoded, err := base64.StdEncoding.DecodeString(encoded)
1128                 if err != nil {
1129                         t.Fatal(err)
1130                 }
1131                 s := string(decoded)
1132                 if expected != s {
1133                         t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
1134                 }
1135         } else {
1136                 t.Errorf("Invalid auth %q", auth)
1137         }
1138 }
1139
1140 func TestBasicAuthHeadersPreserved(t *testing.T) {
1141         defer afterTest(t)
1142         tr := &recordingTransport{}
1143         client := &Client{Transport: tr}
1144
1145         // If Authorization header is provided, username in URL should not override it
1146         url := "http://My%20User@dummy.faketld/"
1147         req, err := NewRequest("GET", url, nil)
1148         if err != nil {
1149                 t.Fatal(err)
1150         }
1151         req.SetBasicAuth("My User", "My Pass")
1152         expected := "My User:My Pass"
1153         client.Do(req)
1154
1155         if tr.req.Method != "GET" {
1156                 t.Errorf("got method %q, want %q", tr.req.Method, "GET")
1157         }
1158         if tr.req.URL.String() != url {
1159                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
1160         }
1161         if tr.req.Header == nil {
1162                 t.Fatalf("expected non-nil request Header")
1163         }
1164         auth := tr.req.Header.Get("Authorization")
1165         if strings.HasPrefix(auth, "Basic ") {
1166                 encoded := auth[6:]
1167                 decoded, err := base64.StdEncoding.DecodeString(encoded)
1168                 if err != nil {
1169                         t.Fatal(err)
1170                 }
1171                 s := string(decoded)
1172                 if expected != s {
1173                         t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
1174                 }
1175         } else {
1176                 t.Errorf("Invalid auth %q", auth)
1177         }
1178
1179 }
1180
1181 func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) }
1182 func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) }
1183
1184 func testClientTimeout(t *testing.T, h2 bool) {
1185         if testing.Short() {
1186                 t.Skip("skipping in short mode")
1187         }
1188         defer afterTest(t)
1189         sawRoot := make(chan bool, 1)
1190         sawSlow := make(chan bool, 1)
1191         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1192                 if r.URL.Path == "/" {
1193                         sawRoot <- true
1194                         Redirect(w, r, "/slow", StatusFound)
1195                         return
1196                 }
1197                 if r.URL.Path == "/slow" {
1198                         w.Write([]byte("Hello"))
1199                         w.(Flusher).Flush()
1200                         sawSlow <- true
1201                         time.Sleep(2 * time.Second)
1202                         return
1203                 }
1204         }))
1205         defer cst.close()
1206         const timeout = 500 * time.Millisecond
1207         cst.c.Timeout = timeout
1208
1209         res, err := cst.c.Get(cst.ts.URL)
1210         if err != nil {
1211                 t.Fatal(err)
1212         }
1213
1214         select {
1215         case <-sawRoot:
1216                 // good.
1217         default:
1218                 t.Fatal("handler never got / request")
1219         }
1220
1221         select {
1222         case <-sawSlow:
1223                 // good.
1224         default:
1225                 t.Fatal("handler never got /slow request")
1226         }
1227
1228         errc := make(chan error, 1)
1229         go func() {
1230                 _, err := ioutil.ReadAll(res.Body)
1231                 errc <- err
1232                 res.Body.Close()
1233         }()
1234
1235         const failTime = timeout * 2
1236         select {
1237         case err := <-errc:
1238                 if err == nil {
1239                         t.Fatal("expected error from ReadAll")
1240                 }
1241                 ne, ok := err.(net.Error)
1242                 if !ok {
1243                         t.Errorf("error value from ReadAll was %T; expected some net.Error", err)
1244                 } else if !ne.Timeout() {
1245                         t.Errorf("net.Error.Timeout = false; want true")
1246                 }
1247                 if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
1248                         t.Errorf("error string = %q; missing timeout substring", got)
1249                 }
1250         case <-time.After(failTime):
1251                 t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
1252         }
1253 }
1254
1255 func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) }
1256 func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h2Mode) }
1257
1258 // Client.Timeout firing before getting to the body
1259 func testClientTimeout_Headers(t *testing.T, h2 bool) {
1260         if testing.Short() {
1261                 t.Skip("skipping in short mode")
1262         }
1263         defer afterTest(t)
1264         donec := make(chan bool)
1265         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1266                 <-donec
1267         }))
1268         defer cst.close()
1269         // Note that we use a channel send here and not a close.
1270         // The race detector doesn't know that we're waiting for a timeout
1271         // and thinks that the waitgroup inside httptest.Server is added to concurrently
1272         // with us closing it. If we timed out immediately, we could close the testserver
1273         // before we entered the handler. We're not timing out immediately and there's
1274         // no way we would be done before we entered the handler, but the race detector
1275         // doesn't know this, so synchronize explicitly.
1276         defer func() { donec <- true }()
1277
1278         cst.c.Timeout = 500 * time.Millisecond
1279         _, err := cst.c.Get(cst.ts.URL)
1280         if err == nil {
1281                 t.Fatal("got response from Get; expected error")
1282         }
1283         if _, ok := err.(*url.Error); !ok {
1284                 t.Fatalf("Got error of type %T; want *url.Error", err)
1285         }
1286         ne, ok := err.(net.Error)
1287         if !ok {
1288                 t.Fatalf("Got error of type %T; want some net.Error", err)
1289         }
1290         if !ne.Timeout() {
1291                 t.Error("net.Error.Timeout = false; want true")
1292         }
1293         if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
1294                 t.Errorf("error string = %q; missing timeout substring", got)
1295         }
1296 }
1297
1298 func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) }
1299 func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) }
1300 func testClientRedirectEatsBody(t *testing.T, h2 bool) {
1301         setParallel(t)
1302         defer afterTest(t)
1303         saw := make(chan string, 2)
1304         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1305                 saw <- r.RemoteAddr
1306                 if r.URL.Path == "/" {
1307                         Redirect(w, r, "/foo", StatusFound) // which includes a body
1308                 }
1309         }))
1310         defer cst.close()
1311
1312         res, err := cst.c.Get(cst.ts.URL)
1313         if err != nil {
1314                 t.Fatal(err)
1315         }
1316         _, err = ioutil.ReadAll(res.Body)
1317         if err != nil {
1318                 t.Fatal(err)
1319         }
1320         res.Body.Close()
1321
1322         var first string
1323         select {
1324         case first = <-saw:
1325         default:
1326                 t.Fatal("server didn't see a request")
1327         }
1328
1329         var second string
1330         select {
1331         case second = <-saw:
1332         default:
1333                 t.Fatal("server didn't see a second request")
1334         }
1335
1336         if first != second {
1337                 t.Fatal("server saw different client ports before & after the redirect")
1338         }
1339 }
1340
1341 // eofReaderFunc is an io.Reader that runs itself, and then returns io.EOF.
1342 type eofReaderFunc func()
1343
1344 func (f eofReaderFunc) Read(p []byte) (n int, err error) {
1345         f()
1346         return 0, io.EOF
1347 }
1348
1349 func TestReferer(t *testing.T) {
1350         tests := []struct {
1351                 lastReq, newReq string // from -> to URLs
1352                 want            string
1353         }{
1354                 // don't send user:
1355                 {"http://gopher@test.com", "http://link.com", "http://test.com"},
1356                 {"https://gopher@test.com", "https://link.com", "https://test.com"},
1357
1358                 // don't send a user and password:
1359                 {"http://gopher:go@test.com", "http://link.com", "http://test.com"},
1360                 {"https://gopher:go@test.com", "https://link.com", "https://test.com"},
1361
1362                 // nothing to do:
1363                 {"http://test.com", "http://link.com", "http://test.com"},
1364                 {"https://test.com", "https://link.com", "https://test.com"},
1365
1366                 // https to http doesn't send a referer:
1367                 {"https://test.com", "http://link.com", ""},
1368                 {"https://gopher:go@test.com", "http://link.com", ""},
1369         }
1370         for _, tt := range tests {
1371                 l, err := url.Parse(tt.lastReq)
1372                 if err != nil {
1373                         t.Fatal(err)
1374                 }
1375                 n, err := url.Parse(tt.newReq)
1376                 if err != nil {
1377                         t.Fatal(err)
1378                 }
1379                 r := ExportRefererForURL(l, n)
1380                 if r != tt.want {
1381                         t.Errorf("refererForURL(%q, %q) = %q; want %q", tt.lastReq, tt.newReq, r, tt.want)
1382                 }
1383         }
1384 }
1385
1386 // issue15577Tripper returns a Response with a redirect response
1387 // header and doesn't populate its Response.Request field.
1388 type issue15577Tripper struct{}
1389
1390 func (issue15577Tripper) RoundTrip(*Request) (*Response, error) {
1391         resp := &Response{
1392                 StatusCode: 303,
1393                 Header:     map[string][]string{"Location": {"http://www.example.com/"}},
1394                 Body:       ioutil.NopCloser(strings.NewReader("")),
1395         }
1396         return resp, nil
1397 }
1398
1399 // Issue 15577: don't assume the roundtripper's response populates its Request field.
1400 func TestClientRedirectResponseWithoutRequest(t *testing.T) {
1401         c := &Client{
1402                 CheckRedirect: func(*Request, []*Request) error { return fmt.Errorf("no redirects!") },
1403                 Transport:     issue15577Tripper{},
1404         }
1405         // Check that this doesn't crash:
1406         c.Get("http://dummy.tld")
1407 }
1408
1409 // Issue 4800: copy (some) headers when Client follows a redirect
1410 func TestClientCopyHeadersOnRedirect(t *testing.T) {
1411         const (
1412                 ua   = "some-agent/1.2"
1413                 xfoo = "foo-val"
1414         )
1415         var ts2URL string
1416         ts1 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1417                 want := Header{
1418                         "User-Agent":      []string{ua},
1419                         "X-Foo":           []string{xfoo},
1420                         "Referer":         []string{ts2URL},
1421                         "Accept-Encoding": []string{"gzip"},
1422                 }
1423                 if !reflect.DeepEqual(r.Header, want) {
1424                         t.Errorf("Request.Header = %#v; want %#v", r.Header, want)
1425                 }
1426                 if t.Failed() {
1427                         w.Header().Set("Result", "got errors")
1428                 } else {
1429                         w.Header().Set("Result", "ok")
1430                 }
1431         }))
1432         defer ts1.Close()
1433         ts2 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1434                 Redirect(w, r, ts1.URL, StatusFound)
1435         }))
1436         defer ts2.Close()
1437         ts2URL = ts2.URL
1438
1439         tr := &Transport{}
1440         defer tr.CloseIdleConnections()
1441         c := &Client{
1442                 Transport: tr,
1443                 CheckRedirect: func(r *Request, via []*Request) error {
1444                         want := Header{
1445                                 "User-Agent": []string{ua},
1446                                 "X-Foo":      []string{xfoo},
1447                                 "Referer":    []string{ts2URL},
1448                         }
1449                         if !reflect.DeepEqual(r.Header, want) {
1450                                 t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want)
1451                         }
1452                         return nil
1453                 },
1454         }
1455
1456         req, _ := NewRequest("GET", ts2.URL, nil)
1457         req.Header.Add("User-Agent", ua)
1458         req.Header.Add("X-Foo", xfoo)
1459         req.Header.Add("Cookie", "foo=bar")
1460         req.Header.Add("Authorization", "secretpassword")
1461         res, err := c.Do(req)
1462         if err != nil {
1463                 t.Fatal(err)
1464         }
1465         defer res.Body.Close()
1466         if res.StatusCode != 200 {
1467                 t.Fatal(res.Status)
1468         }
1469         if got := res.Header.Get("Result"); got != "ok" {
1470                 t.Errorf("result = %q; want ok", got)
1471         }
1472 }
1473
1474 // Issue 17494: cookies should be altered when Client follows redirects.
1475 func TestClientAltersCookiesOnRedirect(t *testing.T) {
1476         cookieMap := func(cs []*Cookie) map[string][]string {
1477                 m := make(map[string][]string)
1478                 for _, c := range cs {
1479                         m[c.Name] = append(m[c.Name], c.Value)
1480                 }
1481                 return m
1482         }
1483
1484         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1485                 var want map[string][]string
1486                 got := cookieMap(r.Cookies())
1487
1488                 c, _ := r.Cookie("Cycle")
1489                 switch c.Value {
1490                 case "0":
1491                         want = map[string][]string{
1492                                 "Cookie1": {"OldValue1a", "OldValue1b"},
1493                                 "Cookie2": {"OldValue2"},
1494                                 "Cookie3": {"OldValue3a", "OldValue3b"},
1495                                 "Cookie4": {"OldValue4"},
1496                                 "Cycle":   {"0"},
1497                         }
1498                         SetCookie(w, &Cookie{Name: "Cycle", Value: "1", Path: "/"})
1499                         SetCookie(w, &Cookie{Name: "Cookie2", Path: "/", MaxAge: -1}) // Delete cookie from Header
1500                         Redirect(w, r, "/", StatusFound)
1501                 case "1":
1502                         want = map[string][]string{
1503                                 "Cookie1": {"OldValue1a", "OldValue1b"},
1504                                 "Cookie3": {"OldValue3a", "OldValue3b"},
1505                                 "Cookie4": {"OldValue4"},
1506                                 "Cycle":   {"1"},
1507                         }
1508                         SetCookie(w, &Cookie{Name: "Cycle", Value: "2", Path: "/"})
1509                         SetCookie(w, &Cookie{Name: "Cookie3", Value: "NewValue3", Path: "/"}) // Modify cookie in Header
1510                         SetCookie(w, &Cookie{Name: "Cookie4", Value: "NewValue4", Path: "/"}) // Modify cookie in Jar
1511                         Redirect(w, r, "/", StatusFound)
1512                 case "2":
1513                         want = map[string][]string{
1514                                 "Cookie1": {"OldValue1a", "OldValue1b"},
1515                                 "Cookie3": {"NewValue3"},
1516                                 "Cookie4": {"NewValue4"},
1517                                 "Cycle":   {"2"},
1518                         }
1519                         SetCookie(w, &Cookie{Name: "Cycle", Value: "3", Path: "/"})
1520                         SetCookie(w, &Cookie{Name: "Cookie5", Value: "NewValue5", Path: "/"}) // Insert cookie into Jar
1521                         Redirect(w, r, "/", StatusFound)
1522                 case "3":
1523                         want = map[string][]string{
1524                                 "Cookie1": {"OldValue1a", "OldValue1b"},
1525                                 "Cookie3": {"NewValue3"},
1526                                 "Cookie4": {"NewValue4"},
1527                                 "Cookie5": {"NewValue5"},
1528                                 "Cycle":   {"3"},
1529                         }
1530                         // Don't redirect to ensure the loop ends.
1531                 default:
1532                         t.Errorf("unexpected redirect cycle")
1533                         return
1534                 }
1535
1536                 if !reflect.DeepEqual(got, want) {
1537                         t.Errorf("redirect %s, Cookie = %v, want %v", c.Value, got, want)
1538                 }
1539         }))
1540         defer ts.Close()
1541
1542         tr := &Transport{}
1543         defer tr.CloseIdleConnections()
1544         jar, _ := cookiejar.New(nil)
1545         c := &Client{
1546                 Transport: tr,
1547                 Jar:       jar,
1548         }
1549
1550         u, _ := url.Parse(ts.URL)
1551         req, _ := NewRequest("GET", ts.URL, nil)
1552         req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1a"})
1553         req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1b"})
1554         req.AddCookie(&Cookie{Name: "Cookie2", Value: "OldValue2"})
1555         req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3a"})
1556         req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3b"})
1557         jar.SetCookies(u, []*Cookie{{Name: "Cookie4", Value: "OldValue4", Path: "/"}})
1558         jar.SetCookies(u, []*Cookie{{Name: "Cycle", Value: "0", Path: "/"}})
1559         res, err := c.Do(req)
1560         if err != nil {
1561                 t.Fatal(err)
1562         }
1563         defer res.Body.Close()
1564         if res.StatusCode != 200 {
1565                 t.Fatal(res.Status)
1566         }
1567 }
1568
1569 // Part of Issue 4800
1570 func TestShouldCopyHeaderOnRedirect(t *testing.T) {
1571         tests := []struct {
1572                 header     string
1573                 initialURL string
1574                 destURL    string
1575                 want       bool
1576         }{
1577                 {"User-Agent", "http://foo.com/", "http://bar.com/", true},
1578                 {"X-Foo", "http://foo.com/", "http://bar.com/", true},
1579
1580                 // Sensitive headers:
1581                 {"cookie", "http://foo.com/", "http://bar.com/", false},
1582                 {"cookie2", "http://foo.com/", "http://bar.com/", false},
1583                 {"authorization", "http://foo.com/", "http://bar.com/", false},
1584                 {"www-authenticate", "http://foo.com/", "http://bar.com/", false},
1585
1586                 // But subdomains should work:
1587                 {"www-authenticate", "http://foo.com/", "http://foo.com/", true},
1588                 {"www-authenticate", "http://foo.com/", "http://sub.foo.com/", true},
1589                 {"www-authenticate", "http://foo.com/", "http://notfoo.com/", false},
1590                 // TODO(bradfitz): make this test work, once issue 16142 is fixed:
1591                 // {"www-authenticate", "http://foo.com:80/", "http://foo.com/", true},
1592         }
1593         for i, tt := range tests {
1594                 u0, err := url.Parse(tt.initialURL)
1595                 if err != nil {
1596                         t.Errorf("%d. initial URL %q parse error: %v", i, tt.initialURL, err)
1597                         continue
1598                 }
1599                 u1, err := url.Parse(tt.destURL)
1600                 if err != nil {
1601                         t.Errorf("%d. dest URL %q parse error: %v", i, tt.destURL, err)
1602                         continue
1603                 }
1604                 got := Export_shouldCopyHeaderOnRedirect(tt.header, u0, u1)
1605                 if got != tt.want {
1606                         t.Errorf("%d. shouldCopyHeaderOnRedirect(%q, %q => %q) = %v; want %v",
1607                                 i, tt.header, tt.initialURL, tt.destURL, got, tt.want)
1608                 }
1609         }
1610 }
1611
1612 func TestClientRedirectTypes(t *testing.T) {
1613         setParallel(t)
1614         defer afterTest(t)
1615
1616         tests := [...]struct {
1617                 method       string
1618                 serverStatus int
1619                 wantMethod   string // desired subsequent client method
1620         }{
1621                 0: {method: "POST", serverStatus: 301, wantMethod: "GET"},
1622                 1: {method: "POST", serverStatus: 302, wantMethod: "GET"},
1623                 2: {method: "POST", serverStatus: 303, wantMethod: "GET"},
1624                 3: {method: "POST", serverStatus: 307, wantMethod: "POST"},
1625                 4: {method: "POST", serverStatus: 308, wantMethod: "POST"},
1626
1627                 5: {method: "HEAD", serverStatus: 301, wantMethod: "GET"},
1628                 6: {method: "HEAD", serverStatus: 302, wantMethod: "GET"},
1629                 7: {method: "HEAD", serverStatus: 303, wantMethod: "GET"},
1630                 8: {method: "HEAD", serverStatus: 307, wantMethod: "HEAD"},
1631                 9: {method: "HEAD", serverStatus: 308, wantMethod: "HEAD"},
1632
1633                 10: {method: "GET", serverStatus: 301, wantMethod: "GET"},
1634                 11: {method: "GET", serverStatus: 302, wantMethod: "GET"},
1635                 12: {method: "GET", serverStatus: 303, wantMethod: "GET"},
1636                 13: {method: "GET", serverStatus: 307, wantMethod: "GET"},
1637                 14: {method: "GET", serverStatus: 308, wantMethod: "GET"},
1638
1639                 15: {method: "DELETE", serverStatus: 301, wantMethod: "GET"},
1640                 16: {method: "DELETE", serverStatus: 302, wantMethod: "GET"},
1641                 17: {method: "DELETE", serverStatus: 303, wantMethod: "GET"},
1642                 18: {method: "DELETE", serverStatus: 307, wantMethod: "DELETE"},
1643                 19: {method: "DELETE", serverStatus: 308, wantMethod: "DELETE"},
1644
1645                 20: {method: "PUT", serverStatus: 301, wantMethod: "GET"},
1646                 21: {method: "PUT", serverStatus: 302, wantMethod: "GET"},
1647                 22: {method: "PUT", serverStatus: 303, wantMethod: "GET"},
1648                 23: {method: "PUT", serverStatus: 307, wantMethod: "PUT"},
1649                 24: {method: "PUT", serverStatus: 308, wantMethod: "PUT"},
1650
1651                 25: {method: "MADEUPMETHOD", serverStatus: 301, wantMethod: "GET"},
1652                 26: {method: "MADEUPMETHOD", serverStatus: 302, wantMethod: "GET"},
1653                 27: {method: "MADEUPMETHOD", serverStatus: 303, wantMethod: "GET"},
1654                 28: {method: "MADEUPMETHOD", serverStatus: 307, wantMethod: "MADEUPMETHOD"},
1655                 29: {method: "MADEUPMETHOD", serverStatus: 308, wantMethod: "MADEUPMETHOD"},
1656         }
1657
1658         handlerc := make(chan HandlerFunc, 1)
1659
1660         ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
1661                 h := <-handlerc
1662                 h(rw, req)
1663         }))
1664         defer ts.Close()
1665
1666         tr := &Transport{}
1667         defer tr.CloseIdleConnections()
1668
1669         for i, tt := range tests {
1670                 handlerc <- func(w ResponseWriter, r *Request) {
1671                         w.Header().Set("Location", ts.URL)
1672                         w.WriteHeader(tt.serverStatus)
1673                 }
1674
1675                 req, err := NewRequest(tt.method, ts.URL, nil)
1676                 if err != nil {
1677                         t.Errorf("#%d: NewRequest: %v", i, err)
1678                         continue
1679                 }
1680
1681                 c := &Client{Transport: tr}
1682                 c.CheckRedirect = func(req *Request, via []*Request) error {
1683                         if got, want := req.Method, tt.wantMethod; got != want {
1684                                 return fmt.Errorf("#%d: got next method %q; want %q", i, got, want)
1685                         }
1686                         handlerc <- func(rw ResponseWriter, req *Request) {
1687                                 // TODO: Check that the body is valid when we do 307 and 308 support
1688                         }
1689                         return nil
1690                 }
1691
1692                 res, err := c.Do(req)
1693                 if err != nil {
1694                         t.Errorf("#%d: Response: %v", i, err)
1695                         continue
1696                 }
1697
1698                 res.Body.Close()
1699         }
1700 }