]> Cypherpunks.ru repositories - gostls13.git/commitdiff
flag: add Func
authorCarl Johnson <me@carlmjohnson.net>
Thu, 27 Aug 2020 13:08:44 +0000 (13:08 +0000)
committerRuss Cox <rsc@golang.org>
Wed, 16 Sep 2020 17:13:14 +0000 (17:13 +0000)
Fixes #39557

Change-Id: Ida578f7484335e8c6bf927255f75377eda63b563
GitHub-Last-Rev: b97294f7669c24011e5b093179d65636512a84cd
GitHub-Pull-Request: golang/go#39880
Reviewed-on: https://go-review.googlesource.com/c/go/+/240014
Reviewed-by: Russ Cox <rsc@golang.org>
Trust: Ian Lance Taylor <iant@golang.org>

src/flag/example_func_test.go [new file with mode: 0644]
src/flag/flag.go
src/flag/flag_test.go

diff --git a/src/flag/example_func_test.go b/src/flag/example_func_test.go
new file mode 100644 (file)
index 0000000..7c30c5e
--- /dev/null
@@ -0,0 +1,41 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package flag_test
+
+import (
+       "errors"
+       "flag"
+       "fmt"
+       "net"
+       "os"
+)
+
+func ExampleFunc() {
+       fs := flag.NewFlagSet("ExampleFunc", flag.ContinueOnError)
+       fs.SetOutput(os.Stdout)
+       var ip net.IP
+       fs.Func("ip", "`IP address` to parse", func(s string) error {
+               ip = net.ParseIP(s)
+               if ip == nil {
+                       return errors.New("could not parse IP")
+               }
+               return nil
+       })
+       fs.Parse([]string{"-ip", "127.0.0.1"})
+       fmt.Printf("{ip: %v, loopback: %t}\n\n", ip, ip.IsLoopback())
+
+       // 256 is not a valid IPv4 component
+       fs.Parse([]string{"-ip", "256.0.0.1"})
+       fmt.Printf("{ip: %v, loopback: %t}\n\n", ip, ip.IsLoopback())
+
+       // Output:
+       // {ip: 127.0.0.1, loopback: true}
+       //
+       // invalid value "256.0.0.1" for flag -ip: could not parse IP
+       // Usage of ExampleFunc:
+       //   -ip IP address
+       //      IP address to parse
+       // {ip: <nil>, loopback: false}
+}
index 286bba687390fee3c272b99b18a67afcef31b08e..a8485f034f30c13cee232c0d5ffcfa15bb36c8e2 100644 (file)
@@ -278,6 +278,12 @@ func (d *durationValue) Get() interface{} { return time.Duration(*d) }
 
 func (d *durationValue) String() string { return (*time.Duration)(d).String() }
 
+type funcValue func(string) error
+
+func (f funcValue) Set(s string) error { return f(s) }
+
+func (f funcValue) String() string { return "" }
+
 // Value is the interface to the dynamic value stored in a flag.
 // (The default value is represented as a string.)
 //
