]> Cypherpunks.ru repositories - gostls13.git/commitdiff
context: add AfterFunc
authorDamien Neil <dneil@google.com>
Wed, 5 Apr 2023 23:06:36 +0000 (16:06 -0700)
committerDamien Neil <dneil@google.com>
Wed, 19 Apr 2023 19:13:01 +0000 (19:13 +0000)
Add an AfterFunc function, which registers a function to run after
a context has been canceled.

Add support for contexts that implement an AfterFunc method, which
can be used to avoid the need to start a new goroutine watching
the Done channel when propagating cancellation signals.

Fixes #57928

Change-Id: If0b2cdcc4332961276a1ff57311338e74916259c
Reviewed-on: https://go-review.googlesource.com/c/go/+/482695
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Sameer Ajmani <sameer@golang.org>
api/next/57928.txt [new file with mode: 0644]
src/context/afterfunc_test.go [new file with mode: 0644]
src/context/context.go
src/context/context_test.go
src/context/example_test.go
src/context/x_test.go

diff --git a/api/next/57928.txt b/api/next/57928.txt
new file mode 100644 (file)
index 0000000..5b85e74
--- /dev/null
@@ -0,0 +1 @@
+pkg context, func AfterFunc(Context, func()) func() bool #57928
diff --git a/src/context/afterfunc_test.go b/src/context/afterfunc_test.go
new file mode 100644 (file)
index 0000000..71f639a
--- /dev/null
@@ -0,0 +1,141 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package context_test
+
+import (
+       "context"
+       "sync"
+       "testing"
+       "time"
+)
+
+// afterFuncContext is a context that's not one of the types
+// defined in context.go, that supports registering AfterFuncs.
+type afterFuncContext struct {
+       mu         sync.Mutex
+       afterFuncs map[*struct{}]func()
+       done       chan struct{}
+       err        error
+}
+
+func newAfterFuncContext() context.Context {
+       return &afterFuncContext{}
+}
+
+func (c *afterFuncContext) Deadline() (time.Time, bool) {
+       return time.Time{}, false
+}
+
+func (c *afterFuncContext) Done() <-chan struct{} {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       if c.done == nil {
+               c.done = make(chan struct{})
+       }
+       return c.done
+}
+
+func (c *afterFuncContext) Err() error {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       return c.err
+}
+
+func (c *afterFuncContext) Value(key any) any {
+       return nil
+}
+
+func (c *afterFuncContext) AfterFunc(f func()) func() bool {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       k := &struct{}{}
+       if c.afterFuncs == nil {
+               c.afterFuncs = make(map[*struct{}]func())
+       }
+       c.afterFuncs[k] = f
+       return func() bool {
+               c.mu.Lock()
+               defer c.mu.Unlock()
+               _, ok := c.afterFuncs[k]
+               delete(c.afterFuncs, k)
+               return ok
+       }
+}
+
+func (c *afterFuncContext) cancel(err error) {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       if c.err != nil {
+               return
+       }
+       c.err = err
+       for _, f := range c.afterFuncs {
+               go f()
+       }
+       c.afterFuncs = nil
+}
+
+func TestCustomContextAfterFuncCancel(t *testing.T) {
+       ctx0 := &afterFuncContext{}
+       ctx1, cancel := context.WithCancel(ctx0)
+       defer cancel()
+       ctx0.cancel(context.Canceled)
+       <-ctx1.Done()
+}
+
+func TestCustomContextAfterFuncTimeout(t *testing.T) {
+       ctx0 := &afterFuncContext{}
+       ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration)
+       defer cancel()
+       ctx0.cancel(context.Canceled)
+       <-ctx1.Done()
+}
+
+func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
+       ctx0 := &afterFuncContext{}
+       donec := make(chan struct{})
+       stop := context.AfterFunc(ctx0, func() {
+               close(donec)
+       })
+       defer stop()
+       ctx0.cancel(context.Canceled)
+       <-donec
+}
+
+func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
+       ctx0 := &afterFuncContext{}
+       _, cancel := context.WithCancel(ctx0)
+       if got, want := len(ctx0.afterFuncs), 1; got != want {
+               t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
+       }
+       cancel()
+       if got, want := len(ctx0.afterFuncs), 0; got != want {
+               t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
+       }
+}
+
+func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) {
+       ctx0 := &afterFuncContext{}
+       _, cancel := context.WithTimeout(ctx0, veryLongDuration)
+       if got, want := len(ctx0.afterFuncs), 1; got != want {
+               t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
+       }
+       cancel()
+       if got, want := len(ctx0.afterFuncs), 0; got != want {
+               t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
+       }
+}
+
+func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) {
+       ctx0 := &afterFuncContext{}
+       stop := context.AfterFunc(ctx0, func() {})
+       if got, want := len(ctx0.afterFuncs), 1; got != want {
+               t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
+       }
+       stop()
+       if got, want := len(ctx0.afterFuncs), 0; got != want {
+               t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
+       }
+}
index 7227099889a83bcddd89077482cca955de194b1d..4c0ba7c1d756bc923aad2f9da45ca35c0dab37f0 100644 (file)
@@ -269,8 +269,8 @@ func withCancel(parent Context) *cancelCtx {
        if parent == nil {
                panic("cannot create context from nil parent")
        }
-       c := &cancelCtx{Context: parent}
-       propagateCancel(parent, c)
+       c := &cancelCtx{}
+       c.propagateCancel(parent, c)
        return c
 }
 
