]> Cypherpunks.ru repositories - gostls13.git/commitdiff
net/http: handle body rewind in HTTP/2 connection loss better
authorRuss Cox <rsc@golang.org>
Fri, 22 May 2020 00:46:05 +0000 (20:46 -0400)
committerRuss Cox <rsc@golang.org>
Wed, 27 May 2020 16:56:56 +0000 (16:56 +0000)
In certain cases the HTTP/2 stack needs to resend a request.
It obtains a fresh body to send by calling req.GetBody.
This call was missing from the path where the HTTP/2
round tripper returns ErrSkipAltProtocol, meaning fall back
to HTTP/1.1. The result was that the HTTP/1.1 fallback
request was sent with no body at all.

This CL changes that code path to rewind the body before
falling back to HTTP/1.1. But rewinding the body is easier
said than done. Some requests have no GetBody function,
meaning the body can't be rewound. If we need to rewind and
can't, that's an error. But if we didn't read anything, we don't
need to rewind. So we have to track whether we read anything,
with a new ReadCloser wrapper. That in turn requires adding
to the couple places that unwrap Body values to look at the
underlying implementation.

This CL adds the new rewinding code in the main retry loop
as well.

The new rewindBody function also takes care of closing the
old body before abandoning it. That was missing in the old
rewind code.

Thanks to Aleksandr Razumov for CL 210123
and to Jun Chen for CL 234358, both of which informed
this CL.

Fixes #32441.

Change-Id: Id183758526c087c6b179ab73cf3b61ed23a2a46a
Reviewed-on: https://go-review.googlesource.com/c/go/+/234894
Run-TryBot: Russ Cox <rsc@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Bryan C. Mills <bcmills@google.com>
src/net/http/transfer.go
src/net/http/transport.go
src/net/http/transport_test.go

index 6d5ea05c32aa79eddd3e4210f6dfde491ef4bac1..9019afb61d442216dc5e1c6b1fa734dfe1108e9e 100644 (file)
@@ -335,7 +335,7 @@ func (t *transferWriter) writeBody(w io.Writer) error {
        var ncopy int64
 
        // Write body. We "unwrap" the body first if it was wrapped in a
-       // nopCloser. This is to ensure that we can take advantage of
+       // nopCloser or readTrackingBody. This is to ensure that we can take advantage of
        // OS-level optimizations in the event that the body is an
        // *os.File.
        if t.Body != nil {
@@ -413,7 +413,10 @@ func (t *transferWriter) unwrapBody() io.Reader {
        if reflect.TypeOf(t.Body) == nopCloserType {
                return reflect.ValueOf(t.Body).Field(0).Interface().(io.Reader)
        }
-
+       if r, ok := t.Body.(*readTrackingBody); ok {
+               r.didRead = true
+               return r.ReadCloser
+       }
        return t.Body
 }
 
@@ -1075,6 +1078,9 @@ func isKnownInMemoryReader(r io.Reader) bool {
        if reflect.TypeOf(r) == nopCloserType {
                return isKnownInMemoryReader(reflect.ValueOf(r).Field(0).Interface().(io.Reader))
        }
+       if r, ok := r.(*readTrackingBody); ok {
+               return isKnownInMemoryReader(r.ReadCloser)
+       }
        return false
 }
 
index b1705d54396ade74ead49627d0760f69e3999c8a..da86b26106b85dee72a4d0e55e30452a705762bc 100644 (file)
@@ -511,10 +511,17 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                }
        }
 
+       req = setupRewindBody(req)
+
        if altRT := t.alternateRoundTripper(req); altRT != nil {
                if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
                        return resp, err
                }
+               var err error
+               req, err = rewindBody(req)
+               if err != nil {
+                       return nil, err
+               }
        }
        if !isHTTP {
                req.closeBody()
@@ -584,18 +591,59 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                testHookRoundTripRetried()
 
                // Rewind the body if we're able to.
-               if req.GetBody != nil {
-                       newReq := *req
-                       var err error
-                       newReq.Body, err = req.GetBody()
-                       if err != nil {
-                               return nil, err
-                       }
-                       req = &newReq
+               req, err = rewindBody(req)
+               if err != nil {
+                       return nil, err
                }
        }
 }
 
+var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss")
+
+type readTrackingBody struct {
+       io.ReadCloser
+       didRead bool
+}
+
+func (r *readTrackingBody) Read(data []byte) (int, error) {
+       r.didRead = true
+       return r.ReadCloser.Read(data)
+}
+
+// setupRewindBody returns a new request with a custom body wrapper
+// that can report whether the body needs rewinding.
+// This lets rewindBody avoid an error result when the request
+// does not have GetBody but the body hasn't been read at all yet.
+func setupRewindBody(req *Request) *Request {
+       if req.Body == nil || req.Body == NoBody {
+               return req
+       }
+       newReq := *req
+       newReq.Body = &readTrackingBody{ReadCloser: req.Body}
+       return &newReq
+}
+
+// rewindBody returns a new request with the body rewound.
+// It returns req unmodified if the body does not need rewinding.
+// rewindBody takes care of closing req.Body when appropriate
+// (in all cases except when rewindBody returns req unmodified).
+func rewindBody(req *Request) (rewound *Request, err error) {
+       if req.Body == nil || req.Body == NoBody || !req.Body.(*readTrackingBody).didRead {
+               return req, nil // nothing to rewind
+       }
+       req.closeBody()
+       if req.GetBody == nil {
+               return nil, errCannotRewind
+       }
+       body, err := req.GetBody()
+       if err != nil {
+               return nil, err
+       }
+       newReq := *req
+       newReq.Body = &readTrackingBody{ReadCloser: body}
+       return &newReq, nil
+}
+
 // shouldRetryRequest reports whether we should retry sending a failed
 // HTTP request on a new connection. The non-nil input error is the
 // error from roundTrip.
index f4014d95bb4c615e1ebfbd22e77124acf1892832..5ccb3d14abd2d0145788e41caebfaae64f65112d 100644 (file)
@@ -6196,3 +6196,29 @@ func (timeoutProto) RoundTrip(req *Request) (*Response, error) {
                return nil, errors.New("request was not canceled")
        }
 }
+
+type roundTripFunc func(r *Request) (*Response, error)
+
+func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
+
+// Issue 32441: body is not reset after ErrSkipAltProtocol
+func TestIssue32441(t *testing.T) {
+       defer afterTest(t)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 {
+                       t.Error("body length is zero")
+               }
+       }))
+       defer ts.Close()
+       c := ts.Client()
+       c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
+               // Draining body to trigger failure condition on actual request to server.
+               if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 {
+                       t.Error("body length is zero during round trip")
+               }
+               return nil, ErrSkipAltProtocol
+       }))
+       if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
+               t.Error(err)
+       }
+}