]> Cypherpunks.ru repositories - gostls13.git/commitdiff
context: add APIs for writing and reading cancelation cause
authorSameer Ajmani <sameer@golang.org>
Thu, 6 Jan 2022 18:57:05 +0000 (13:57 -0500)
committerGopher Robot <gobot@golang.org>
Tue, 8 Nov 2022 13:51:16 +0000 (13:51 +0000)
Extend the context package to allow users to specify why a context was
canceled in the form of an error, the "cause". Users write the cause
by calling WithCancelCause to construct a derived context, then
calling cancel(cause) to cancel the context with the provided cause.
Users retrieve the cause by calling context.Cause(ctx), which returns
the cause of the first cancelation for ctx or any of its parents.

The cause is implemented as a field of cancelCtx, since only cancelCtx
can be canceled. Calling cancel copies the cause to all derived (child)
cancelCtxs. Calling Cause(ctx) finds the nearest parent cancelCtx by
looking up the context value keyed by cancelCtxKey.

API changes:
+pkg context, func Cause(Context) error
+pkg context, func WithCancelCause(Context) (Context, CancelCauseFunc)
+pkg context, type CancelCauseFunc func(error)

Fixes #26356
Fixes #51365

Change-Id: I15b62bd454c014db3f4f1498b35204451509e641
Reviewed-on: https://go-review.googlesource.com/c/go/+/375977
Reviewed-by: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Sameer Ajmani <sameer@golang.org>
Auto-Submit: Sameer Ajmani <sameer@golang.org>

api/next/51365.txt [new file with mode: 0644]
src/context/context.go
src/context/context_test.go
src/context/x_test.go

diff --git a/api/next/51365.txt b/api/next/51365.txt
new file mode 100644 (file)
index 0000000..df629f1
--- /dev/null
@@ -0,0 +1,3 @@
+pkg context, func Cause(Context) error #51365
+pkg context, func WithCancelCause(Context) (Context, CancelCauseFunc) #51365
+pkg context, type CancelCauseFunc func(error) #51365
index 7eace57893fc6dd2b5599f7364efd9f8b62d540e..a0b5edc524a86a15ce917892879354df8e095fa7 100644 (file)
 // fires. The go vet tool checks that CancelFuncs are used on all
 // control-flow paths.
 //
+// The WithCancelCause function returns a CancelCauseFunc, which
+// takes an error and records it as the cancelation cause. Calling
+// Cause on the canceled context or any of its children retrieves
+// the cause. If no cause is specified, Cause(ctx) returns the same
+// value as ctx.Err().
+//
 // Programs that use Contexts should follow these rules to keep interfaces
 // consistent across packages and enable static analysis tools to check context
 // propagation:
@@ -230,17 +236,63 @@ type CancelFunc func()
 // Canceling this context releases resources associated with it, so code should
 // call cancel as soon as the operations running in this Context complete.
 func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
+       c := withCancel(parent)
+       return c, func() { c.cancel(true, Canceled, nil) }
+}
+
+// A CancelCauseFunc behaves like a CancelFunc but additionally sets the cancelation cause.
+// This cause can be retrieved by calling Cause on the canceled Context or on
+// any of its derived Contexts.
+//
+// If the context has already been canceled, CancelCauseFunc does not set the cause.
+// For example, if childContext is derived from parentContext:
+//   - if parentContext is canceled with cause1 before childContext is canceled with cause2,
+//     then Cause(parentContext) == Cause(childContext) == cause1
+//   - if childContext is canceled with cause2 before parentContext is canceled with cause1,
+//     then Cause(parentContext) == cause1 and Cause(childContext) == cause2
+type CancelCauseFunc func(cause error)
+
+// WithCancelCause behaves like WithCancel but returns a CancelCauseFunc instead of a CancelFunc.
+// Calling cancel with a non-nil error (the "cause") records that error in ctx;
+// it can then be retrieved using Cause(ctx).
+// Calling cancel with nil sets the cause to Canceled.
+//
+// Example use:
+//
+//     ctx, cancel := context.WithCancelCause(parent)
+//     cancel(myError)
+//     ctx.Err() // returns context.Canceled
+//     context.Cause(ctx) // returns myError
+func WithCancelCause(parent Context) (ctx Context, cancel CancelCauseFunc) {
+       c := withCancel(parent)
+       return c, func(cause error) { c.cancel(true, Canceled, cause) }
+}
+
+func withCancel(parent Context) *cancelCtx {
        if parent == nil {
                panic("cannot create context from nil parent")
        }
        c := newCancelCtx(parent)
-       propagateCancel(parent, &c)
-       return &c, func() { c.cancel(true, Canceled) }
+       propagateCancel(parent, c)
+       return c
+}
+
+// Cause returns a non-nil error explaining why c was canceled.
+// The first cancelation of c or one of its parents sets the cause.
+// If that cancelation happened via a call to CancelCauseFunc(err),
+// then Cause returns err.
+// Otherwise Cause(c) returns the same value as c.Err().
+// Cause returns nil if c has not been canceled yet.
+func Cause(c Context) error {
+       if cc, ok := c.Value(&cancelCtxKey).(*cancelCtx); ok {
+               return cc.cause
+       }
+       return nil
 }
 
 // newCancelCtx returns an initialized cancelCtx.
