var ncopy int64
// Write body. We "unwrap" the body first if it was wrapped in a
- // nopCloser. This is to ensure that we can take advantage of
+ // nopCloser or readTrackingBody. This is to ensure that we can take advantage of
// OS-level optimizations in the event that the body is an
// *os.File.
if t.Body != nil {
if reflect.TypeOf(t.Body) == nopCloserType {
return reflect.ValueOf(t.Body).Field(0).Interface().(io.Reader)
}
-
+ if r, ok := t.Body.(*readTrackingBody); ok {
+ r.didRead = true
+ return r.ReadCloser
+ }
return t.Body
}
if reflect.TypeOf(r) == nopCloserType {
return isKnownInMemoryReader(reflect.ValueOf(r).Field(0).Interface().(io.Reader))
}
+ if r, ok := r.(*readTrackingBody); ok {
+ return isKnownInMemoryReader(r.ReadCloser)
+ }
return false
}
}
}
+ req = setupRewindBody(req)
+
if altRT := t.alternateRoundTripper(req); altRT != nil {
if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
return resp, err
}
+ var err error
+ req, err = rewindBody(req)
+ if err != nil {
+ return nil, err
+ }
}
if !isHTTP {
req.closeBody()
testHookRoundTripRetried()
// Rewind the body if we're able to.
- if req.GetBody != nil {
- newReq := *req
- var err error
- newReq.Body, err = req.GetBody()
- if err != nil {
- return nil, err
- }
- req = &newReq
+ req, err = rewindBody(req)
+ if err != nil {
+ return nil, err
}
}
}
+var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss")
+
+type readTrackingBody struct {
+ io.ReadCloser
+ didRead bool
+}
+
+func (r *readTrackingBody) Read(data []byte) (int, error) {
+ r.didRead = true
+ return r.ReadCloser.Read(data)
+}
+
+// setupRewindBody returns a new request with a custom body wrapper
+// that can report whether the body needs rewinding.
+// This lets rewindBody avoid an error result when the request
+// does not have GetBody but the body hasn't been read at all yet.
+func setupRewindBody(req *Request) *Request {
+ if req.Body == nil || req.Body == NoBody {
+ return req
+ }
+ newReq := *req
+ newReq.Body = &readTrackingBody{ReadCloser: req.Body}
+ return &newReq
+}
+
+// rewindBody returns a new request with the body rewound.
+// It returns req unmodified if the body does not need rewinding.
+// rewindBody takes care of closing req.Body when appropriate
+// (in all cases except when rewindBody returns req unmodified).
+func rewindBody(req *Request) (rewound *Request, err error) {
+ if req.Body == nil || req.Body == NoBody || !req.Body.(*readTrackingBody).didRead {
+ return req, nil // nothing to rewind
+ }
+ req.closeBody()
+ if req.GetBody == nil {
+ return nil, errCannotRewind
+ }
+ body, err := req.GetBody()
+ if err != nil {
+ return nil, err
+ }
+ newReq := *req
+ newReq.Body = &readTrackingBody{ReadCloser: body}
+ return &newReq, nil
+}
+
// shouldRetryRequest reports whether we should retry sending a failed
// HTTP request on a new connection. The non-nil input error is the
// error from roundTrip.
return nil, errors.New("request was not canceled")
}
}
+
+type roundTripFunc func(r *Request) (*Response, error)
+
+func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
+
+// Issue 32441: body is not reset after ErrSkipAltProtocol
+func TestIssue32441(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 {
+ t.Error("body length is zero")
+ }
+ }))
+ defer ts.Close()
+ c := ts.Client()
+ c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
+ // Draining body to trigger failure condition on actual request to server.
+ if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 {
+ t.Error("body length is zero during round trip")
+ }
+ return nil, ErrSkipAltProtocol
+ }))
+ if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
+ t.Error(err)
+ }
+}