]> Cypherpunks.ru repositories - gostls13.git/commitdiff
errors, fmt: add support for wrapping multiple errors
authorDamien Neil <dneil@google.com>
Thu, 22 Sep 2022 17:43:26 +0000 (10:43 -0700)
committerDamien Neil <dneil@google.com>
Thu, 29 Sep 2022 18:40:40 +0000 (18:40 +0000)
An error which implements an "Unwrap() []error" method wraps all the
non-nil errors in the returned []error.

We replace the concept of the "error chain" inspected by errors.Is
and errors.As with the "error tree". Is and As perform a pre-order,
depth-first traversal of an error's tree. As returns the first
matching result, if any.

The new errors.Join function returns an error wrapping a list of errors.

The fmt.Errorf function now supports multiple instances of the %w verb.

For #53435.

Change-Id: Ib7402e70b68e28af8f201d2b66bd8e87ccfb5283
Reviewed-on: https://go-review.googlesource.com/c/go/+/432898
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Rob Pike <r@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Joseph Tsai <joetsai@digital-static.net>
api/next/53435.txt [new file with mode: 0644]
src/errors/errors.go
src/errors/errors_test.go
src/errors/join.go [new file with mode: 0644]
src/errors/join_test.go [new file with mode: 0644]
src/errors/wrap.go
src/errors/wrap_test.go
src/fmt/errors.go
src/fmt/errors_test.go
src/fmt/print.go

diff --git a/api/next/53435.txt b/api/next/53435.txt
new file mode 100644 (file)
index 0000000..8f295fc
--- /dev/null
@@ -0,0 +1 @@
+pkg errors, func Join(...error) error #53435
index f2fabacd4e9d85961f299ec869ee721bbad7ce5f..8436f812a6e5acc675b03cacae534d0a765cfb05 100644 (file)
@@ -6,26 +6,29 @@
 //
 // The New function creates errors whose only content is a text message.
 //
-// The Unwrap, Is and As functions work on errors that may wrap other errors.
-// An error wraps another error if its type has the method
+// An error e wraps another error if e's type has one of the methods
 //
 //     Unwrap() error
+//     Unwrap() []error
 //
-// If e.Unwrap() returns a non-nil error w, then we say that e wraps w.
+// If e.Unwrap() returns a non-nil error w or a slice containing w,
+// then we say that e wraps w. A nil error returned from e.Unwrap()
+// indicates that e does not wrap any error. It is invalid for an
+// Unwrap method to return an []error containing a nil error value.
 //
-// Unwrap unpacks wrapped errors. If its argument's type has an
-// Unwrap method, it calls the method once. Otherwise, it returns nil.
+// An easy way to create wrapped errors is to call fmt.Errorf and apply
+// the %w verb to the error argument:
 //
-// A simple way to create wrapped errors is to call fmt.Errorf and apply the %w verb
-// to the error argument:
+//     wrapsErr := fmt.Errorf("... %w ...", ..., err, ...)
 //
-//     errors.Unwrap(fmt.Errorf("... %w ...", ..., err, ...))
+// Successive unwrapping of an error creates a tree. The Is and As
+// functions inspect an error's tree by examining first the error
+// itself followed by the tree of each of its children in turn
+// (pre-order, depth-first traversal).
 //
-// returns err.
-//
-// Is unwraps its first argument sequentially looking for an error that matches the
-// second. It reports whether it finds a match. It should be used in preference to
-// simple equality checks:
+// Is examines the tree of its first argument looking for an error that
+// matches the second. It reports whether it finds a match. It should be
+// used in preference to simple equality checks:
 //
 //     if errors.Is(err, fs.ErrExist)
 //
@@ -35,7 +38,7 @@
 //
 // because the former will succeed if err wraps fs.ErrExist.
 //
