]> Cypherpunks.ru repositories - gostls13.git/commitdiff
cmd/cgo: add #cgo noescape/nocallback annotations
authordoujiang24 <doujiang24@gmail.com>
Fri, 25 Aug 2023 17:06:31 +0000 (17:06 +0000)
committerGopher Robot <gobot@golang.org>
Fri, 25 Aug 2023 17:39:23 +0000 (17:39 +0000)
When passing pointers of Go objects from Go to C, the cgo command generate _Cgo_use(pN) for the unsafe.Pointer type arguments, so that the Go compiler will escape these object to heap.

Since the C function may callback to Go, then the Go stack might grow/shrink, that means the pointers that the C function have will be invalid.

After adding the #cgo noescape annotation for a C function, the cgo command won't generate _Cgo_use(pN), and the Go compiler won't force the object escape to heap.

After adding the #cgo nocallback annotation for a C function, which means the C function won't callback to Go, if it do callback to Go, the Go process will crash.

Fixes #56378

Change-Id: Ifdca070584e0d349c7b12276270e50089e481f7a
GitHub-Last-Rev: f1a17b08b0590eca2670e404bbfedad5461df72f
GitHub-Pull-Request: golang/go#60399
Reviewed-on: https://go-review.googlesource.com/c/go/+/497837
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>

15 files changed:
src/cmd/cgo/doc.go
src/cmd/cgo/gcc.go
src/cmd/cgo/internal/testerrors/errors_test.go
src/cmd/cgo/internal/testerrors/testdata/notmatchedcfunction.go [new file with mode: 0644]
src/cmd/cgo/main.go
src/cmd/cgo/out.go
src/cmd/go/internal/modindex/build.go
src/go/build/build.go
src/runtime/cgo.go
src/runtime/cgocall.go
src/runtime/crash_cgo_test.go
src/runtime/runtime2.go
src/runtime/testdata/testprogcgo/cgonocallback.c [new file with mode: 0644]
src/runtime/testdata/testprogcgo/cgonocallback.go [new file with mode: 0644]
src/runtime/testdata/testprogcgo/cgonoescape.go [new file with mode: 0644]

index b1a288f57392aaf0524aed207e219b447e9006ec..894df2d836de9c74a3d0f76f003263cea678b20d 100644 (file)
@@ -420,6 +420,30 @@ passing uninitialized C memory to Go code if the Go code is going to
 store pointer values in it. Zero out the memory in C before passing it
 to Go.
 
+# Optimizing calls of C code
+
+When passing a Go pointer to a C function the compiler normally ensures
+that the Go object lives on the heap. If the C function does not keep
+a copy of the Go pointer, and never passes the Go pointer back to Go code,
+then this is unnecessary. The #cgo noescape directive may be used to tell
+the compiler that no Go pointers escape via the named C function.
+If the noescape directive is used and the C function does not handle the
+pointer safely, the program may crash or see memory corruption.
+
+For example:
+
+       // #cgo noescape cFunctionName
+
+When a Go function calls a C function, it prepares for the C function to
+call back to a Go function. the #cgo nocallback directive may be used to
+tell the compiler that these preparations are not necessary.
+If the nocallback directive is used and the C function does call back into
+Go code, the program will panic.
+
+For example:
+
+       // #cgo nocallback cFunctionName
+
 # Special cases
 
 A few special C types which would normally be represented by a pointer
index 78a44d33a247b0a6267aa8da538a3c81652a2c27..28dc2a9bf8998123cfa59a916f3b9332c3518ec5 100644 (file)
@@ -73,18 +73,32 @@ func cname(s string) string {
        return s
 }
 
