import (
"bufio"
+ "bytes"
"io"
"io/ioutil"
"log"
if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
}
-
}
func TestXForwardedFor(t *testing.T) {
t.Errorf("Log events = %q; want %q", log, wantLog)
}
}
+
+func TestReverseProxy_Post(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 200
+ var requestBody = bytes.Repeat([]byte("a"), 1<<20)
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ slurp, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ t.Error("Backend body read = %v", err)
+ }
+ if len(slurp) != len(requestBody) {
+ t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
+ }
+ if !bytes.Equal(slurp, requestBody) {
+ t.Error("Backend read wrong request body.") // 1MB; omitting details
+ }
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
+ res, err := http.DefaultClient.Do(postReq)
+ if err != nil {
+ t.Fatalf("Do: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
}
// ReadRequest reads and parses an incoming request from b.
-func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, true) }
+func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, deleteHostHeader) }
+
+// Constants for readRequest's deleteHostHeader parameter.
+const (
+ deleteHostHeader = true
+ keepHostHeader = false
+)
func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err error) {
tp := newTextprotoReader(b)
}
}
+func TestHijackBeforeRequestBodyRead(t *testing.T) {
+ defer afterTest(t)
+ var requestBody = bytes.Repeat([]byte("a"), 1<<20)
+ bodyOkay := make(chan bool, 1)
+ gotCloseNotify := make(chan bool, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ defer close(bodyOkay) // caller will read false if nothing else
+
+ reqBody := r.Body
+ r.Body = nil // to test that server.go doesn't use this value.
+
+ gone := w.(CloseNotifier).CloseNotify()
+ slurp, err := ioutil.ReadAll(reqBody)
+ if err != nil {
+ t.Error("Body read: %v", err)
+ return
+ }
+ if len(slurp) != len(requestBody) {
+ t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
+ return
+ }
+ if !bytes.Equal(slurp, requestBody) {
+ t.Error("Backend read wrong request body.") // 1MB; omitting details
+ return
+ }
+ bodyOkay <- true
+ select {
+ case <-gone:
+ gotCloseNotify <- true
+ case <-time.After(5 * time.Second):
+ gotCloseNotify <- false
+ }
+ }))
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
+ len(requestBody), requestBody)
+ if !<-bodyOkay {
+ // already failed.
+ return
+ }
+ conn.Close()
+ if !<-gotCloseNotify {
+ t.Error("timeout waiting for CloseNotify")
+ }
+}
+
func TestOptions(t *testing.T) {
uric := make(chan string, 2) // only expect 1, but leave space for 2
mux := NewServeMux()
// single value (true) when the client connection has gone
// away.
//
- // CloseNotify is undefined before Request.Body has been
+ // CloseNotify may wait to notify until Request.Body has been
// fully read.
//
// After the Handler has returned, there is no guarantee
// If the protocol is HTTP/1.1 and CloseNotify is called while
// processing an idempotent request (such a GET) while
// HTTP/1.1 pipelining is in use, the arrival of a subsequent
- // pipelined request will cause a value to be sent on the
+ // pipelined request may cause a value to be sent on the
// returned channel. In practice HTTP/1.1 pipelining is not
// enabled in browsers and not seen often in the wild. If this
// is a problem, use HTTP/2 or only use CloseNotify on methods
dateBuf [len(TimeFormat)]byte
clenBuf [10]byte
- closeNotifyCh <-chan bool // guarded by conn.mu
+ // closeNotifyCh is non-nil once CloseNotify is called.
+ // Guarded by conn.mu
+ closeNotifyCh <-chan bool
}
// declareTrailer is called for each Trailer header when the
peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
c.bufr.Discard(numLeadingCRorLF(peek))
}
- req, err := readRequest(c.bufr, false)
+ req, err := readRequest(c.bufr, keepHostHeader)
c.mu.Unlock()
if err != nil {
if c.r.hitReadLimit() {
}
if discard {
- _, err := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1)
+ _, err := io.CopyN(ioutil.Discard, w.reqBody, maxPostHandlerReadBytes+1)
switch err {
case nil:
// There must be even more data left over.
// Body was already consumed and closed.
case io.EOF:
// The remaining body was just consumed, close it.
- err = w.req.Body.Close()
+ err = w.reqBody.Close()
if err != nil {
w.closeAfterReply = true
}
var once sync.Once
notify := func() { once.Do(func() { ch <- true }) }
+ if requestBodyRemains(w.reqBody) {
+ // They're still consuming the request body, so we
+ // shouldn't notify yet.
+ registerOnHitEOF(w.reqBody, func() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ startCloseNotifyBackgroundRead(c, notify)
+ })
+ } else {
+ startCloseNotifyBackgroundRead(c, notify)
+ }
+ return ch
+}
+
+// c.mu must be held.
+func startCloseNotifyBackgroundRead(c *conn, notify func()) {
if c.bufr.Buffered() > 0 {
- // A pipelined request or unread request body data is available
- // unread. Per the CloseNotifier docs, fire immediately.
+ // They've consumed the request body, so anything
+ // remaining is a pipelined request, which we
+ // document as firing on.
notify()
} else {
c.r.startBackgroundRead(notify)
}
- return ch
+}
+
+func registerOnHitEOF(rc io.ReadCloser, fn func()) {
+ switch v := rc.(type) {
+ case *expectContinueReader:
+ registerOnHitEOF(v.readCloser, fn)
+ case *body:
+ v.registerOnHitEOF(fn)
+ default:
+ panic("unexpected type " + fmt.Sprintf("%T", rc))
+ }
+}
+
+// requestBodyRemains reports whether future calls to Read
+// on rc might yield more data.
+func requestBodyRemains(rc io.ReadCloser) bool {
+ if rc == eofReader {
+ return false
+ }
+ switch v := rc.(type) {
+ case *expectContinueReader:
+ return requestBodyRemains(v.readCloser)
+ case *body:
+ return v.bodyRemains()
+ default:
+ panic("unexpected type " + fmt.Sprintf("%T", rc))
+ }
}
// The HandlerFunc type is an adapter to allow the use of
closing bool // is the connection to be closed after reading body?
doEarlyClose bool // whether Close should stop early
- mu sync.Mutex // guards closed, and calls to Read and Close
+ mu sync.Mutex // guards following, and calls to Read and Close
sawEOF bool
closed bool
- earlyClose bool // Close called and we didn't read to the end of src
+ earlyClose bool // Close called and we didn't read to the end of src
+ onHitEOF func() // if non-nil, func to call when EOF is Read
}
// ErrBodyReadAfterClose is returned when reading a Request or Response
}
}
+ if b.sawEOF && b.onHitEOF != nil {
+ b.onHitEOF()
+ }
+
return n, err
}
return b.earlyClose
}
+// bodyRemains reports whether future Read calls might
+// yield data.
+func (b *body) bodyRemains() bool {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ return !b.sawEOF
+}
+
+func (b *body) registerOnHitEOF(fn func()) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ b.onHitEOF = fn
+}
+
// bodyLocked is a io.Reader reading from a *body when its mutex is
// already held.
type bodyLocked struct {