@@ -296,7 +302,7 @@ type Value interface {
 // Getter is an interface that allows the contents of a Value to be retrieved.
 // It wraps the Value interface, rather than being part of it, because it
 // appeared after Go 1 and its compatibility rules. All Value types provided
-// by this package satisfy the Getter interface.
+// by this package satisfy the Getter interface, except the type used by Func.
 type Getter interface {
        Value
        Get() interface{}
@@ -830,6 +836,20 @@ func Duration(name string, value time.Duration, usage string) *time.Duration {
        return CommandLine.Duration(name, value, usage)
 }
 
+// Func defines a flag with the specified name and usage string.
+// Each time the flag is seen, fn is called with the value of the flag.
+// If fn returns a non-nil error, it will be treated as a flag value parsing error.
+func (f *FlagSet) Func(name, usage string, fn func(string) error) {
+       f.Var(funcValue(fn), name, usage)
+}
+
+// Func defines a flag with the specified name and usage string.
+// Each time the flag is seen, fn is called with the value of the flag.
+// If fn returns a non-nil error, it will be treated as a flag value parsing error.
+func Func(name, usage string, fn func(string) error) {
+       CommandLine.Func(name, usage, fn)
+}
+
 // Var defines a flag with the specified name and usage string. The type and
 // value of the flag are represented by the first argument, of type Value, which
 // typically holds a user-defined implementation of Value. For instance, the
index a01a5e4cea2ffef91bc2bc05cf88edf1aa94a7ee..2793064511616ce41f0639176564647a5e42bcb7 100644 (file)
@@ -38,6 +38,7 @@ func TestEverything(t *testing.T) {
        String("test_string", "0", "string value")
        Float64("test_float64", 0, "float64 value")
        Duration("test_duration", 0, "time.Duration value")
+       Func("test_func", "func value", func(string) error { return nil })
 
        m := make(map[string]*Flag)
        desired := "0"
@@ -52,6 +53,8 @@ func TestEverything(t *testing.T) {
                                ok = true
                        case f.Name == "test_duration" && f.Value.String() == desired+"s":
                                ok = true
+                       case f.Name == "test_func" && f.Value.String() == "":
+                               ok = true
                        }
                        if !ok {
                                t.Error("Visit: bad value", f.Value.String(), "for", f.Name)
@@ -59,7 +62,7 @@ func TestEverything(t *testing.T) {
                }
        }
        VisitAll(visitor)
-       if len(m) != 8 {
+       if len(m) != 9 {
                t.Error("VisitAll misses some flags")
                for k, v := range m {
                        t.Log(k, *v)
@@ -82,9 +85,10 @@ func TestEverything(t *testing.T) {
        Set("test_string", "1")
        Set("test_float64", "1")
        Set("test_duration", "1s")
+       Set("test_func", "1")
        desired = "1"
        Visit(visitor)
-       if len(m) != 8 {
+       if len(m) != 9 {
                t.Error("Visit fails after set")
                for k, v := range m {
                        t.Log(k, *v)
@@ -257,6 +261,48 @@ func TestUserDefined(t *testing.T) {
        }
 }
 
+func TestUserDefinedFunc(t *testing.T) {
+       var flags FlagSet
+       flags.Init("test", ContinueOnError)
+       var ss []string
+       flags.Func("v", "usage", func(s string) error {
+               ss = append(ss, s)
+               return nil
+       })
+       if err := flags.Parse([]string{"-v", "1", "-v", "2", "-v=3"}); err != nil {
+               t.Error(err)
+       }
+       if len(ss) != 3 {
+               t.Fatal("expected 3 args; got ", len(ss))
+       }
+       expect := "[1 2 3]"
+       if got := fmt.Sprint(ss); got != expect {
+               t.Errorf("expected value %q got %q", expect, got)
+       }
+       // test usage
+       var buf strings.Builder
+       flags.SetOutput(&buf)
+       flags.Parse([]string{"-h"})
+       if usage := buf.String(); !strings.Contains(usage, "usage") {
+               t.Errorf("usage string not included: %q", usage)
+       }
+       // test Func error
+       flags = *NewFlagSet("test", ContinueOnError)
+       flags.Func("v", "usage", func(s string) error {
+               return fmt.Errorf("test error")
+       })
+       // flag not set, so no error
+       if err := flags.Parse(nil); err != nil {
+               t.Error(err)
+       }
+       // flag set, expect error
+       if err := flags.Parse([]string{"-v", "1"}); err == nil {
+               t.Error("expected error; got none")
+       } else if errMsg := err.Error(); !strings.Contains(errMsg, "test error") {
+               t.Errorf(`error should contain "test error"; got %q`, errMsg)
+       }
+}
+
 func TestUserDefinedForCommandLine(t *testing.T) {
        const help = "HELP"
        var result string