-// DiscardCgoDirectives processes the import C preamble, and discards
-// all #cgo CFLAGS and LDFLAGS directives, so they don't make their
-// way into _cgo_export.h.
-func (f *File) DiscardCgoDirectives() {
+// ProcessCgoDirectives processes the import C preamble:
+//  1. discards all #cgo CFLAGS, LDFLAGS, nocallback and noescape directives,
+//     so they don't make their way into _cgo_export.h.
+//  2. parse the nocallback and noescape directives.
+func (f *File) ProcessCgoDirectives() {
        linesIn := strings.Split(f.Preamble, "\n")
        linesOut := make([]string, 0, len(linesIn))
+       f.NoCallbacks = make(map[string]bool)
+       f.NoEscapes = make(map[string]bool)
        for _, line := range linesIn {
                l := strings.TrimSpace(line)
                if len(l) < 5 || l[:4] != "#cgo" || !unicode.IsSpace(rune(l[4])) {
                        linesOut = append(linesOut, line)
                } else {
                        linesOut = append(linesOut, "")
+
+                       // #cgo (nocallback|noescape) <function name>
+                       if fields := strings.Fields(l); len(fields) == 3 {
+                               directive := fields[1]
+                               funcName := fields[2]
+                               if directive == "nocallback" {
+                                       f.NoCallbacks[funcName] = true
+                               } else if directive == "noescape" {
+                                       f.NoEscapes[funcName] = true
+                               }
+                       }
                }
        }
        f.Preamble = strings.Join(linesOut, "\n")
index 486530e186c610fc96d26cfc058e6ac4305593d5..fd522ba4745a83ac296619cd5b1e456889a1dea5 100644 (file)
@@ -39,16 +39,23 @@ func check(t *testing.T, file string) {
                                continue
                        }
 
-                       _, frag, ok := bytes.Cut(line, []byte("ERROR HERE: "))
-                       if !ok {
-                               continue
+                       if _, frag, ok := bytes.Cut(line, []byte("ERROR HERE: ")); ok {
+                               re, err := regexp.Compile(fmt.Sprintf(":%d:.*%s", i+1, frag))
+                               if err != nil {
+                                       t.Errorf("Invalid regexp after `ERROR HERE: `: %#q", frag)
+                                       continue
+                               }
+                               errors = append(errors, re)
                        }
-                       re, err := regexp.Compile(fmt.Sprintf(":%d:.*%s", i+1, frag))
-                       if err != nil {
-                               t.Errorf("Invalid regexp after `ERROR HERE: `: %#q", frag)
-                               continue
+
+                       if _, frag, ok := bytes.Cut(line, []byte("ERROR MESSAGE: ")); ok {
+                               re, err := regexp.Compile(string(frag))
+                               if err != nil {
+                                       t.Errorf("Invalid regexp after `ERROR MESSAGE: `: %#q", frag)
+                                       continue
+                               }
+                               errors = append(errors, re)
                        }
-                       errors = append(errors, re)
                }
                if len(errors) == 0 {
                        t.Fatalf("cannot find ERROR HERE")
@@ -165,3 +172,8 @@ func TestMallocCrashesOnNil(t *testing.T) {
                t.Fatalf("succeeded unexpectedly")
        }
 }
+
+func TestNotMatchedCFunction(t *testing.T) {
+       file := "notmatchedcfunction.go"
+       check(t, file)
+}
diff --git a/src/cmd/cgo/internal/testerrors/testdata/notmatchedcfunction.go b/src/cmd/cgo/internal/testerrors/testdata/notmatchedcfunction.go
new file mode 100644 (file)
index 0000000..46afeef
--- /dev/null
@@ -0,0 +1,14 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+/*
+// ERROR MESSAGE: #cgo noescape noMatchedCFunction: no matched C function
+#cgo noescape noMatchedCFunction
+*/
+import "C"
+
+func main() {
+}
index 78020aedbe5aa696397fdc1e9a77d0c55e5b8b29..55f9cdc318b08e1f66b0e3d0236c52989428fba4 100644 (file)
@@ -48,6 +48,8 @@ type Package struct {
        Preamble    string          // collected preamble for _cgo_export.h
        typedefs    map[string]bool // type names that appear in the types of the objects we're interested in
        typedefList []typedefInfo
+       noCallbacks map[string]bool // C function names with #cgo nocallback directive
+       noEscapes   map[string]bool // C function names with #cgo noescape directive
 }
 
 // A typedefInfo is an element on Package.typedefList: a typedef name
@@ -59,16 +61,18 @@ type typedefInfo struct {
 
 // A File collects information about a single Go input file.
 type File struct {
-       AST      *ast.File           // parsed AST
-       Comments []*ast.CommentGroup // comments from file
-       Package  string              // Package name
-       Preamble string              // C preamble (doc comment on import "C")
-       Ref      []*Ref              // all references to C.xxx in AST
-       Calls    []*Call             // all calls to C.xxx in AST
-       ExpFunc  []*ExpFunc          // exported functions for this file
-       Name     map[string]*Name    // map from Go name to Name
-       NamePos  map[*Name]token.Pos // map from Name to position of the first reference
-       Edit     *edit.Buffer
+       AST         *ast.File           // parsed AST
+       Comments    []*ast.CommentGroup // comments from file
+       Package     string              // Package name
+       Preamble    string              // C preamble (doc comment on import "C")
+       Ref         []*Ref              // all references to C.xxx in AST
+       Calls       []*Call             // all calls to C.xxx in AST
+       ExpFunc     []*ExpFunc          // exported functions for this file
+       Name        map[string]*Name    // map from Go name to Name
+       NamePos     map[*Name]token.Pos // map from Name to position of the first reference
+       NoCallbacks map[string]bool     // C function names that with #cgo nocallback directive
+       NoEscapes   map[string]bool     // C function names that with #cgo noescape directive
+       Edit        *edit.Buffer
 }
 
 func (f *File) offset(p token.Pos) int {
@@ -374,7 +378,7 @@ func main() {
                f := new(File)
                f.Edit = edit.NewBuffer(b)
                f.ParseGo(input, b)
-               f.DiscardCgoDirectives()
+               f.ProcessCgoDirectives()
                fs[i] = f
        }
 
@@ -413,6 +417,25 @@ func main() {
                        p.writeOutput(f, input)
                }
        }
+       cFunctions := make(map[string]bool)
+       for _, key := range nameKeys(p.Name) {
+               n := p.Name[key]
+               if n.FuncType != nil {
+                       cFunctions[n.C] = true
+               }
+       }
+
+       for funcName := range p.noEscapes {
+               if _, found := cFunctions[funcName]; !found {
+                       error_(token.NoPos, "#cgo noescape %s: no matched C function", funcName)
+               }
+       }
+
+       for funcName := range p.noCallbacks {
+               if _, found := cFunctions[funcName]; !found {
+                       error_(token.NoPos, "#cgo nocallback %s: no matched C function", funcName)
+               }
+       }
 
        if !*godefs {
                p.writeDefs()
@@ -450,10 +473,12 @@ func newPackage(args []string) *Package {
        os.Setenv("LC_ALL", "C")
 
        p := &Package{
-               PtrSize:  ptrSize,
-               IntSize:  intSize,
-               CgoFlags: make(map[string][]string),
-               Written:  make(map[string]bool),
+               PtrSize:     ptrSize,
+               IntSize:     intSize,
+               CgoFlags:    make(map[string][]string),
+               Written:     make(map[string]bool),
+               noCallbacks: make(map[string]bool),
+               noEscapes:   make(map[string]bool),
        }
        p.addToFlag("CFLAGS", args)
        return p
@@ -487,6 +512,14 @@ func (p *Package) Record(f *File) {
                }
        }
 
+       // merge nocallback & noescape
+       for k, v := range f.NoCallbacks {
+               p.noCallbacks[k] = v
+       }
+       for k, v := range f.NoEscapes {
+               p.noEscapes[k] = v
+       }
+
        if f.ExpFunc != nil {
                p.ExpFunc = append(p.ExpFunc, f.ExpFunc...)
                p.Preamble += "\n" + f.Preamble
index b2933e2d82e80e381f57dc909eaeb72b3ae5c25b..947c61b5c5970e0b14d1bd77722c1e5f9df120c8 100644 (file)
@@ -106,6 +106,8 @@ func (p *Package) writeDefs() {
                fmt.Fprintf(fgo2, "//go:linkname _Cgo_use runtime.cgoUse\n")
                fmt.Fprintf(fgo2, "func _Cgo_use(interface{})\n")
        }
+       fmt.Fprintf(fgo2, "//go:linkname _Cgo_no_callback runtime.cgoNoCallback\n")
+       fmt.Fprintf(fgo2, "func _Cgo_no_callback(bool)\n")
 
        typedefNames := make([]string, 0, len(typedef))
        for name := range typedef {
@@ -612,6 +614,12 @@ func (p *Package) writeDefsFunc(fgo2 io.Writer, n *Name, callsMalloc *bool) {
                arg = "uintptr(unsafe.Pointer(&r1))"
        }
 
+       noCallback := p.noCallbacks[n.C]
+       if noCallback {
+               // disable cgocallback, will check it in runtime.
+               fmt.Fprintf(fgo2, "\t_Cgo_no_callback(true)\n")
+       }
+
        prefix := ""
        if n.AddError {
                prefix = "errno := "
@@ -620,13 +628,21 @@ func (p *Package) writeDefsFunc(fgo2 io.Writer, n *Name, callsMalloc *bool) {
        if n.AddError {
                fmt.Fprintf(fgo2, "\tif errno != 0 { r2 = syscall.Errno(errno) }\n")
        }
-       fmt.Fprintf(fgo2, "\tif _Cgo_always_false {\n")
-       if d.Type.Params != nil {
-               for i := range d.Type.Params.List {
-                       fmt.Fprintf(fgo2, "\t\t_Cgo_use(p%d)\n", i)
+       if noCallback {
+               fmt.Fprintf(fgo2, "\t_Cgo_no_callback(false)\n")
+       }
+
+       // skip _Cgo_use when noescape exist,
+       // so that the compiler won't force to escape them to heap.
+       if !p.noEscapes[n.C] {
+               fmt.Fprintf(fgo2, "\tif _Cgo_always_false {\n")
+               if d.Type.Params != nil {
+                       for i := range d.Type.Params.List {
+                               fmt.Fprintf(fgo2, "\t\t_Cgo_use(p%d)\n", i)
+                       }
                }
+               fmt.Fprintf(fgo2, "\t}\n")
        }
-       fmt.Fprintf(fgo2, "\t}\n")
        fmt.Fprintf(fgo2, "\treturn\n")
        fmt.Fprintf(fgo2, "}\n")
 }
@@ -1612,9 +1628,11 @@ const goProlog = `
 func _cgo_runtime_cgocall(unsafe.Pointer, uintptr) int32
 
 //go:linkname _cgoCheckPointer runtime.cgoCheckPointer
+//go:noescape
 func _cgoCheckPointer(interface{}, interface{})
 
 //go:linkname _cgoCheckResult runtime.cgoCheckResult
+//go:noescape
 func _cgoCheckResult(interface{})
 `
 
index b57f2f6368f0fedf3158a88999db83ca6b30ea12..0b06373984190a53f1b2dc1f30e5f96662e0d404 100644 (file)
@@ -622,6 +622,11 @@ func (ctxt *Context) saveCgo(filename string, di *build.Package, text string) er
                        continue
                }
 
+               // #cgo (nocallback|noescape) <function name>
+               if fields := strings.Fields(line); len(fields) == 3 && (fields[1] == "nocallback" || fields[1] == "noescape") {
+                       continue
+               }
+
                // Split at colon.
                line, argstr, ok := strings.Cut(strings.TrimSpace(line[4:]), ":")
                if !ok {
index dd6cdc903a21a816bd007e51c427bc19be2b1603..f51713806119e106f2ab6d79daad7fa3d05c7569 100644 (file)
@@ -1687,6 +1687,11 @@ func (ctxt *Context) saveCgo(filename string, di *Package, cg *ast.CommentGroup)
                        continue
                }
 
+               // #cgo (nocallback|noescape) <function name>
+               if fields := strings.Fields(line); len(fields) == 3 && (fields[1] == "nocallback" || fields[1] == "noescape") {
+                       continue
+               }
+
                // Split at colon.
                line, argstr, ok := strings.Cut(strings.TrimSpace(line[4:]), ":")
                if !ok {
index 395303552c9bd1e3bc8f3cc11a1db75e963e2401..40c8c748d3e56ed5c1b2fb0d6c2f2fb200e55827 100644 (file)
@@ -61,3 +61,11 @@ func cgoUse(any) { throw("cgoUse should not be called") }
 var cgoAlwaysFalse bool
 
 var cgo_yield = &_cgo_yield
+
+func cgoNoCallback(v bool) {
+       g := getg()
+       if g.nocgocallback && v {
+               panic("runtime: unexpected setting cgoNoCallback")
+       }
+       g.nocgocallback = v
+}
index f6e2f6381382b56309abca907cf3f12c705828a9..802d6f2084067bfd898201f420ebd1f22fbbdf21 100644 (file)
@@ -242,6 +242,10 @@ func cgocallbackg(fn, frame unsafe.Pointer, ctxt uintptr) {
 
        osPreemptExtExit(gp.m)
 
+       if gp.nocgocallback {
+               panic("runtime: function marked with #cgo nocallback called back into Go")
+       }
+
        cgocallbackg1(fn, frame, ctxt) // will call unlockOSThread
 
        // At this point unlockOSThread has been called.
index e1851808f3194ed4e60836e89d119274c5b41e64..88044caacf7442e129b8a40f43d975b03b7fbe42 100644 (file)
@@ -753,6 +753,22 @@ func TestNeedmDeadlock(t *testing.T) {
        }
 }
 
+func TestCgoNoCallback(t *testing.T) {
+       got := runTestProg(t, "testprogcgo", "CgoNoCallback")
+       want := "function marked with #cgo nocallback called back into Go"
+       if !strings.Contains(got, want) {
+               t.Fatalf("did not see %q in output:\n%s", want, got)
+       }
+}
+
+func TestCgoNoEscape(t *testing.T) {
+       got := runTestProg(t, "testprogcgo", "CgoNoEscape")
+       want := "OK\n"
+       if got != want {
+               t.Fatalf("want %s, got %s\n", want, got)
+       }
+}
+
 func TestCgoTracebackGoroutineProfile(t *testing.T) {
        output := runTestProg(t, "testprogcgo", "GoroutineProfile")
        want := "OK\n"
index c3a367930275ac5fdc8c183731bf67cfdb26f26a..8809b5d569c787924f06e75190e61b3506caee2a 100644 (file)
@@ -481,6 +481,7 @@ type g struct {
        parkingOnChan atomic.Bool
 
        raceignore    int8  // ignore race detection events
+       nocgocallback bool  // whether disable callback from C
        tracking      bool  // whether we're tracking this G for sched latency statistics
        trackingSeq   uint8 // used to decide whether to track this G
        trackingStamp int64 // timestamp of when the G last started being tracked
diff --git a/src/runtime/testdata/testprogcgo/cgonocallback.c b/src/runtime/testdata/testprogcgo/cgonocallback.c
new file mode 100644 (file)
index 0000000..465a484
--- /dev/null
@@ -0,0 +1,9 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "_cgo_export.h"
+
+void runCShouldNotCallback() {
+       CallbackToGo();
+}
diff --git a/src/runtime/testdata/testprogcgo/cgonocallback.go b/src/runtime/testdata/testprogcgo/cgonocallback.go
new file mode 100644 (file)
index 0000000..8cbbfd1
--- /dev/null
@@ -0,0 +1,32 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+// #cgo nocallback annotations for a C function means it should not callback to Go.
+// But it do callback to go in this test, Go should crash here.
+
+/*
+#cgo nocallback runCShouldNotCallback
+
+extern void runCShouldNotCallback();
+*/
+import "C"
+
+import (
+       "fmt"
+)
+
+func init() {
+       register("CgoNoCallback", CgoNoCallback)
+}
+
+//export CallbackToGo
+func CallbackToGo() {
+}
+
+func CgoNoCallback() {
+       C.runCShouldNotCallback()
+       fmt.Println("OK")
+}
diff --git a/src/runtime/testdata/testprogcgo/cgonoescape.go b/src/runtime/testdata/testprogcgo/cgonoescape.go
new file mode 100644 (file)
index 0000000..056be44
--- /dev/null
@@ -0,0 +1,84 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+// #cgo noescape annotations for a C function means its arguments won't escape to heap.
+
+// We assume that there won't be 100 new allocated heap objects in other places,
+// i.e. runtime.ReadMemStats or other runtime background works.
+// So, the tests are:
+// 1. at least 100 new allocated heap objects after invoking withoutNoEscape 100 times.
+// 2. less than 100 new allocated heap objects after invoking withoutNoEscape 100 times.
+
+/*
+#cgo noescape runCWithNoEscape
+
+void runCWithNoEscape(void *p) {
+}
+void runCWithoutNoEscape(void *p) {
+}
+*/
+import "C"
+
+import (
+       "fmt"
+       "runtime"
+       "runtime/debug"
+       "unsafe"
+)
+
+const num = 100
+
+func init() {
+       register("CgoNoEscape", CgoNoEscape)
+}
+
+//go:noinline
+func withNoEscape() {
+       var str string
+       C.runCWithNoEscape(unsafe.Pointer(&str))
+}
+
+//go:noinline
+func withoutNoEscape() {
+       var str string
+       C.runCWithoutNoEscape(unsafe.Pointer(&str))
+}
+
+func CgoNoEscape() {
+       // make GC stop to see the heap objects allocated
+       debug.SetGCPercent(-1)
+
+       var stats runtime.MemStats
+       runtime.ReadMemStats(&stats)
+       preHeapObjects := stats.HeapObjects
+
+       for i := 0; i < num; i++ {
+               withNoEscape()
+       }
+
+       runtime.ReadMemStats(&stats)
+       nowHeapObjects := stats.HeapObjects
+
+       if nowHeapObjects-preHeapObjects >= num {
+               fmt.Printf("too many heap objects allocated, pre: %v, now: %v\n", preHeapObjects, nowHeapObjects)
+       }
+
+       runtime.ReadMemStats(&stats)
+       preHeapObjects = stats.HeapObjects
+
+       for i := 0; i < num; i++ {
+               withoutNoEscape()
+       }
+
+       runtime.ReadMemStats(&stats)
+       nowHeapObjects = stats.HeapObjects
+
+       if nowHeapObjects-preHeapObjects < num {
+               fmt.Printf("too few heap objects allocated, pre: %v, now: %v\n", preHeapObjects, nowHeapObjects)
+       }
+
+       fmt.Println("OK")
+}