return len(p), nil
}
-type countReader struct {
- r io.Reader
- n *int64
+type bodyLimitReader struct {
+ mu sync.Mutex
+ count int
+ limit int
+ closed chan struct{}
}
-func (cr countReader) Read(p []byte) (n int, err error) {
- n, err = cr.r.Read(p)
- atomic.AddInt64(cr.n, int64(n))
- return
+func (r *bodyLimitReader) Read(p []byte) (int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ select {
+ case <-r.closed:
+ return 0, errors.New("closed")
+ default:
+ }
+ if r.count > r.limit {
+ return 0, errors.New("at limit")
+ }
+ r.count += len(p)
+ for i := range p {
+ p[i] = 'a'
+ }
+ return len(p), nil
+}
+
+func (r *bodyLimitReader) Close() error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ close(r.closed)
+ return nil
}
func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
}
}))
- nWritten := new(int64)
- req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
+ body := &bodyLimitReader{
+ closed: make(chan struct{}),
+ limit: limit * 200,
+ }
+ req, _ := NewRequest("POST", cst.ts.URL, body)
// Send the POST, but don't care it succeeds or not. The
// remote side is going to reply and then close the TCP
if err == nil {
resp.Body.Close()
}
+ // Wait for the Transport to finish writing the request body.
+ // It will close the body when done.
+ <-body.closed
- if atomic.LoadInt64(nWritten) > limit*100 {
+ if body.count > limit*100 {
t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
- limit, nWritten)
+ limit, body.count)
}
}