@@ -289,48 +289,72 @@ func Cause(c Context) error {
        return nil
 }
 
-// goroutines counts the number of goroutines ever created; for testing.
-var goroutines atomic.Int32
-
-// propagateCancel arranges for child to be canceled when parent is.
-func propagateCancel(parent Context, child canceler) {
-       done := parent.Done()
-       if done == nil {
-               return // parent is never canceled
+// AfterFunc arranges to call f in its own goroutine after ctx is done
+// (cancelled or timed out).
+// If ctx is already done, AfterFunc calls f immediately in its own goroutine.
+//
+// Multiple calls to AfterFunc on a context operate independently;
+// one does not replace another.
+//
+// Calling the returned stop function stops the association of ctx with f.
+// It returns true if the call stopped f from being run.
+// If stop returns false,
+// either the context is done and f has been started in its own goroutine;
+// or f was already stopped.
+// The stop function does not wait for f to complete before returning.
+// If the caller needs to know whether f is completed,
+// it must coordinate with f explicitly.
+//
+// If ctx has a "AfterFunc(func()) func() bool" method,
+// AfterFunc will use it to schedule the call.
+func AfterFunc(ctx Context, f func()) (stop func() bool) {
+       a := &afterFuncCtx{
+               f: f,
+       }
+       a.cancelCtx.propagateCancel(ctx, a)
+       return func() bool {
+               stopped := false
+               a.once.Do(func() {
+                       stopped = true
+               })
+               if stopped {
+                       a.cancel(true, Canceled, nil)
+               }
+               return stopped
        }
+}
 
-       select {
-       case <-done:
-               // parent is already canceled
-               child.cancel(false, parent.Err(), Cause(parent))
-               return
-       default:
-       }
+type afterFuncer interface {
+       AfterFunc(func()) func() bool
+}
 
-       if p, ok := parentCancelCtx(parent); ok {
-               p.mu.Lock()
-               if p.err != nil {
-                       // parent has already been canceled
-                       child.cancel(false, p.err, p.cause)
-               } else {
-                       if p.children == nil {
-                               p.children = make(map[canceler]struct{})
-                       }
-                       p.children[child] = struct{}{}
-               }
-               p.mu.Unlock()
-       } else {
-               goroutines.Add(1)
-               go func() {
-                       select {
-                       case <-parent.Done():
-                               child.cancel(false, parent.Err(), Cause(parent))
-                       case <-child.Done():
-                       }
-               }()
+type afterFuncCtx struct {
+       cancelCtx
+       once sync.Once // either starts running f or stops f from running
+       f    func()
+}
+
+func (a *afterFuncCtx) cancel(removeFromParent bool, err, cause error) {
+       a.cancelCtx.cancel(false, err, cause)
+       if removeFromParent {
+               removeChild(a.Context, a)
        }
+       a.once.Do(func() {
+               go a.f()
+       })
 }
 
+// A stopCtx is used as the parent context of a cancelCtx when
+// an AfterFunc has been registered with the parent.
+// It holds the stop function used to unregister the AfterFunc.
+type stopCtx struct {
+       Context
+       stop func() bool
+}
+
+// goroutines counts the number of goroutines ever created; for testing.
+var goroutines atomic.Int32
+
 // &cancelCtxKey is the key that a cancelCtx returns itself for.
 var cancelCtxKey int
 
@@ -358,6 +382,10 @@ func parentCancelCtx(parent Context) (*cancelCtx, bool) {
 
 // removeChild removes a context from its parent.
 func removeChild(parent Context, child canceler) {
+       if s, ok := parent.(stopCtx); ok {
+               s.stop()
+               return
+       }
        p, ok := parentCancelCtx(parent)
        if !ok {
                return
@@ -424,6 +452,64 @@ func (c *cancelCtx) Err() error {
        return err
 }
 
+// propagateCancel arranges for child to be canceled when parent is.
+// It sets the parent context of cancelCtx.
+func (c *cancelCtx) propagateCancel(parent Context, child canceler) {
+       c.Context = parent
+
+       done := parent.Done()
+       if done == nil {
+               return // parent is never canceled
+       }
+
+       select {
+       case <-done:
+               // parent is already canceled
+               child.cancel(false, parent.Err(), Cause(parent))
+               return
+       default:
+       }
+
+       if p, ok := parentCancelCtx(parent); ok {
+               // parent is a *cancelCtx, or derives from one.
+               p.mu.Lock()
+               if p.err != nil {
+                       // parent has already been canceled
+                       child.cancel(false, p.err, p.cause)
+               } else {
+                       if p.children == nil {
+                               p.children = make(map[canceler]struct{})
+                       }
+                       p.children[child] = struct{}{}
+               }
+               p.mu.Unlock()
+               return
+       }
+
+       if a, ok := parent.(afterFuncer); ok {
+               // parent implements an AfterFunc method.
+               c.mu.Lock()
+               stop := a.AfterFunc(func() {
+                       child.cancel(false, parent.Err(), Cause(parent))
+               })
+               c.Context = stopCtx{
+                       Context: parent,
+                       stop:    stop,
+               }
+               c.mu.Unlock()
+               return
+       }
+
+       goroutines.Add(1)
+       go func() {
+               select {
+               case <-parent.Done():
+                       child.cancel(false, parent.Err(), Cause(parent))
+               case <-child.Done():
+               }
+       }()
+}
+
 type stringer interface {
        String() string
 }
@@ -533,10 +619,9 @@ func WithDeadlineCause(parent Context, d time.Time, cause error) (Context, Cance
                return WithCancel(parent)
        }
        c := &timerCtx{
-               cancelCtx: cancelCtx{Context: parent},
-               deadline:  d,
+               deadline: d,
        }
-       propagateCancel(parent, c)
+       c.cancelCtx.propagateCancel(parent, c)
        dur := time.Until(d)
        if dur <= 0 {
                c.cancel(true, DeadlineExceeded, cause) // deadline has already passed
index 74738fd3160b2c11a084380e2978b318f42461dc..57066c968596f631f63b811f35e84d97b3570ca6 100644 (file)
@@ -44,12 +44,15 @@ func XTestParentFinishesChild(t testingT) {
        // Context tree:
        // parent -> cancelChild
        // parent -> valueChild -> timerChild
+       // parent -> afterChild
        parent, cancel := WithCancel(Background())
        cancelChild, stop := WithCancel(parent)
        defer stop()
        valueChild := WithValue(parent, "key", "value")
        timerChild, stop := WithTimeout(valueChild, veryLongDuration)
        defer stop()
+       afterStop := AfterFunc(parent, func() {})
+       defer afterStop()
 
        select {
        case x := <-parent.Done():
@@ -63,13 +66,20 @@ func XTestParentFinishesChild(t testingT) {
        default:
        }
 
-       // The parent's children should contain the two cancelable children.
+       // The parent's children should contain the three cancelable children.
        pc := parent.(*cancelCtx)
        cc := cancelChild.(*cancelCtx)
        tc := timerChild.(*timerCtx)
        pc.mu.Lock()
-       if len(pc.children) != 2 || !contains(pc.children, cc) || !contains(pc.children, tc) {
-               t.Errorf("bad linkage: pc.children = %v, want %v and %v",
+       var ac *afterFuncCtx
+       for c := range pc.children {
+               if a, ok := c.(*afterFuncCtx); ok {
+                       ac = a
+                       break
+               }
+       }
+       if len(pc.children) != 3 || !contains(pc.children, cc) || !contains(pc.children, tc) || ac == nil {
+               t.Errorf("bad linkage: pc.children = %v, want %v, %v, and an afterFunc",
                        pc.children, cc, tc)
        }
        pc.mu.Unlock()
@@ -80,6 +90,9 @@ func XTestParentFinishesChild(t testingT) {
        if p, ok := parentCancelCtx(tc.Context); !ok || p != pc {
                t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc)
        }
+       if p, ok := parentCancelCtx(ac.Context); !ok || p != pc {
+               t.Errorf("bad linkage: parentCancelCtx(afterChild.Context) = %v, %v want %v, true", p, ok, pc)
+       }
 
        cancel()
 
@@ -197,6 +210,13 @@ func XTestCancelRemoves(t testingT) {
        checkChildren("with WithTimeout child ", ctx, 1)
        cancel()
        checkChildren("after canceling WithTimeout child", ctx, 0)
+
+       ctx, _ = WithCancel(Background())
+       checkChildren("after creation", ctx, 0)
+       stop := AfterFunc(ctx, func() {})
+       checkChildren("with AfterFunc child ", ctx, 1)
+       stop()
+       checkChildren("after stopping AfterFunc child ", ctx, 0)
 }
 
 type myCtx struct {
index 8dad0a4220bccb490bb203b48857094c4485aca8..38549a12de898e19f39448bf01f37209acdb74c9 100644 (file)
@@ -6,7 +6,10 @@ package context_test
 
 import (
        "context"
+       "errors"
        "fmt"
+       "net"
+       "sync"
        "time"
 )
 
@@ -118,3 +121,104 @@ func ExampleWithValue() {
        // found value: Go
        // key not found: color
 }
+
+// This example uses AfterFunc to define a function which waits on a sync.Cond,
+// stopping the wait when a context is canceled.
+func ExampleAfterFunc_cond() {
+       waitOnCond := func(ctx context.Context, cond *sync.Cond) error {
+               stopf := context.AfterFunc(ctx, cond.Broadcast)
+               defer stopf()
+               cond.Wait()
+               return ctx.Err()
+       }
+
+       ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+       defer cancel()
+
+       var mu sync.Mutex
+       cond := sync.NewCond(&mu)
+
+       mu.Lock()
+       err := waitOnCond(ctx, cond)
+       fmt.Println(err)
+
+       // Output:
+       // context deadline exceeded
+}
+
+// This example uses AfterFunc to define a function which reads from a net.Conn,
+// stopping the read when a context is canceled.
+func ExampleAfterFunc_connection() {
+       readFromConn := func(ctx context.Context, conn net.Conn, b []byte) (n int, err error) {
+               stopc := make(chan struct{})
+               stop := context.AfterFunc(ctx, func() {
+                       conn.SetReadDeadline(time.Now())
+                       close(stopc)
+               })
+               n, err = conn.Read(b)
+               if !stop() {
+                       // The AfterFunc was started.
+                       // Wait for it to complete, and reset the Conn's deadline.
+                       <-stopc
+                       conn.SetReadDeadline(time.Time{})
+                       return n, ctx.Err()
+               }
+               return n, err
+       }
+
+       listener, err := net.Listen("tcp", ":0")
+       if err != nil {
+               fmt.Println(err)
+               return
+       }
+       defer listener.Close()
+
+       conn, err := net.Dial(listener.Addr().Network(), listener.Addr().String())
+       if err != nil {
+               fmt.Println(err)
+               return
+       }
+       defer conn.Close()
+
+       ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+       defer cancel()
+
+       b := make([]byte, 1024)
+       _, err = readFromConn(ctx, conn, b)
+       fmt.Println(err)
+
+       // Output:
+       // context deadline exceeded
+}
+
+// This example uses AfterFunc to define a function which combines
+// the cancellation signals of two Contexts.
+func ExampleAfterFunc_merge() {
+       // mergeCancel returns a context that contains the values of ctx,
+       // and which is canceled when either ctx or cancelCtx is canceled.
+       mergeCancel := func(ctx, cancelCtx context.Context) (context.Context, context.CancelFunc) {
+               ctx, cancel := context.WithCancelCause(ctx)
+               stop := context.AfterFunc(cancelCtx, func() {
+                       cancel(context.Cause(cancelCtx))
+               })
+               return ctx, func() {
+                       stop()
+                       cancel(context.Canceled)
+               }
+       }
+
+       ctx1, cancel1 := context.WithCancelCause(context.Background())
+       defer cancel1(errors.New("ctx1 canceled"))
+
+       ctx2, cancel2 := context.WithCancelCause(context.Background())
+
+       mergedCtx, mergedCancel := mergeCancel(ctx1, ctx2)
+       defer mergedCancel()
+
+       cancel2(errors.New("ctx2 canceled"))
+       <-mergedCtx.Done()
+       fmt.Println(context.Cause(mergedCtx))
+
+       // Output:
+       // ctx2 canceled
+}
index 957590d01b19ea89bb8079014666afbf9c89a209..bf0af674c188aa7e4d2504dd5ec3de896642f086 100644 (file)
@@ -502,6 +502,24 @@ func TestWithCancelCanceledParent(t *testing.T) {
        }
 }
 
+func TestWithCancelSimultaneouslyCanceledParent(t *testing.T) {
+       // Cancel the parent goroutine concurrently with creating a child.
+       for i := 0; i < 100; i++ {
+               parent, pcancel := WithCancelCause(Background())
+               cause := fmt.Errorf("Because!")
+               go pcancel(cause)
+
+               c, _ := WithCancel(parent)
+               <-c.Done()
+               if got, want := c.Err(), Canceled; got != want {
+                       t.Errorf("child not canceled; got = %v, want = %v", got, want)
+               }
+               if got, want := Cause(c), cause; got != want {
+                       t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
+               }
+       }
+}
+
 func TestWithValueChecksKey(t *testing.T) {
        panicVal := recoveredValue(func() { WithValue(Background(), []byte("foo"), "bar") })
        if panicVal == nil {
@@ -816,3 +834,123 @@ func TestWithoutCancel(t *testing.T) {
                t.Errorf("ctx.Value(%q) = %q want %q", key, v, value)
        }
 }
+
+type customDoneContext struct {
+       Context
+       donec chan struct{}
+}
+
+func (c *customDoneContext) Done() <-chan struct{} {
+       return c.donec
+}
+
+func TestCustomContextPropagation(t *testing.T) {
+       cause := errors.New("TestCustomContextPropagation")
+       donec := make(chan struct{})
+       ctx1, cancel1 := WithCancelCause(Background())
+       ctx2 := &customDoneContext{
+               Context: ctx1,
+               donec:   donec,
+       }
+       ctx3, cancel3 := WithCancel(ctx2)
+       defer cancel3()
+
+       cancel1(cause)
+       close(donec)
+
+       <-ctx3.Done()
+       if got, want := ctx3.Err(), Canceled; got != want {
+               t.Errorf("child not canceled; got = %v, want = %v", got, want)
+       }
+       if got, want := Cause(ctx3), cause; got != want {
+               t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
+       }
+}
+
+func TestAfterFuncCalledAfterCancel(t *testing.T) {
+       ctx, cancel := WithCancel(Background())
+       donec := make(chan struct{})
+       stop := AfterFunc(ctx, func() {
+               close(donec)
+       })
+       select {
+       case <-donec:
+               t.Fatalf("AfterFunc called before context is done")
+       case <-time.After(shortDuration):
+       }
+       cancel()
+       select {
+       case <-donec:
+       case <-time.After(veryLongDuration):
+               t.Fatalf("AfterFunc not called after context is canceled")
+       }
+       if stop() {
+               t.Fatalf("stop() = true, want false")
+       }
+}
+
+func TestAfterFuncCalledAfterTimeout(t *testing.T) {
+       ctx, cancel := WithTimeout(Background(), shortDuration)
+       defer cancel()
+       donec := make(chan struct{})
+       AfterFunc(ctx, func() {
+               close(donec)
+       })
+       select {
+       case <-donec:
+       case <-time.After(veryLongDuration):
+               t.Fatalf("AfterFunc not called after context is canceled")
+       }
+}
+
+func TestAfterFuncCalledImmediately(t *testing.T) {
+       ctx, cancel := WithCancel(Background())
+       cancel()
+       donec := make(chan struct{})
+       AfterFunc(ctx, func() {
+               close(donec)
+       })
+       select {
+       case <-donec:
+       case <-time.After(veryLongDuration):
+               t.Fatalf("AfterFunc not called for already-canceled context")
+       }
+}
+
+func TestAfterFuncNotCalledAfterStop(t *testing.T) {
+       ctx, cancel := WithCancel(Background())
+       donec := make(chan struct{})
+       stop := AfterFunc(ctx, func() {
+               close(donec)
+       })
+       if !stop() {
+               t.Fatalf("stop() = false, want true")
+       }
+       cancel()
+       select {
+       case <-donec:
+               t.Fatalf("AfterFunc called for already-canceled context")
+       case <-time.After(shortDuration):
+       }
+       if stop() {
+               t.Fatalf("stop() = true, want false")
+       }
+}
+
+// This test verifies that cancelling a context does not block waiting for AfterFuncs to finish.
+func TestAfterFuncCalledAsynchronously(t *testing.T) {
+       ctx, cancel := WithCancel(Background())
+       donec := make(chan struct{})
+       stop := AfterFunc(ctx, func() {
+               // The channel send blocks until donec is read from.
+               donec <- struct{}{}
+       })
+       defer stop()
+       cancel()
+       // After cancel returns, read from donec and unblock the AfterFunc.
+       select {
+       case <-donec:
+       case <-time.After(veryLongDuration):
+               t.Fatalf("AfterFunc not called after context is canceled")
+       }
+}