-// As unwraps its first argument sequentially looking for an error that can be
+// As examines the tree of its first argument looking for an error that can be
 // assigned to its second argument, which must be a pointer. If it succeeds, it
 // performs the assignment and returns true. Otherwise, it returns false. The form
 //
index cf4df90b69503ced9aa5dd6eeb66596cbd3569eb..8b93f530d59aad4a3d73756e96fc8946537609ab 100644 (file)
@@ -51,3 +51,21 @@ func ExampleNew_errorf() {
        }
        // Output: user "bimmler" (id 17) not found
 }
+
+func ExampleJoin() {
+       err1 := errors.New("err1")
+       err2 := errors.New("err2")
+       err := errors.Join(err1, err2)
+       fmt.Println(err)
+       if errors.Is(err, err1) {
+               fmt.Println("err is err1")
+       }
+       if errors.Is(err, err2) {
+               fmt.Println("err is err2")
+       }
+       // Output:
+       // err1
+       // err2
+       // err is err1
+       // err is err2
+}
diff --git a/src/errors/join.go b/src/errors/join.go
new file mode 100644 (file)
index 0000000..dc5a716
--- /dev/null
@@ -0,0 +1,51 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package errors
+
+// Join returns an error that wraps the given errors.
+// Any nil error values are discarded.
+// Join returns nil if errs contains no non-nil values.
+// The error formats as the concatenation of the strings obtained
+// by calling the Error method of each element of errs, with a newline
+// between each string.
+func Join(errs ...error) error {
+       n := 0
+       for _, err := range errs {
+               if err != nil {
+                       n++
+               }
+       }
+       if n == 0 {
+               return nil
+       }
+       e := &joinError{
+               errs: make([]error, 0, n),
+       }
+       for _, err := range errs {
+               if err != nil {
+                       e.errs = append(e.errs, err)
+               }
+       }
+       return e
+}
+
+type joinError struct {
+       errs []error
+}
+
+func (e *joinError) Error() string {
+       var b []byte
+       for i, err := range e.errs {
+               if i > 0 {
+                       b = append(b, '\n')
+               }
+               b = append(b, err.Error()...)
+       }
+       return string(b)
+}
+
+func (e *joinError) Unwrap() []error {
+       return e.errs
+}
diff --git a/src/errors/join_test.go b/src/errors/join_test.go
new file mode 100644 (file)
index 0000000..ee69314
--- /dev/null
@@ -0,0 +1,49 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package errors_test
+
+import (
+       "errors"
+       "reflect"
+       "testing"
+)
+
+func TestJoinReturnsNil(t *testing.T) {
+       if err := errors.Join(); err != nil {
+               t.Errorf("errors.Join() = %v, want nil", err)
+       }
+       if err := errors.Join(nil); err != nil {
+               t.Errorf("errors.Join(nil) = %v, want nil", err)
+       }
+       if err := errors.Join(nil, nil); err != nil {
+               t.Errorf("errors.Join(nil, nil) = %v, want nil", err)
+       }
+}
+
+func TestJoin(t *testing.T) {
+       err1 := errors.New("err1")
+       err2 := errors.New("err2")
+       for _, test := range []struct {
+               errs []error
+               want []error
+       }{{
+               errs: []error{err1},
+               want: []error{err1},
+       }, {
+               errs: []error{err1, err2},
+               want: []error{err1, err2},
+       }, {
+               errs: []error{err1, nil, err2},
+               want: []error{err1, err2},
+       }} {
+               got := errors.Join(test.errs...).(interface{ Unwrap() []error }).Unwrap()
+               if !reflect.DeepEqual(got, test.want) {
+                       t.Errorf("Join(%v) = %v; want %v", test.errs, got, test.want)
+               }
+               if len(got) != cap(got) {
+                       t.Errorf("Join(%v) returns errors with len=%v, cap=%v; want len==cap", test.errs, len(got), cap(got))
+               }
+       }
+}
index 263ae16b48dc0c599590b31f4e2cd370a2f97dc2..a719655b10dbbf2de2f910fbf12c715000be2951 100644 (file)
@@ -11,6 +11,8 @@ import (
 // Unwrap returns the result of calling the Unwrap method on err, if err's
 // type contains an Unwrap method returning error.
 // Otherwise, Unwrap returns nil.
+//
+// Unwrap returns nil if the Unwrap method returns []error.
 func Unwrap(err error) error {
        u, ok := err.(interface {
                Unwrap() error
@@ -21,10 +23,11 @@ func Unwrap(err error) error {
        return u.Unwrap()
 }
 
-// Is reports whether any error in err's chain matches target.
+// Is reports whether any error in err's tree matches target.
 //
-// The chain consists of err itself followed by the sequence of errors obtained by
-// repeatedly calling Unwrap.
+// The tree consists of err itself, followed by the errors obtained by repeatedly
+// calling Unwrap. When err wraps multiple errors, Is examines err followed by a
+// depth-first traversal of its children.
 //
 // An error is considered to match a target if it is equal to that target or if
 // it implements a method Is(error) bool such that Is(target) returns true.
@@ -50,20 +53,31 @@ func Is(err, target error) bool {
                if x, ok := err.(interface{ Is(error) bool }); ok && x.Is(target) {
                        return true
                }
-               // TODO: consider supporting target.Is(err). This would allow
-               // user-definable predicates, but also may allow for coping with sloppy
-               // APIs, thereby making it easier to get away with them.
-               if err = Unwrap(err); err == nil {
+               switch x := err.(type) {
+               case interface{ Unwrap() error }:
+                       err = x.Unwrap()
+                       if err == nil {
+                               return false
+                       }
+               case interface{ Unwrap() []error }:
+                       for _, err := range x.Unwrap() {
+                               if Is(err, target) {
+                                       return true
+                               }
+                       }
+                       return false
+               default:
                        return false
                }
        }
 }
 
-// As finds the first error in err's chain that matches target, and if one is found, sets
+// As finds the first error in err's tree that matches target, and if one is found, sets
 // target to that error value and returns true. Otherwise, it returns false.
 //
-// The chain consists of err itself followed by the sequence of errors obtained by
-// repeatedly calling Unwrap.
+// The tree consists of err itself, followed by the errors obtained by repeatedly
+// calling Unwrap. When err wraps multiple errors, As examines err followed by a
+// depth-first traversal of its children.
 //
 // An error matches target if the error's concrete value is assignable to the value
 // pointed to by target, or if the error has a method As(interface{}) bool such that
@@ -76,6 +90,9 @@ func Is(err, target error) bool {
 // As panics if target is not a non-nil pointer to either a type that implements
 // error, or to any interface type.
 func As(err error, target any) bool {
+       if err == nil {
+               return false
+       }
        if target == nil {
                panic("errors: target cannot be nil")
        }
@@ -88,7 +105,7 @@ func As(err error, target any) bool {
        if targetType.Kind() != reflectlite.Interface && !targetType.Implements(errorType) {
                panic("errors: *target must be interface or implement error")
        }
-       for err != nil {
+       for {
                if reflectlite.TypeOf(err).AssignableTo(targetType) {
                        val.Elem().Set(reflectlite.ValueOf(err))
                        return true
@@ -96,9 +113,23 @@ func As(err error, target any) bool {
                if x, ok := err.(interface{ As(any) bool }); ok && x.As(target) {
                        return true
                }
-               err = Unwrap(err)
+               switch x := err.(type) {
+               case interface{ Unwrap() error }:
+                       err = x.Unwrap()
+                       if err == nil {
+                               return false
+                       }
+               case interface{ Unwrap() []error }:
+                       for _, err := range x.Unwrap() {
+                               if As(err, target) {
+                                       return true
+                               }
+                       }
+                       return false
+               default:
+                       return false
+               }
        }
-       return false
 }
 
 var errorType = reflectlite.TypeOf((*error)(nil)).Elem()
index eb8314b04bb67b8b88a341962fd84f54b7cdc594..9efbe45ee0b38ac44631c5a6eaf79ef0e10d33a8 100644 (file)
@@ -47,6 +47,17 @@ func TestIs(t *testing.T) {
                {&errorUncomparable{}, &errorUncomparable{}, false},
                {errorUncomparable{}, err1, false},
                {&errorUncomparable{}, err1, false},
+               {multiErr{}, err1, false},
+               {multiErr{err1, err3}, err1, true},
+               {multiErr{err3, err1}, err1, true},
+               {multiErr{err1, err3}, errors.New("x"), false},
+               {multiErr{err3, errb}, errb, true},
+               {multiErr{err3, errb}, erra, true},
+               {multiErr{err3, errb}, err1, true},
+               {multiErr{errb, err3}, err1, true},
+               {multiErr{poser}, err1, true},
+               {multiErr{poser}, err3, true},
+               {multiErr{nil}, nil, false},
        }
        for _, tc := range testCases {
                t.Run("", func(t *testing.T) {
@@ -148,6 +159,41 @@ func TestAs(t *testing.T) {
                &timeout,
                true,
                errF,
+       }, {
+               multiErr{},
+               &errT,
+               false,
+               nil,
+       }, {
+               multiErr{errors.New("a"), errorT{"T"}},
+               &errT,
+               true,
+               errorT{"T"},
+       }, {
+               multiErr{errorT{"T"}, errors.New("a")},
+               &errT,
+               true,
+               errorT{"T"},
+       }, {
+               multiErr{errorT{"a"}, errorT{"b"}},
+               &errT,
+               true,
+               errorT{"a"},
+       }, {
+               multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}},
+               &errT,
+               true,
+               errorT{"a"},
+       }, {
+               multiErr{wrapped{"path error", errF}},
+               &timeout,
+               true,
+               errF,
+       }, {
+               multiErr{nil},
+               &errT,
+               false,
+               nil,
        }}
        for i, tc := range testCases {
                name := fmt.Sprintf("%d:As(Errorf(..., %v), %v)", i, tc.err, tc.target)
@@ -223,9 +269,13 @@ type wrapped struct {
 }
 
 func (e wrapped) Error() string { return e.msg }
-
 func (e wrapped) Unwrap() error { return e.err }
 
+type multiErr []error
+
+func (m multiErr) Error() string   { return "multiError" }
+func (m multiErr) Unwrap() []error { return []error(m) }
+
 type errorUncomparable struct {
        f []string
 }
index 4f4daf19e1448b98e13e0f93da7ef4ab65809e40..1fbd39f8f17bf3601cff438202e4c875de073b1a 100644 (file)
@@ -4,26 +4,48 @@
 
 package fmt
 
-import "errors"
+import (
+       "errors"
+       "sort"
+)
 
 // Errorf formats according to a format specifier and returns the string as a
 // value that satisfies error.
 //
 // If the format specifier includes a %w verb with an error operand,
-// the returned error will implement an Unwrap method returning the operand. It is
-// invalid to include more than one %w verb or to supply it with an operand
-// that does not implement the error interface. The %w verb is otherwise
-// a synonym for %v.
+// the returned error will implement an Unwrap method returning the operand.
+// If there is more than one %w verb, the returned error will implement an
+// Unwrap method returning a []error containing all the %w operands in the
+// order they appear in the arguments.
+// It is invalid to supply the %w verb with an operand that does not implement
+// the error interface. The %w verb is otherwise a synonym for %v.
 func Errorf(format string, a ...any) error {
        p := newPrinter()
        p.wrapErrs = true
        p.doPrintf(format, a)
        s := string(p.buf)
        var err error
-       if p.wrappedErr == nil {
+       switch len(p.wrappedErrs) {
+       case 0:
                err = errors.New(s)
-       } else {
-               err = &wrapError{s, p.wrappedErr}
+       case 1:
+               w := &wrapError{msg: s}
+               w.err, _ = a[p.wrappedErrs[0]].(error)
+               err = w
+       default:
+               if p.reordered {
+                       sort.Ints(p.wrappedErrs)
+               }
+               var errs []error
+               for i, argNum := range p.wrappedErrs {
+                       if i > 0 && p.wrappedErrs[i-1] == argNum {
+                               continue
+                       }
+                       if e, ok := a[argNum].(error); ok {
+                               errs = append(errs, e)
+                       }
+               }
+               err = &wrapErrors{s, errs}
        }
        p.free()
        return err
@@ -41,3 +63,16 @@ func (e *wrapError) Error() string {
 func (e *wrapError) Unwrap() error {
        return e.err
 }
+
+type wrapErrors struct {
+       msg  string
+       errs []error
+}
+
+func (e *wrapErrors) Error() string {
+       return e.msg
+}
+
+func (e *wrapErrors) Unwrap() []error {
+       return e.errs
+}
index 481a7b840302f40026b423e7a641b364f0e5c49a..4eb55faffe7a181777ae1961d35f3a9f66a5042b 100644 (file)
@@ -7,6 +7,7 @@ package fmt_test
 import (
        "errors"
        "fmt"
+       "reflect"
        "testing"
 )
 
@@ -20,6 +21,7 @@ func TestErrorf(t *testing.T) {
                err        error
                wantText   string
                wantUnwrap error
+               wantSplit  []error
        }{{
                err:        fmt.Errorf("%w", wrapped),
                wantText:   "inner error",
@@ -53,11 +55,29 @@ func TestErrorf(t *testing.T) {
                err:      noVetErrorf("%w is not an error", "not-an-error"),
                wantText: "%!w(string=not-an-error) is not an error",
        }, {
-               err:      noVetErrorf("wrapped two errors: %w %w", errString("1"), errString("2")),
-               wantText: "wrapped two errors: 1 %!w(fmt_test.errString=2)",
+               err:       noVetErrorf("wrapped two errors: %w %w", errString("1"), errString("2")),
+               wantText:  "wrapped two errors: 1 2",
+               wantSplit: []error{errString("1"), errString("2")},
        }, {
-               err:      noVetErrorf("wrapped three errors: %w %w %w", errString("1"), errString("2"), errString("3")),
-               wantText: "wrapped three errors: 1 %!w(fmt_test.errString=2) %!w(fmt_test.errString=3)",
+               err:       noVetErrorf("wrapped three errors: %w %w %w", errString("1"), errString("2"), errString("3")),
+               wantText:  "wrapped three errors: 1 2 3",
+               wantSplit: []error{errString("1"), errString("2"), errString("3")},
+       }, {
+               err:       noVetErrorf("wrapped nil error: %w %w %w", errString("1"), nil, errString("2")),
+               wantText:  "wrapped nil error: 1 %!w(<nil>) 2",
+               wantSplit: []error{errString("1"), errString("2")},
+       }, {
+               err:       noVetErrorf("wrapped one non-error: %w %w %w", errString("1"), "not-an-error", errString("3")),
+               wantText:  "wrapped one non-error: 1 %!w(string=not-an-error) 3",
+               wantSplit: []error{errString("1"), errString("3")},
+       }, {
+               err:       fmt.Errorf("wrapped errors out of order: %[3]w %[2]w %[1]w", errString("1"), errString("2"), errString("3")),
+               wantText:  "wrapped errors out of order: 3 2 1",
+               wantSplit: []error{errString("1"), errString("2"), errString("3")},
+       }, {
+               err:       fmt.Errorf("wrapped several times: %[1]w %[1]w %[2]w %[1]w", errString("1"), errString("2")),
+               wantText:  "wrapped several times: 1 1 2 1",
+               wantSplit: []error{errString("1"), errString("2")},
        }, {
                err:        fmt.Errorf("%w", nil),
                wantText:   "%!w(<nil>)",
@@ -66,12 +86,22 @@ func TestErrorf(t *testing.T) {
                if got, want := errors.Unwrap(test.err), test.wantUnwrap; got != want {
                        t.Errorf("Formatted error: %v\nerrors.Unwrap() = %v, want %v", test.err, got, want)
                }
+               if got, want := splitErr(test.err), test.wantSplit; !reflect.DeepEqual(got, want) {
+                       t.Errorf("Formatted error: %v\nUnwrap() []error = %v, want %v", test.err, got, want)
+               }
                if got, want := test.err.Error(), test.wantText; got != want {
                        t.Errorf("err.Error() = %q, want %q", got, want)
                }
        }
 }
 
+func splitErr(err error) []error {
+       if e, ok := err.(interface{ Unwrap() []error }); ok {
+               return e.Unwrap()
+       }
+       return nil
+}
+
 type errString string
 
 func (e errString) Error() string { return string(e) }
index 4eabda1ce862c2ddcf1e264a773bcb9e26d6368d..b3dd43ce04eeb11c83c35d3e395c006d14e42da8 100644 (file)
@@ -139,8 +139,8 @@ type pp struct {
        erroring bool
        // wrapErrs is set when the format string may contain a %w verb.
        wrapErrs bool
-       // wrappedErr records the target of the %w verb.
-       wrappedErr error
+       // wrappedErrs records the targets of the %w verb.
+       wrappedErrs []int
 }
 
 var ppFree = sync.Pool{
@@ -171,10 +171,13 @@ func (p *pp) free() {
        } else {
                p.buf = p.buf[:0]
        }
+       if cap(p.wrappedErrs) > 8 {
+               p.wrappedErrs = nil
+       }
 
        p.arg = nil
        p.value = reflect.Value{}
-       p.wrappedErr = nil
+       p.wrappedErrs = p.wrappedErrs[:0]
        ppFree.Put(p)
 }
 
@@ -620,16 +623,12 @@ func (p *pp) handleMethods(verb rune) (handled bool) {
                return
        }
        if verb == 'w' {
-               // It is invalid to use %w other than with Errorf, more than once,
-               // or with a non-error arg.
-               err, ok := p.arg.(error)
-               if !ok || !p.wrapErrs || p.wrappedErr != nil {
-                       p.wrappedErr = nil
-                       p.wrapErrs = false
+               // It is invalid to use %w other than with Errorf or with a non-error arg.
+               _, ok := p.arg.(error)
+               if !ok || !p.wrapErrs {
                        p.badVerb(verb)
                        return true
                }
-               p.wrappedErr = err
                // If the arg is a Formatter, pass 'v' as the verb to it.
                verb = 'v'
        }
@@ -1063,7 +1062,11 @@ formatLoop:
                                // Fast path for common case of ascii lower case simple verbs
                                // without precision or width or argument indices.
                                if 'a' <= c && c <= 'z' && argNum < len(a) {
-                                       if c == 'v' {
+                                       switch c {
+                                       case 'w':
+                                               p.wrappedErrs = append(p.wrappedErrs, argNum)
+                                               fallthrough
+                                       case 'v':
                                                // Go syntax
                                                p.fmt.sharpV = p.fmt.sharp
                                                p.fmt.sharp = false
@@ -1158,6 +1161,9 @@ formatLoop:
                        p.badArgNum(verb)
                case argNum >= len(a): // No argument left over to print for the current verb.
                        p.missingArg(verb)
+               case verb == 'w':
+                       p.wrappedErrs = append(p.wrappedErrs, argNum)
+                       fallthrough
                case verb == 'v':
                        // Go syntax
                        p.fmt.sharpV = p.fmt.sharp