-func newCancelCtx(parent Context) cancelCtx {
-       return cancelCtx{Context: parent}
+func newCancelCtx(parent Context) *cancelCtx {
+       return &cancelCtx{Context: parent}
 }
 
 // goroutines counts the number of goroutines ever created; for testing.
@@ -256,7 +308,7 @@ func propagateCancel(parent Context, child canceler) {
        select {
        case <-done:
                // parent is already canceled
-               child.cancel(false, parent.Err())
+               child.cancel(false, parent.Err(), Cause(parent))
                return
        default:
        }
@@ -265,7 +317,7 @@ func propagateCancel(parent Context, child canceler) {
                p.mu.Lock()
                if p.err != nil {
                        // parent has already been canceled
-                       child.cancel(false, p.err)
+                       child.cancel(false, p.err, p.cause)
                } else {
                        if p.children == nil {
                                p.children = make(map[canceler]struct{})
@@ -278,7 +330,7 @@ func propagateCancel(parent Context, child canceler) {
                go func() {
                        select {
                        case <-parent.Done():
-                               child.cancel(false, parent.Err())
+                               child.cancel(false, parent.Err(), Cause(parent))
                        case <-child.Done():
                        }
                }()
@@ -326,7 +378,7 @@ func removeChild(parent Context, child canceler) {
 // A canceler is a context type that can be canceled directly. The
 // implementations are *cancelCtx and *timerCtx.
 type canceler interface {
-       cancel(removeFromParent bool, err error)
+       cancel(removeFromParent bool, err, cause error)
        Done() <-chan struct{}
 }
 
@@ -346,6 +398,7 @@ type cancelCtx struct {
        done     atomic.Value          // of chan struct{}, created lazily, closed by first cancel call
        children map[canceler]struct{} // set to nil by the first cancel call
        err      error                 // set to non-nil by the first cancel call
+       cause    error                 // set to non-nil by the first cancel call
 }
 
 func (c *cancelCtx) Value(key any) any {
@@ -394,16 +447,21 @@ func (c *cancelCtx) String() string {
 
 // cancel closes c.done, cancels each of c's children, and, if
 // removeFromParent is true, removes c from its parent's children.
-func (c *cancelCtx) cancel(removeFromParent bool, err error) {
+// cancel sets c.cause to cause if this is the first time c is canceled.
+func (c *cancelCtx) cancel(removeFromParent bool, err, cause error) {
        if err == nil {
                panic("context: internal error: missing cancel error")
        }
+       if cause == nil {
+               cause = err
+       }
        c.mu.Lock()
        if c.err != nil {
                c.mu.Unlock()
                return // already canceled
        }
        c.err = err
+       c.cause = cause
        d, _ := c.done.Load().(chan struct{})
        if d == nil {
                c.done.Store(closedchan)
@@ -412,7 +470,7 @@ func (c *cancelCtx) cancel(removeFromParent bool, err error) {
        }
        for child := range c.children {
                // NOTE: acquiring the child's lock while holding parent's lock.
-               child.cancel(false, err)
+               child.cancel(false, err, cause)
        }
        c.children = nil
        c.mu.Unlock()
@@ -446,24 +504,24 @@ func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
        propagateCancel(parent, c)
        dur := time.Until(d)
        if dur <= 0 {
-               c.cancel(true, DeadlineExceeded) // deadline has already passed
-               return c, func() { c.cancel(false, Canceled) }
+               c.cancel(true, DeadlineExceeded, nil) // deadline has already passed
+               return c, func() { c.cancel(false, Canceled, nil) }
        }
        c.mu.Lock()
        defer c.mu.Unlock()
        if c.err == nil {
                c.timer = time.AfterFunc(dur, func() {
-                       c.cancel(true, DeadlineExceeded)
+                       c.cancel(true, DeadlineExceeded, nil)
                })
        }
-       return c, func() { c.cancel(true, Canceled) }
+       return c, func() { c.cancel(true, Canceled, nil) }
 }
 
 // A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
 // implement Done and Err. It implements cancel by stopping its timer then
 // delegating to cancelCtx.cancel.
 type timerCtx struct {
-       cancelCtx
+       *cancelCtx
        timer *time.Timer // Under cancelCtx.mu.
 
        deadline time.Time
@@ -479,8 +537,8 @@ func (c *timerCtx) String() string {
                time.Until(c.deadline).String() + "])"
 }
 
-func (c *timerCtx) cancel(removeFromParent bool, err error) {
-       c.cancelCtx.cancel(false, err)
+func (c *timerCtx) cancel(removeFromParent bool, err, cause error) {
+       c.cancelCtx.cancel(false, err, cause)
        if removeFromParent {
                // Remove this timerCtx from its parent cancelCtx's children.
                removeChild(c.cancelCtx.Context, c)
@@ -581,7 +639,7 @@ func value(c Context, key any) any {
                        c = ctx.Context
                case *timerCtx:
                        if key == &cancelCtxKey {
-                               return &ctx.cancelCtx
+                               return ctx.cancelCtx
                        }
                        c = ctx.Context
                case *emptyCtx:
index 099188090767d9a573d2fbeeeaaa522085dc9d7d..593a7b152187a093a1e81978b71bd674d0334592 100644 (file)
@@ -650,8 +650,9 @@ func XTestCancelRemoves(t testingT) {
 }
 
 func XTestWithCancelCanceledParent(t testingT) {
-       parent, pcancel := WithCancel(Background())
-       pcancel()
+       parent, pcancel := WithCancelCause(Background())
+       cause := fmt.Errorf("Because!")
+       pcancel(cause)
 
        c, _ := WithCancel(parent)
        select {
@@ -662,6 +663,9 @@ func XTestWithCancelCanceledParent(t testingT) {
        if got, want := c.Err(), Canceled; got != want {
                t.Errorf("child not canceled; got = %v, want = %v", got, want)
        }
+       if got, want := Cause(c), cause; got != want {
+               t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
+       }
 }
 
 func XTestWithValueChecksKey(t testingT) {
@@ -785,3 +789,148 @@ func XTestCustomContextGoroutines(t testingT) {
        defer cancel7()
        checkNoGoroutine()
 }
+
+func XTestCause(t testingT) {
+       var (
+               parentCause = fmt.Errorf("parentCause")
+               childCause  = fmt.Errorf("childCause")
+       )
+       for _, test := range []struct {
+               name  string
+               ctx   Context
+               err   error
+               cause error
+       }{
+               {
+                       name:  "Background",
+                       ctx:   Background(),
+                       err:   nil,
+                       cause: nil,
+               },
+               {
+                       name:  "TODO",
+                       ctx:   TODO(),
+                       err:   nil,
+                       cause: nil,
+               },
+               {
+                       name: "WithCancel",
+                       ctx: func() Context {
+                               ctx, cancel := WithCancel(Background())
+                               cancel()
+                               return ctx
+                       }(),
+                       err:   Canceled,
+                       cause: Canceled,
+               },
+               {
+                       name: "WithCancelCause",
+                       ctx: func() Context {
+                               ctx, cancel := WithCancelCause(Background())
+                               cancel(parentCause)
+                               return ctx
+                       }(),
+                       err:   Canceled,
+                       cause: parentCause,
+               },
+               {
+                       name: "WithCancelCause nil",
+                       ctx: func() Context {
+                               ctx, cancel := WithCancelCause(Background())
+                               cancel(nil)
+                               return ctx
+                       }(),
+                       err:   Canceled,
+                       cause: Canceled,
+               },
+               {
+                       name: "WithCancelCause: parent cause before child",
+                       ctx: func() Context {
+                               ctx, cancelParent := WithCancelCause(Background())
+                               ctx, cancelChild := WithCancelCause(ctx)
+                               cancelParent(parentCause)
+                               cancelChild(childCause)
+                               return ctx
+                       }(),
+                       err:   Canceled,
+                       cause: parentCause,
+               },
+               {
+                       name: "WithCancelCause: parent cause after child",
+                       ctx: func() Context {
+                               ctx, cancelParent := WithCancelCause(Background())
+                               ctx, cancelChild := WithCancelCause(ctx)
+                               cancelChild(childCause)
+                               cancelParent(parentCause)
+                               return ctx
+                       }(),
+                       err:   Canceled,
+                       cause: childCause,
+               },
+               {
+                       name: "WithCancelCause: parent cause before nil",
+                       ctx: func() Context {
+                               ctx, cancelParent := WithCancelCause(Background())
+                               ctx, cancelChild := WithCancel(ctx)
+                               cancelParent(parentCause)
+                               cancelChild()
+                               return ctx
+                       }(),
+                       err:   Canceled,
+                       cause: parentCause,
+               },
+               {
+                       name: "WithCancelCause: parent cause after nil",
+                       ctx: func() Context {
+                               ctx, cancelParent := WithCancelCause(Background())
+                               ctx, cancelChild := WithCancel(ctx)
+                               cancelChild()
+                               cancelParent(parentCause)
+                               return ctx
+                       }(),
+                       err:   Canceled,
+                       cause: Canceled,
+               },
+               {
+                       name: "WithCancelCause: child cause after nil",
+                       ctx: func() Context {
+                               ctx, cancelParent := WithCancel(Background())
+                               ctx, cancelChild := WithCancelCause(ctx)
+                               cancelParent()
+                               cancelChild(childCause)
+                               return ctx
+                       }(),
+                       err:   Canceled,
+                       cause: Canceled,
+               },
+               {
+                       name: "WithCancelCause: child cause before nil",
+                       ctx: func() Context {
+                               ctx, cancelParent := WithCancel(Background())
+                               ctx, cancelChild := WithCancelCause(ctx)
+                               cancelChild(childCause)
+                               cancelParent()
+                               return ctx
+                       }(),
+                       err:   Canceled,
+                       cause: childCause,
+               },
+               {
+                       name: "WithTimeout",
+                       ctx: func() Context {
+                               ctx, cancel := WithTimeout(Background(), 0)
+                               cancel()
+                               return ctx
+                       }(),
+                       err:   DeadlineExceeded,
+                       cause: DeadlineExceeded,
+               },
+       } {
+               if got, want := test.ctx.Err(), test.err; want != got {
+                       t.Errorf("%s: ctx.Err() = %v want %v", test.name, got, want)
+               }
+               if got, want := Cause(test.ctx), test.cause; want != got {
+                       t.Errorf("%s: Cause(ctx) = %v want %v", test.name, got, want)
+               }
+       }
+}
index 00eca72d5aff0a3237da8d0f832d1d5f202b0951..d3adb381d655231f436545d54caaeebfc4415dcb 100644 (file)
@@ -29,3 +29,4 @@ func TestWithValueChecksKey(t *testing.T)              { XTestWithValueChecksKey
 func TestInvalidDerivedFail(t *testing.T)              { XTestInvalidDerivedFail(t) }
 func TestDeadlineExceededSupportsTimeout(t *testing.T) { XTestDeadlineExceededSupportsTimeout(t) }
 func TestCustomContextGoroutines(t *testing.T)         { XTestCustomContextGoroutines(t) }
+func TestCause(t *testing.T)                           { XTestCause(t) }