]> Cypherpunks.ru repositories - gostls13.git/commitdiff
net/netip: add AddrPort.Compare and Prefix.Compare
authorDavid Anderson <danderson@tailscale.com>
Thu, 31 Aug 2023 22:46:45 +0000 (22:46 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 11 Sep 2023 20:26:41 +0000 (20:26 +0000)
Fixes #61642

Change-Id: I2262855dbe75135f70008e5df4634d2cfff76550
GitHub-Last-Rev: 949685a9e426f50f37753045d7527ebcccb082e7
GitHub-Pull-Request: golang/go#62387
Reviewed-on: https://go-review.googlesource.com/c/go/+/524616
Reviewed-by: Joseph Tsai <joetsai@digital-static.net>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
api/next/61642.txt [new file with mode: 0644]
src/net/netip/netip.go
src/net/netip/netip_test.go

diff --git a/api/next/61642.txt b/api/next/61642.txt
new file mode 100644 (file)
index 0000000..4c8bf25
--- /dev/null
@@ -0,0 +1,2 @@
+pkg net/netip, method (AddrPort) Compare(AddrPort) int #61642
+pkg net/netip, method (Prefix) Compare(Prefix) int #61642
index 0c9dc3246ccfd493378715d368e6f9cdfb1163af..99cb754fae87762b22c42b2014d6989558ee8019 100644 (file)
@@ -12,6 +12,7 @@
 package netip
 
 import (
+       "cmp"
        "errors"
        "math"
        "strconv"
@@ -1102,6 +1103,16 @@ func MustParseAddrPort(s string) AddrPort {
 // All ports are valid, including zero.
 func (p AddrPort) IsValid() bool { return p.ip.IsValid() }
 
+// Compare returns an integer comparing two AddrPorts.
+// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
+// AddrPorts sort first by IP address, then port.
+func (p AddrPort) Compare(p2 AddrPort) int {
+       if c := p.Addr().Compare(p2.Addr()); c != 0 {
+               return c
+       }
+       return cmp.Compare(p.Port(), p2.Port())
+}
+
 func (p AddrPort) String() string {
        switch p.ip.z {
        case z0:
@@ -1261,6 +1272,21 @@ func (p Prefix) isZero() bool { return p == Prefix{} }
 // IsSingleIP reports whether p contains exactly one IP.
 func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLen() }
 
+// Compare returns an integer comparing two prefixes.
+// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
+// Prefixes sort first by validity (invalid before valid), then
+// address family (IPv4 before IPv6), then prefix length, then
+// address.
+func (p Prefix) Compare(p2 Prefix) int {
+       if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
+               return c
+       }
+       if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
+               return c
+       }
+       return p.Addr().Compare(p2.Addr())
+}
+
 // ParsePrefix parses s as an IP address prefix.
 // The string can be in the form "192.168.1.0/24" or "2001:db8::/32",
 // the CIDR notation defined in RFC 4632 and RFC 4291.
index 0f80bb0ab0e56f81470f17ff3ffabbd423bc15e4..39893e0f6df5567bde7207fa459b2159a9cd0e4c 100644 (file)
@@ -14,6 +14,7 @@ import (
        "net"
        . "net/netip"
        "reflect"
+       "slices"
        "sort"
        "strings"
        "testing"
@@ -812,7 +813,7 @@ func TestAddrWellKnown(t *testing.T) {
        }
 }
 
-func TestLessCompare(t *testing.T) {
+func TestAddrLessCompare(t *testing.T) {
        tests := []struct {
                a, b Addr
                want bool
@@ -882,6 +883,109 @@ func TestLessCompare(t *testing.T) {
        }
 }
 
+func TestAddrPortCompare(t *testing.T) {
+       tests := []struct {
+               a, b AddrPort
+               want int
+       }{
+               {AddrPort{}, AddrPort{}, 0},
+               {AddrPort{}, mustIPPort("1.2.3.4:80"), -1},
+
+               {mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:80"), 0},
+               {mustIPPort("[::1]:80"), mustIPPort("[::1]:80"), 0},
+
+               {mustIPPort("1.2.3.4:80"), mustIPPort("2.3.4.5:22"), -1},
+               {mustIPPort("[::1]:80"), mustIPPort("[::2]:22"), -1},
+
+               {mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:443"), -1},
+               {mustIPPort("[::1]:80"), mustIPPort("[::1]:443"), -1},
+
+               {mustIPPort("1.2.3.4:80"), mustIPPort("[0102:0304::0]:80"), -1},
+       }
+       for _, tt := range tests {
+               got := tt.a.Compare(tt.b)
+               if got != tt.want {
+                       t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want)
+               }
+
+               // Also check inverse.
+               if got == tt.want {
+                       got2 := tt.b.Compare(tt.a)
+                       if want2 := -1 * tt.want; got2 != want2 {
+                               t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2)
+                       }
+               }
+       }
+
+       // And just sort.
+       values := []AddrPort{
+               mustIPPort("[::1]:80"),
+               mustIPPort("[::2]:80"),
+               AddrPort{},
+               mustIPPort("1.2.3.4:443"),
+               mustIPPort("8.8.8.8:8080"),
+               mustIPPort("[::1%foo]:1024"),
+       }
+       slices.SortFunc(values, func(a, b AddrPort) int { return a.Compare(b) })
+       got := fmt.Sprintf("%s", values)
+       want := `[invalid AddrPort 1.2.3.4:443 8.8.8.8:8080 [::1]:80 [::1%foo]:1024 [::2]:80]`
+       if got != want {
+               t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want)
+       }
+}
+
+func TestPrefixCompare(t *testing.T) {
+       tests := []struct {
+               a, b Prefix
+               want int
+       }{
+               {Prefix{}, Prefix{}, 0},
+               {Prefix{}, mustPrefix("1.2.3.0/24"), -1},
+
+               {mustPrefix("1.2.3.0/24"), mustPrefix("1.2.3.0/24"), 0},
+               {mustPrefix("fe80::/64"), mustPrefix("fe80::/64"), 0},
+
+               {mustPrefix("1.2.3.0/24"), mustPrefix("1.2.4.0/24"), -1},
+               {mustPrefix("fe80::/64"), mustPrefix("fe90::/64"), -1},
+
+               {mustPrefix("1.2.0.0/16"), mustPrefix("1.2.0.0/24"), -1},
+               {mustPrefix("fe80::/48"), mustPrefix("fe80::/64"), -1},
+
+               {mustPrefix("1.2.3.0/24"), mustPrefix("fe80::/8"), -1},
+       }
+       for _, tt := range tests {
+               got := tt.a.Compare(tt.b)
+               if got != tt.want {
+                       t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want)
+               }
+
+               // Also check inverse.
+               if got == tt.want {
+                       got2 := tt.b.Compare(tt.a)
+                       if want2 := -1 * tt.want; got2 != want2 {
+                               t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2)
+                       }
+               }
+       }
+
+       // And just sort.
+       values := []Prefix{
+               mustPrefix("1.2.3.0/24"),
+               mustPrefix("fe90::/64"),
+               mustPrefix("fe80::/64"),
+               mustPrefix("1.2.0.0/16"),
+               Prefix{},
+               mustPrefix("fe80::/48"),
+               mustPrefix("1.2.0.0/24"),
+       }
+       slices.SortFunc(values, func(a, b Prefix) int { return a.Compare(b) })
+       got := fmt.Sprintf("%s", values)
+       want := `[invalid Prefix 1.2.0.0/16 1.2.0.0/24 1.2.3.0/24 fe80::/48 fe80::/64 fe90::/64]`
+       if got != want {
+               t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want)
+       }
+}
+
 func TestIPStringExpanded(t *testing.T) {
        tests := []struct {
                ip Addr