]> Cypherpunks.ru repositories - gostls13.git/blob - src/net/http/client_test.go
net/http: testClientHead now in http2 mode
[gostls13.git] / src / net / http / client_test.go
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // Tests for client.go
6
7 package http_test
8
9 import (
10         "bytes"
11         "crypto/tls"
12         "crypto/x509"
13         "encoding/base64"
14         "errors"
15         "fmt"
16         "io"
17         "io/ioutil"
18         "log"
19         "net"
20         . "net/http"
21         "net/http/httptest"
22         "net/url"
23         "reflect"
24         "sort"
25         "strconv"
26         "strings"
27         "sync"
28         "testing"
29         "time"
30 )
31
32 var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
33         w.Header().Set("Last-Modified", "sometime")
34         fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
35 })
36
37 // pedanticReadAll works like ioutil.ReadAll but additionally
38 // verifies that r obeys the documented io.Reader contract.
39 func pedanticReadAll(r io.Reader) (b []byte, err error) {
40         var bufa [64]byte
41         buf := bufa[:]
42         for {
43                 n, err := r.Read(buf)
44                 if n == 0 && err == nil {
45                         return nil, fmt.Errorf("Read: n=0 with err=nil")
46                 }
47                 b = append(b, buf[:n]...)
48                 if err == io.EOF {
49                         n, err := r.Read(buf)
50                         if n != 0 || err != io.EOF {
51                                 return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err)
52                         }
53                         return b, nil
54                 }
55                 if err != nil {
56                         return b, err
57                 }
58         }
59 }
60
61 type chanWriter chan string
62
63 func (w chanWriter) Write(p []byte) (n int, err error) {
64         w <- string(p)
65         return len(p), nil
66 }
67
68 func TestClient(t *testing.T) {
69         defer afterTest(t)
70         ts := httptest.NewServer(robotsTxtHandler)
71         defer ts.Close()
72
73         r, err := Get(ts.URL)
74         var b []byte
75         if err == nil {
76                 b, err = pedanticReadAll(r.Body)
77                 r.Body.Close()
78         }
79         if err != nil {
80                 t.Error(err)
81         } else if s := string(b); !strings.HasPrefix(s, "User-agent:") {
82                 t.Errorf("Incorrect page body (did not begin with User-agent): %q", s)
83         }
84 }
85
86 func TestClientHead_h1(t *testing.T) { testClientHead(t, false) }
87 func TestClientHead_h2(t *testing.T) { testClientHead(t, true) }
88
89 func testClientHead(t *testing.T, h2 bool) {
90         defer afterTest(t)
91         cst := newClientServerTest(t, h2, robotsTxtHandler)
92         defer cst.close()
93
94         r, err := cst.c.Head(cst.ts.URL)
95         if err != nil {
96                 t.Fatal(err)
97         }
98         if _, ok := r.Header["Last-Modified"]; !ok {
99                 t.Error("Last-Modified header not found.")
100         }
101 }
102
103 type recordingTransport struct {
104         req *Request
105 }
106
107 func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) {
108         t.req = req
109         return nil, errors.New("dummy impl")
110 }
111
112 func TestGetRequestFormat(t *testing.T) {
113         defer afterTest(t)
114         tr := &recordingTransport{}
115         client := &Client{Transport: tr}
116         url := "http://dummy.faketld/"
117         client.Get(url) // Note: doesn't hit network
118         if tr.req.Method != "GET" {
119                 t.Errorf("expected method %q; got %q", "GET", tr.req.Method)
120         }
121         if tr.req.URL.String() != url {
122                 t.Errorf("expected URL %q; got %q", url, tr.req.URL.String())
123         }
124         if tr.req.Header == nil {
125                 t.Errorf("expected non-nil request Header")
126         }
127 }
128
129 func TestPostRequestFormat(t *testing.T) {
130         defer afterTest(t)
131         tr := &recordingTransport{}
132         client := &Client{Transport: tr}
133
134         url := "http://dummy.faketld/"
135         json := `{"key":"value"}`
136         b := strings.NewReader(json)
137         client.Post(url, "application/json", b) // Note: doesn't hit network
138
139         if tr.req.Method != "POST" {
140                 t.Errorf("got method %q, want %q", tr.req.Method, "POST")
141         }
142         if tr.req.URL.String() != url {
143                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
144         }
145         if tr.req.Header == nil {
146                 t.Fatalf("expected non-nil request Header")
147         }
148         if tr.req.Close {
149                 t.Error("got Close true, want false")
150         }
151         if g, e := tr.req.ContentLength, int64(len(json)); g != e {
152                 t.Errorf("got ContentLength %d, want %d", g, e)
153         }
154 }
155
156 func TestPostFormRequestFormat(t *testing.T) {
157         defer afterTest(t)
158         tr := &recordingTransport{}
159         client := &Client{Transport: tr}
160
161         urlStr := "http://dummy.faketld/"
162         form := make(url.Values)
163         form.Set("foo", "bar")
164         form.Add("foo", "bar2")
165         form.Set("bar", "baz")
166         client.PostForm(urlStr, form) // Note: doesn't hit network
167
168         if tr.req.Method != "POST" {
169                 t.Errorf("got method %q, want %q", tr.req.Method, "POST")
170         }
171         if tr.req.URL.String() != urlStr {
172                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), urlStr)
173         }
174         if tr.req.Header == nil {
175                 t.Fatalf("expected non-nil request Header")
176         }
177         if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e {
178                 t.Errorf("got Content-Type %q, want %q", g, e)
179         }
180         if tr.req.Close {
181                 t.Error("got Close true, want false")
182         }
183         // Depending on map iteration, body can be either of these.
184         expectedBody := "foo=bar&foo=bar2&bar=baz"
185         expectedBody1 := "bar=baz&foo=bar&foo=bar2"
186         if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e {
187                 t.Errorf("got ContentLength %d, want %d", g, e)
188         }
189         bodyb, err := ioutil.ReadAll(tr.req.Body)
190         if err != nil {
191                 t.Fatalf("ReadAll on req.Body: %v", err)
192         }
193         if g := string(bodyb); g != expectedBody && g != expectedBody1 {
194                 t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1)
195         }
196 }
197
198 func TestClientRedirects(t *testing.T) {
199         defer afterTest(t)
200         var ts *httptest.Server
201         ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
202                 n, _ := strconv.Atoi(r.FormValue("n"))
203                 // Test Referer header. (7 is arbitrary position to test at)
204                 if n == 7 {
205                         if g, e := r.Referer(), ts.URL+"/?n=6"; e != g {
206                                 t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g)
207                         }
208                 }
209                 if n < 15 {
210                         Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound)
211                         return
212                 }
213                 fmt.Fprintf(w, "n=%d", n)
214         }))
215         defer ts.Close()
216
217         c := &Client{}
218         _, err := c.Get(ts.URL)
219         if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
220                 t.Errorf("with default client Get, expected error %q, got %q", e, g)
221         }
222
223         // HEAD request should also have the ability to follow redirects.
224         _, err = c.Head(ts.URL)
225         if e, g := "Head /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
226                 t.Errorf("with default client Head, expected error %q, got %q", e, g)
227         }
228
229         // Do should also follow redirects.
230         greq, _ := NewRequest("GET", ts.URL, nil)
231         _, err = c.Do(greq)
232         if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
233                 t.Errorf("with default client Do, expected error %q, got %q", e, g)
234         }
235
236         // Requests with an empty Method should also redirect (Issue 12705)
237         greq.Method = ""
238         _, err = c.Do(greq)
239         if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
240                 t.Errorf("with default client Do and empty Method, expected error %q, got %q", e, g)
241         }
242
243         var checkErr error
244         var lastVia []*Request
245         c = &Client{CheckRedirect: func(_ *Request, via []*Request) error {
246                 lastVia = via
247                 return checkErr
248         }}
249         res, err := c.Get(ts.URL)
250         if err != nil {
251                 t.Fatalf("Get error: %v", err)
252         }
253         res.Body.Close()
254         finalUrl := res.Request.URL.String()
255         if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
256                 t.Errorf("with custom client, expected error %q, got %q", e, g)
257         }
258         if !strings.HasSuffix(finalUrl, "/?n=15") {
259                 t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl)
260         }
261         if e, g := 15, len(lastVia); e != g {
262                 t.Errorf("expected lastVia to have contained %d elements; got %d", e, g)
263         }
264
265         checkErr = errors.New("no redirects allowed")
266         res, err = c.Get(ts.URL)
267         if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr {
268                 t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err)
269         }
270         if res == nil {
271                 t.Fatalf("Expected a non-nil Response on CheckRedirect failure (https://golang.org/issue/3795)")
272         }
273         res.Body.Close()
274         if res.Header.Get("Location") == "" {
275                 t.Errorf("no Location header in Response")
276         }
277 }
278
279 func TestPostRedirects(t *testing.T) {
280         defer afterTest(t)
281         var log struct {
282                 sync.Mutex
283                 bytes.Buffer
284         }
285         var ts *httptest.Server
286         ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
287                 log.Lock()
288                 fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI)
289                 log.Unlock()
290                 if v := r.URL.Query().Get("code"); v != "" {
291                         code, _ := strconv.Atoi(v)
292                         if code/100 == 3 {
293                                 w.Header().Set("Location", ts.URL)
294                         }
295                         w.WriteHeader(code)
296                 }
297         }))
298         defer ts.Close()
299         tests := []struct {
300                 suffix string
301                 want   int // response code
302         }{
303                 {"/", 200},
304                 {"/?code=301", 301},
305                 {"/?code=302", 200},
306                 {"/?code=303", 200},
307                 {"/?code=404", 404},
308         }
309         for _, tt := range tests {
310                 res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content"))
311                 if err != nil {
312                         t.Fatal(err)
313                 }
314                 if res.StatusCode != tt.want {
315                         t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
316                 }
317         }
318         log.Lock()
319         got := log.String()
320         log.Unlock()
321         want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 "
322         if got != want {
323                 t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want)
324         }
325 }
326
327 var expectedCookies = []*Cookie{
328         {Name: "ChocolateChip", Value: "tasty"},
329         {Name: "First", Value: "Hit"},
330         {Name: "Second", Value: "Hit"},
331 }
332
333 var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
334         for _, cookie := range r.Cookies() {
335                 SetCookie(w, cookie)
336         }
337         if r.URL.Path == "/" {
338                 SetCookie(w, expectedCookies[1])
339                 Redirect(w, r, "/second", StatusMovedPermanently)
340         } else {
341                 SetCookie(w, expectedCookies[2])
342                 w.Write([]byte("hello"))
343         }
344 })
345
346 func TestClientSendsCookieFromJar(t *testing.T) {
347         defer afterTest(t)
348         tr := &recordingTransport{}
349         client := &Client{Transport: tr}
350         client.Jar = &TestJar{perURL: make(map[string][]*Cookie)}
351         us := "http://dummy.faketld/"
352         u, _ := url.Parse(us)
353         client.Jar.SetCookies(u, expectedCookies)
354
355         client.Get(us) // Note: doesn't hit network
356         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
357
358         client.Head(us) // Note: doesn't hit network
359         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
360
361         client.Post(us, "text/plain", strings.NewReader("body")) // Note: doesn't hit network
362         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
363
364         client.PostForm(us, url.Values{}) // Note: doesn't hit network
365         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
366
367         req, _ := NewRequest("GET", us, nil)
368         client.Do(req) // Note: doesn't hit network
369         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
370
371         req, _ = NewRequest("POST", us, nil)
372         client.Do(req) // Note: doesn't hit network
373         matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
374 }
375
376 // Just enough correctness for our redirect tests. Uses the URL.Host as the
377 // scope of all cookies.
378 type TestJar struct {
379         m      sync.Mutex
380         perURL map[string][]*Cookie
381 }
382
383 func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) {
384         j.m.Lock()
385         defer j.m.Unlock()
386         if j.perURL == nil {
387                 j.perURL = make(map[string][]*Cookie)
388         }
389         j.perURL[u.Host] = cookies
390 }
391
392 func (j *TestJar) Cookies(u *url.URL) []*Cookie {
393         j.m.Lock()
394         defer j.m.Unlock()
395         return j.perURL[u.Host]
396 }
397
398 func TestRedirectCookiesJar(t *testing.T) {
399         defer afterTest(t)
400         var ts *httptest.Server
401         ts = httptest.NewServer(echoCookiesRedirectHandler)
402         defer ts.Close()
403         c := &Client{
404                 Jar: new(TestJar),
405         }
406         u, _ := url.Parse(ts.URL)
407         c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
408         resp, err := c.Get(ts.URL)
409         if err != nil {
410                 t.Fatalf("Get: %v", err)
411         }
412         resp.Body.Close()
413         matchReturnedCookies(t, expectedCookies, resp.Cookies())
414 }
415
416 func matchReturnedCookies(t *testing.T, expected, given []*Cookie) {
417         if len(given) != len(expected) {
418                 t.Logf("Received cookies: %v", given)
419                 t.Errorf("Expected %d cookies, got %d", len(expected), len(given))
420         }
421         for _, ec := range expected {
422                 foundC := false
423                 for _, c := range given {
424                         if ec.Name == c.Name && ec.Value == c.Value {
425                                 foundC = true
426                                 break
427                         }
428                 }
429                 if !foundC {
430                         t.Errorf("Missing cookie %v", ec)
431                 }
432         }
433 }
434
435 func TestJarCalls(t *testing.T) {
436         defer afterTest(t)
437         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
438                 pathSuffix := r.RequestURI[1:]
439                 if r.RequestURI == "/nosetcookie" {
440                         return // don't set cookies for this path
441                 }
442                 SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix})
443                 if r.RequestURI == "/" {
444                         Redirect(w, r, "http://secondhost.fake/secondpath", 302)
445                 }
446         }))
447         defer ts.Close()
448         jar := new(RecordingJar)
449         c := &Client{
450                 Jar: jar,
451                 Transport: &Transport{
452                         Dial: func(_ string, _ string) (net.Conn, error) {
453                                 return net.Dial("tcp", ts.Listener.Addr().String())
454                         },
455                 },
456         }
457         _, err := c.Get("http://firsthost.fake/")
458         if err != nil {
459                 t.Fatal(err)
460         }
461         _, err = c.Get("http://firsthost.fake/nosetcookie")
462         if err != nil {
463                 t.Fatal(err)
464         }
465         got := jar.log.String()
466         want := `Cookies("http://firsthost.fake/")
467 SetCookie("http://firsthost.fake/", [name=val])
468 Cookies("http://secondhost.fake/secondpath")
469 SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath])
470 Cookies("http://firsthost.fake/nosetcookie")
471 `
472         if got != want {
473                 t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want)
474         }
475 }
476
477 // RecordingJar keeps a log of calls made to it, without
478 // tracking any cookies.
479 type RecordingJar struct {
480         mu  sync.Mutex
481         log bytes.Buffer
482 }
483
484 func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) {
485         j.logf("SetCookie(%q, %v)\n", u, cookies)
486 }
487
488 func (j *RecordingJar) Cookies(u *url.URL) []*Cookie {
489         j.logf("Cookies(%q)\n", u)
490         return nil
491 }
492
493 func (j *RecordingJar) logf(format string, args ...interface{}) {
494         j.mu.Lock()
495         defer j.mu.Unlock()
496         fmt.Fprintf(&j.log, format, args...)
497 }
498
499 func TestStreamingGet_h1(t *testing.T) { testStreamingGet(t, false) }
500 func TestStreamingGet_h2(t *testing.T) { testStreamingGet(t, true) }
501
502 func testStreamingGet(t *testing.T, h2 bool) {
503         defer afterTest(t)
504         say := make(chan string)
505         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
506                 w.(Flusher).Flush()
507                 for str := range say {
508                         w.Write([]byte(str))
509                         w.(Flusher).Flush()
510                 }
511         }))
512         defer cst.close()
513
514         c := cst.c
515         res, err := c.Get(cst.ts.URL)
516         if err != nil {
517                 t.Fatal(err)
518         }
519         var buf [10]byte
520         for _, str := range []string{"i", "am", "also", "known", "as", "comet"} {
521                 say <- str
522                 n, err := io.ReadFull(res.Body, buf[0:len(str)])
523                 if err != nil {
524                         t.Fatalf("ReadFull on %q: %v", str, err)
525                 }
526                 if n != len(str) {
527                         t.Fatalf("Receiving %q, only read %d bytes", str, n)
528                 }
529                 got := string(buf[0:n])
530                 if got != str {
531                         t.Fatalf("Expected %q, got %q", str, got)
532                 }
533         }
534         close(say)
535         _, err = io.ReadFull(res.Body, buf[0:1])
536         if err != io.EOF {
537                 t.Fatalf("at end expected EOF, got %v", err)
538         }
539 }
540
541 type writeCountingConn struct {
542         net.Conn
543         count *int
544 }
545
546 func (c *writeCountingConn) Write(p []byte) (int, error) {
547         *c.count++
548         return c.Conn.Write(p)
549 }
550
551 // TestClientWrites verifies that client requests are buffered and we
552 // don't send a TCP packet per line of the http request + body.
553 func TestClientWrites(t *testing.T) {
554         defer afterTest(t)
555         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
556         }))
557         defer ts.Close()
558
559         writes := 0
560         dialer := func(netz string, addr string) (net.Conn, error) {
561                 c, err := net.Dial(netz, addr)
562                 if err == nil {
563                         c = &writeCountingConn{c, &writes}
564                 }
565                 return c, err
566         }
567         c := &Client{Transport: &Transport{Dial: dialer}}
568
569         _, err := c.Get(ts.URL)
570         if err != nil {
571                 t.Fatal(err)
572         }
573         if writes != 1 {
574                 t.Errorf("Get request did %d Write calls, want 1", writes)
575         }
576
577         writes = 0
578         _, err = c.PostForm(ts.URL, url.Values{"foo": {"bar"}})
579         if err != nil {
580                 t.Fatal(err)
581         }
582         if writes != 1 {
583                 t.Errorf("Post request did %d Write calls, want 1", writes)
584         }
585 }
586
587 func TestClientInsecureTransport(t *testing.T) {
588         defer afterTest(t)
589         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
590                 w.Write([]byte("Hello"))
591         }))
592         errc := make(chanWriter, 10) // but only expecting 1
593         ts.Config.ErrorLog = log.New(errc, "", 0)
594         defer ts.Close()
595
596         // TODO(bradfitz): add tests for skipping hostname checks too?
597         // would require a new cert for testing, and probably
598         // redundant with these tests.
599         for _, insecure := range []bool{true, false} {
600                 tr := &Transport{
601                         TLSClientConfig: &tls.Config{
602                                 InsecureSkipVerify: insecure,
603                         },
604                 }
605                 defer tr.CloseIdleConnections()
606                 c := &Client{Transport: tr}
607                 res, err := c.Get(ts.URL)
608                 if (err == nil) != insecure {
609                         t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
610                 }
611                 if res != nil {
612                         res.Body.Close()
613                 }
614         }
615
616         select {
617         case v := <-errc:
618                 if !strings.Contains(v, "TLS handshake error") {
619                         t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
620                 }
621         case <-time.After(5 * time.Second):
622                 t.Errorf("timeout waiting for logged error")
623         }
624
625 }
626
627 func TestClientErrorWithRequestURI(t *testing.T) {
628         defer afterTest(t)
629         req, _ := NewRequest("GET", "http://localhost:1234/", nil)
630         req.RequestURI = "/this/field/is/illegal/and/should/error/"
631         _, err := DefaultClient.Do(req)
632         if err == nil {
633                 t.Fatalf("expected an error")
634         }
635         if !strings.Contains(err.Error(), "RequestURI") {
636                 t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
637         }
638 }
639
640 func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
641         certs := x509.NewCertPool()
642         for _, c := range ts.TLS.Certificates {
643                 roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
644                 if err != nil {
645                         t.Fatalf("error parsing server's root cert: %v", err)
646                 }
647                 for _, root := range roots {
648                         certs.AddCert(root)
649                 }
650         }
651         return &Transport{
652                 TLSClientConfig: &tls.Config{RootCAs: certs},
653         }
654 }
655
656 func TestClientWithCorrectTLSServerName(t *testing.T) {
657         defer afterTest(t)
658
659         const serverName = "example.com"
660         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
661                 if r.TLS.ServerName != serverName {
662                         t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName)
663                 }
664         }))
665         defer ts.Close()
666
667         trans := newTLSTransport(t, ts)
668         trans.TLSClientConfig.ServerName = serverName
669         c := &Client{Transport: trans}
670         if _, err := c.Get(ts.URL); err != nil {
671                 t.Fatalf("expected successful TLS connection, got error: %v", err)
672         }
673 }
674
675 func TestClientWithIncorrectTLSServerName(t *testing.T) {
676         defer afterTest(t)
677         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
678         defer ts.Close()
679         errc := make(chanWriter, 10) // but only expecting 1
680         ts.Config.ErrorLog = log.New(errc, "", 0)
681
682         trans := newTLSTransport(t, ts)
683         trans.TLSClientConfig.ServerName = "badserver"
684         c := &Client{Transport: trans}
685         _, err := c.Get(ts.URL)
686         if err == nil {
687                 t.Fatalf("expected an error")
688         }
689         if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
690                 t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
691         }
692         select {
693         case v := <-errc:
694                 if !strings.Contains(v, "TLS handshake error") {
695                         t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
696                 }
697         case <-time.After(5 * time.Second):
698                 t.Errorf("timeout waiting for logged error")
699         }
700 }
701
702 // Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName
703 // when not empty.
704 //
705 // tls.Config.ServerName (non-empty, set to "example.com") takes
706 // precedence over "some-other-host.tld" which previously incorrectly
707 // took precedence. We don't actually connect to (or even resolve)
708 // "some-other-host.tld", though, because of the Transport.Dial hook.
709 //
710 // The httptest.Server has a cert with "example.com" as its name.
711 func TestTransportUsesTLSConfigServerName(t *testing.T) {
712         defer afterTest(t)
713         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
714                 w.Write([]byte("Hello"))
715         }))
716         defer ts.Close()
717
718         tr := newTLSTransport(t, ts)
719         tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names
720         tr.Dial = func(netw, addr string) (net.Conn, error) {
721                 return net.Dial(netw, ts.Listener.Addr().String())
722         }
723         defer tr.CloseIdleConnections()
724         c := &Client{Transport: tr}
725         res, err := c.Get("https://some-other-host.tld/")
726         if err != nil {
727                 t.Fatal(err)
728         }
729         res.Body.Close()
730 }
731
732 func TestResponseSetsTLSConnectionState(t *testing.T) {
733         defer afterTest(t)
734         ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
735                 w.Write([]byte("Hello"))
736         }))
737         defer ts.Close()
738
739         tr := newTLSTransport(t, ts)
740         tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA}
741         tr.Dial = func(netw, addr string) (net.Conn, error) {
742                 return net.Dial(netw, ts.Listener.Addr().String())
743         }
744         defer tr.CloseIdleConnections()
745         c := &Client{Transport: tr}
746         res, err := c.Get("https://example.com/")
747         if err != nil {
748                 t.Fatal(err)
749         }
750         defer res.Body.Close()
751         if res.TLS == nil {
752                 t.Fatal("Response didn't set TLS Connection State.")
753         }
754         if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want {
755                 t.Errorf("TLS Cipher Suite = %d; want %d", got, want)
756         }
757 }
758
759 // Check that an HTTPS client can interpret a particular TLS error
760 // to determine that the server is speaking HTTP.
761 // See golang.org/issue/11111.
762 func TestHTTPSClientDetectsHTTPServer(t *testing.T) {
763         defer afterTest(t)
764         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
765         defer ts.Close()
766
767         _, err := Get(strings.Replace(ts.URL, "http", "https", 1))
768         if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") {
769                 t.Fatalf("error = %q; want error indicating HTTP response to HTTPS request", got)
770         }
771 }
772
773 // Verify Response.ContentLength is populated. https://golang.org/issue/4126
774 func TestClientHeadContentLength_h1(t *testing.T) {
775         testClientHeadContentLength(t, false)
776 }
777
778 func TestClientHeadContentLength_h2(t *testing.T) {
779         testClientHeadContentLength(t, true)
780 }
781
782 func testClientHeadContentLength(t *testing.T, h2 bool) {
783         defer afterTest(t)
784         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
785                 if v := r.FormValue("cl"); v != "" {
786                         w.Header().Set("Content-Length", v)
787                 }
788         }))
789         defer cst.close()
790         tests := []struct {
791                 suffix string
792                 want   int64
793         }{
794                 {"/?cl=1234", 1234},
795                 {"/?cl=0", 0},
796                 {"", -1},
797         }
798         for _, tt := range tests {
799                 req, _ := NewRequest("HEAD", cst.ts.URL+tt.suffix, nil)
800                 res, err := cst.c.Do(req)
801                 if err != nil {
802                         t.Fatal(err)
803                 }
804                 if res.ContentLength != tt.want {
805                         t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
806                 }
807                 bs, err := ioutil.ReadAll(res.Body)
808                 if err != nil {
809                         t.Fatal(err)
810                 }
811                 if len(bs) != 0 {
812                         t.Errorf("Unexpected content: %q", bs)
813                 }
814         }
815 }
816
817 func TestEmptyPasswordAuth(t *testing.T) {
818         defer afterTest(t)
819         gopher := "gopher"
820         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
821                 auth := r.Header.Get("Authorization")
822                 if strings.HasPrefix(auth, "Basic ") {
823                         encoded := auth[6:]
824                         decoded, err := base64.StdEncoding.DecodeString(encoded)
825                         if err != nil {
826                                 t.Fatal(err)
827                         }
828                         expected := gopher + ":"
829                         s := string(decoded)
830                         if expected != s {
831                                 t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
832                         }
833                 } else {
834                         t.Errorf("Invalid auth %q", auth)
835                 }
836         }))
837         defer ts.Close()
838         c := &Client{}
839         req, err := NewRequest("GET", ts.URL, nil)
840         if err != nil {
841                 t.Fatal(err)
842         }
843         req.URL.User = url.User(gopher)
844         resp, err := c.Do(req)
845         if err != nil {
846                 t.Fatal(err)
847         }
848         defer resp.Body.Close()
849 }
850
851 func TestBasicAuth(t *testing.T) {
852         defer afterTest(t)
853         tr := &recordingTransport{}
854         client := &Client{Transport: tr}
855
856         url := "http://My%20User:My%20Pass@dummy.faketld/"
857         expected := "My User:My Pass"
858         client.Get(url)
859
860         if tr.req.Method != "GET" {
861                 t.Errorf("got method %q, want %q", tr.req.Method, "GET")
862         }
863         if tr.req.URL.String() != url {
864                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
865         }
866         if tr.req.Header == nil {
867                 t.Fatalf("expected non-nil request Header")
868         }
869         auth := tr.req.Header.Get("Authorization")
870         if strings.HasPrefix(auth, "Basic ") {
871                 encoded := auth[6:]
872                 decoded, err := base64.StdEncoding.DecodeString(encoded)
873                 if err != nil {
874                         t.Fatal(err)
875                 }
876                 s := string(decoded)
877                 if expected != s {
878                         t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
879                 }
880         } else {
881                 t.Errorf("Invalid auth %q", auth)
882         }
883 }
884
885 func TestBasicAuthHeadersPreserved(t *testing.T) {
886         defer afterTest(t)
887         tr := &recordingTransport{}
888         client := &Client{Transport: tr}
889
890         // If Authorization header is provided, username in URL should not override it
891         url := "http://My%20User@dummy.faketld/"
892         req, err := NewRequest("GET", url, nil)
893         if err != nil {
894                 t.Fatal(err)
895         }
896         req.SetBasicAuth("My User", "My Pass")
897         expected := "My User:My Pass"
898         client.Do(req)
899
900         if tr.req.Method != "GET" {
901                 t.Errorf("got method %q, want %q", tr.req.Method, "GET")
902         }
903         if tr.req.URL.String() != url {
904                 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
905         }
906         if tr.req.Header == nil {
907                 t.Fatalf("expected non-nil request Header")
908         }
909         auth := tr.req.Header.Get("Authorization")
910         if strings.HasPrefix(auth, "Basic ") {
911                 encoded := auth[6:]
912                 decoded, err := base64.StdEncoding.DecodeString(encoded)
913                 if err != nil {
914                         t.Fatal(err)
915                 }
916                 s := string(decoded)
917                 if expected != s {
918                         t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
919                 }
920         } else {
921                 t.Errorf("Invalid auth %q", auth)
922         }
923
924 }
925
926 func TestClientTimeout(t *testing.T) {
927         if testing.Short() {
928                 t.Skip("skipping in short mode")
929         }
930         defer afterTest(t)
931         sawRoot := make(chan bool, 1)
932         sawSlow := make(chan bool, 1)
933         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
934                 if r.URL.Path == "/" {
935                         sawRoot <- true
936                         Redirect(w, r, "/slow", StatusFound)
937                         return
938                 }
939                 if r.URL.Path == "/slow" {
940                         w.Write([]byte("Hello"))
941                         w.(Flusher).Flush()
942                         sawSlow <- true
943                         time.Sleep(2 * time.Second)
944                         return
945                 }
946         }))
947         defer ts.Close()
948         const timeout = 500 * time.Millisecond
949         c := &Client{
950                 Timeout: timeout,
951         }
952
953         res, err := c.Get(ts.URL)
954         if err != nil {
955                 t.Fatal(err)
956         }
957
958         select {
959         case <-sawRoot:
960                 // good.
961         default:
962                 t.Fatal("handler never got / request")
963         }
964
965         select {
966         case <-sawSlow:
967                 // good.
968         default:
969                 t.Fatal("handler never got /slow request")
970         }
971
972         errc := make(chan error, 1)
973         go func() {
974                 _, err := ioutil.ReadAll(res.Body)
975                 errc <- err
976                 res.Body.Close()
977         }()
978
979         const failTime = timeout * 2
980         select {
981         case err := <-errc:
982                 if err == nil {
983                         t.Fatal("expected error from ReadAll")
984                 }
985                 ne, ok := err.(net.Error)
986                 if !ok {
987                         t.Errorf("error value from ReadAll was %T; expected some net.Error", err)
988                 } else if !ne.Timeout() {
989                         t.Errorf("net.Error.Timeout = false; want true")
990                 }
991                 if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
992                         t.Errorf("error string = %q; missing timeout substring", got)
993                 }
994         case <-time.After(failTime):
995                 t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
996         }
997 }
998
999 // Client.Timeout firing before getting to the body
1000 func TestClientTimeout_Headers(t *testing.T) {
1001         if testing.Short() {
1002                 t.Skip("skipping in short mode")
1003         }
1004         defer afterTest(t)
1005         donec := make(chan bool)
1006         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1007                 <-donec
1008         }))
1009         defer ts.Close()
1010         // Note that we use a channel send here and not a close.
1011         // The race detector doesn't know that we're waiting for a timeout
1012         // and thinks that the waitgroup inside httptest.Server is added to concurrently
1013         // with us closing it. If we timed out immediately, we could close the testserver
1014         // before we entered the handler. We're not timing out immediately and there's
1015         // no way we would be done before we entered the handler, but the race detector
1016         // doesn't know this, so synchronize explicitly.
1017         defer func() { donec <- true }()
1018
1019         c := &Client{Timeout: 500 * time.Millisecond}
1020
1021         _, err := c.Get(ts.URL)
1022         if err == nil {
1023                 t.Fatal("got response from Get; expected error")
1024         }
1025         if _, ok := err.(*url.Error); !ok {
1026                 t.Fatalf("Got error of type %T; want *url.Error", err)
1027         }
1028         ne, ok := err.(net.Error)
1029         if !ok {
1030                 t.Fatalf("Got error of type %T; want some net.Error", err)
1031         }
1032         if !ne.Timeout() {
1033                 t.Error("net.Error.Timeout = false; want true")
1034         }
1035         if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
1036                 t.Errorf("error string = %q; missing timeout substring", got)
1037         }
1038 }
1039
1040 func TestClientRedirectEatsBody_h1(t *testing.T) {
1041         testClientRedirectEatsBody(t, false)
1042 }
1043
1044 func TestClientRedirectEatsBody_h2(t *testing.T) {
1045         testClientRedirectEatsBody(t, true)
1046 }
1047
1048 func testClientRedirectEatsBody(t *testing.T, h2 bool) {
1049         defer afterTest(t)
1050         saw := make(chan string, 2)
1051         cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1052                 saw <- r.RemoteAddr
1053                 if r.URL.Path == "/" {
1054                         Redirect(w, r, "/foo", StatusFound) // which includes a body
1055                 }
1056         }))
1057         defer cst.close()
1058
1059         res, err := cst.c.Get(cst.ts.URL)
1060         if err != nil {
1061                 t.Fatal(err)
1062         }
1063         _, err = ioutil.ReadAll(res.Body)
1064         if err != nil {
1065                 t.Fatal(err)
1066         }
1067         res.Body.Close()
1068
1069         var first string
1070         select {
1071         case first = <-saw:
1072         default:
1073                 t.Fatal("server didn't see a request")
1074         }
1075
1076         var second string
1077         select {
1078         case second = <-saw:
1079         default:
1080                 t.Fatal("server didn't see a second request")
1081         }
1082
1083         if first != second {
1084                 t.Fatal("server saw different client ports before & after the redirect")
1085         }
1086 }
1087
1088 // eofReaderFunc is an io.Reader that runs itself, and then returns io.EOF.
1089 type eofReaderFunc func()
1090
1091 func (f eofReaderFunc) Read(p []byte) (n int, err error) {
1092         f()
1093         return 0, io.EOF
1094 }
1095
1096 func TestClientTrailers(t *testing.T) {
1097         defer afterTest(t)
1098         ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1099                 w.Header().Set("Connection", "close")
1100                 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
1101                 w.Header().Add("Trailer", "Server-Trailer-C")
1102
1103                 var decl []string
1104                 for k := range r.Trailer {
1105                         decl = append(decl, k)
1106                 }
1107                 sort.Strings(decl)
1108
1109                 slurp, err := ioutil.ReadAll(r.Body)
1110                 if err != nil {
1111                         t.Errorf("Server reading request body: %v", err)
1112                 }
1113                 if string(slurp) != "foo" {
1114                         t.Errorf("Server read request body %q; want foo", slurp)
1115                 }
1116                 if r.Trailer == nil {
1117                         io.WriteString(w, "nil Trailer")
1118                 } else {
1119                         fmt.Fprintf(w, "decl: %v, vals: %s, %s",
1120                                 decl,
1121                                 r.Trailer.Get("Client-Trailer-A"),
1122                                 r.Trailer.Get("Client-Trailer-B"))
1123                 }
1124
1125                 // How handlers set Trailers: declare it ahead of time
1126                 // with the Trailer header, and then mutate the
1127                 // Header() of those values later, after the response
1128                 // has been written (we wrote to w above).
1129                 w.Header().Set("Server-Trailer-A", "valuea")
1130                 w.Header().Set("Server-Trailer-C", "valuec") // skipping B
1131         }))
1132         defer ts.Close()
1133
1134         var req *Request
1135         req, _ = NewRequest("POST", ts.URL, io.MultiReader(
1136                 eofReaderFunc(func() {
1137                         req.Trailer["Client-Trailer-A"] = []string{"valuea"}
1138                 }),
1139                 strings.NewReader("foo"),
1140                 eofReaderFunc(func() {
1141                         req.Trailer["Client-Trailer-B"] = []string{"valueb"}
1142                 }),
1143         ))
1144         req.Trailer = Header{
1145                 "Client-Trailer-A": nil, //  to be set later
1146                 "Client-Trailer-B": nil, //  to be set later
1147         }
1148         req.ContentLength = -1
1149         res, err := DefaultClient.Do(req)
1150         if err != nil {
1151                 t.Fatal(err)
1152         }
1153         if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
1154                 t.Error(err)
1155         }
1156         want := Header{
1157                 "Server-Trailer-A": []string{"valuea"},
1158                 "Server-Trailer-B": nil,
1159                 "Server-Trailer-C": []string{"valuec"},
1160         }
1161         if !reflect.DeepEqual(res.Trailer, want) {
1162                 t.Errorf("Response trailers = %#v; want %#v", res.Trailer, want)
1163         }
1164 }
1165
1166 func TestReferer(t *testing.T) {
1167         tests := []struct {
1168                 lastReq, newReq string // from -> to URLs
1169                 want            string
1170         }{
1171                 // don't send user:
1172                 {"http://gopher@test.com", "http://link.com", "http://test.com"},
1173                 {"https://gopher@test.com", "https://link.com", "https://test.com"},
1174
1175                 // don't send a user and password:
1176                 {"http://gopher:go@test.com", "http://link.com", "http://test.com"},
1177                 {"https://gopher:go@test.com", "https://link.com", "https://test.com"},
1178
1179                 // nothing to do:
1180                 {"http://test.com", "http://link.com", "http://test.com"},
1181                 {"https://test.com", "https://link.com", "https://test.com"},
1182
1183                 // https to http doesn't send a referer:
1184                 {"https://test.com", "http://link.com", ""},
1185                 {"https://gopher:go@test.com", "http://link.com", ""},
1186         }
1187         for _, tt := range tests {
1188                 l, err := url.Parse(tt.lastReq)
1189                 if err != nil {
1190                         t.Fatal(err)
1191                 }
1192                 n, err := url.Parse(tt.newReq)
1193                 if err != nil {
1194                         t.Fatal(err)
1195                 }
1196                 r := ExportRefererForURL(l, n)
1197                 if r != tt.want {
1198                         t.Errorf("refererForURL(%q, %q) = %q; want %q", tt.lastReq, tt.newReq, r, tt.want)
1199                 }
1200         }
1201 }