]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/flag/flag_test.go
flag: add BoolFunc; FlagSet.BoolFunc
[gostls13.git] / src / flag / flag_test.go
index 17551684055cbdd3ce3b494c6750fe66df51ba3a..14d199d6e950a80f05274402e2c216c8dc7ecbd4 100644 (file)
@@ -38,6 +38,7 @@ func TestEverything(t *testing.T) {
        Float64("test_float64", 0, "float64 value")
        Duration("test_duration", 0, "time.Duration value")
        Func("test_func", "func value", func(string) error { return nil })
+       BoolFunc("test_boolfunc", "func", func(string) error { return nil })
 
        m := make(map[string]*Flag)
        desired := "0"
@@ -54,6 +55,8 @@ func TestEverything(t *testing.T) {
                                ok = true
                        case f.Name == "test_func" && f.Value.String() == "":
                                ok = true
+                       case f.Name == "test_boolfunc" && f.Value.String() == "":
+                               ok = true
                        }
                        if !ok {
                                t.Error("Visit: bad value", f.Value.String(), "for", f.Name)
@@ -61,7 +64,7 @@ func TestEverything(t *testing.T) {
                }
        }
        VisitAll(visitor)
-       if len(m) != 9 {
+       if len(m) != 10 {
                t.Error("VisitAll misses some flags")
                for k, v := range m {
                        t.Log(k, *v)
@@ -85,9 +88,10 @@ func TestEverything(t *testing.T) {
        Set("test_float64", "1")
        Set("test_duration", "1s")
        Set("test_func", "1")
+       Set("test_boolfunc", "")
        desired = "1"
        Visit(visitor)
-       if len(m) != 9 {
+       if len(m) != 10 {
                t.Error("Visit fails after set")
                for k, v := range m {
                        t.Log(k, *v)
@@ -797,3 +801,46 @@ func TestRedefinedFlags(t *testing.T) {
                }
        }
 }
+
+func TestUserDefinedBoolFunc(t *testing.T) {
+       flags := NewFlagSet("test", ContinueOnError)
+       flags.SetOutput(io.Discard)
+       var ss []string
+       flags.BoolFunc("v", "usage", func(s string) error {
+               ss = append(ss, s)
+               return nil
+       })
+       if err := flags.Parse([]string{"-v", "", "-v", "1", "-v=2"}); err != nil {
+               t.Error(err)
+       }
+       if len(ss) != 1 {
+               t.Fatalf("got %d args; want 1 arg", len(ss))
+       }
+       want := "[true]"
+       if got := fmt.Sprint(ss); got != want {
+               t.Errorf("got %q; want %q", got, want)
+       }
+       // test usage
+       var buf strings.Builder
+       flags.SetOutput(&buf)
+       flags.Parse([]string{"-h"})
+       if usage := buf.String(); !strings.Contains(usage, "usage") {
+               t.Errorf("usage string not included: %q", usage)
+       }
+       // test BoolFunc error
+       flags = NewFlagSet("test", ContinueOnError)
+       flags.SetOutput(io.Discard)
+       flags.BoolFunc("v", "usage", func(s string) error {
+               return fmt.Errorf("test error")
+       })
+       // flag not set, so no error
+       if err := flags.Parse(nil); err != nil {
+               t.Error(err)
+       }
+       // flag set, expect error
+       if err := flags.Parse([]string{"-v", ""}); err == nil {
+               t.Error("got err == nil; want err != nil")
+       } else if errMsg := err.Error(); !strings.Contains(errMsg, "test error") {
+               t.Errorf(`got %q; error should contain "test error"`, errMsg)
+       }
+}