]> Cypherpunks.ru repositories - gostls13.git/commitdiff
net/http: index patterns for faster conflict detection
authorJonathan Amsterdam <jba@google.com>
Mon, 18 Sep 2023 19:51:04 +0000 (15:51 -0400)
committerJonathan Amsterdam <jba@google.com>
Tue, 19 Sep 2023 18:35:22 +0000 (18:35 +0000)
Add an index so that pattern registration isn't always quadratic.

If there were no index, then every pattern that was registered would
have to be compared to every existing pattern for conflicts. This
would make registration quadratic in the number of patterns, in every
case.

The index in this CL should help most of the time. If a pattern has a
literal segment, it will weed out all other patterns that have a
different literal in that position.

The worst case will still be quadratic, but it is unlikely that a set
of such patterns would arise naturally.

One novel (to me) aspect of the CL is the use of fuzz testing on data
that is neither a string nor a byte slice. The test uses fuzzing to
generate a byte slice, then decodes the byte slice into a valid
pattern (most of the time). This test actually caught a bug: see
https://go.dev/cl/529119.

Change-Id: Ice0be6547decb5ce75a8062e4e17227815d5d0b0
Reviewed-on: https://go-review.googlesource.com/c/go/+/529121
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
src/net/http/routing_index.go [new file with mode: 0644]
src/net/http/routing_index_test.go [new file with mode: 0644]
src/net/http/server.go
src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da [new file with mode: 0644]
src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3 [new file with mode: 0644]

diff --git a/src/net/http/routing_index.go b/src/net/http/routing_index.go
new file mode 100644 (file)
index 0000000..9ac42c9
--- /dev/null
@@ -0,0 +1,124 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import "math"
+
+// A routingIndex optimizes conflict detection by indexing patterns.
+//
+// The basic idea is to rule out patterns that cannot conflict with a given
+// pattern because they have a different literal in a corresponding segment.
+// See the comments in [routingIndex.possiblyConflictingPatterns] for more details.
+type routingIndex struct {
+       // map from a particular segment position and value to all registered patterns
+       // with that value in that position.
+       // For example, the key {1, "b"} would hold the patterns "/a/b" and "/a/b/c"
+       // but not "/a", "b/a", "/a/c" or "/a/{x}".
+       segments map[routingIndexKey][]*pattern
+       // All patterns that end in a multi wildcard (including trailing slash).
+       // We do not try to be clever about indexing multi patterns, because there
+       // are unlikely to be many of them.
+       multis []*pattern
+}
+
+type routingIndexKey struct {
+       pos int    // 0-based segment position
+       s   string // literal, or empty for wildcard
+}
+
+func (idx *routingIndex) addPattern(pat *pattern) {
+       if pat.lastSegment().multi {
+               idx.multis = append(idx.multis, pat)
+       } else {
+               if idx.segments == nil {
+                       idx.segments = map[routingIndexKey][]*pattern{}
+               }
+               for pos, seg := range pat.segments {
+                       key := routingIndexKey{pos: pos, s: ""}
+                       if !seg.wild {
+                               key.s = seg.s
+                       }
+                       idx.segments[key] = append(idx.segments[key], pat)
+               }
+       }
+}
+
+// possiblyConflictingPatterns calls f on all patterns that might conflict with
+// pat. If f returns a non-nil error, possiblyConflictingPatterns returns immediately
+// with that error.
+//
+// To be correct, possiblyConflictingPatterns must include all patterns that
+// might conflict. But it may also include patterns that cannot conflict.
+// For instance, an implementation that returns all registered patterns is correct.
+// We use this fact throughout, simplifying the implementation by returning more
+// patterns that we might need to.
+func (idx *routingIndex) possiblyConflictingPatterns(pat *pattern, f func(*pattern) error) (err error) {
+       // Terminology:
+       //   dollar pattern: one ending in "{$}"
+       //   multi pattern: one ending in a trailing slash or "{x...}" wildcard
+       //   ordinary pattern: neither of the above
+
+       // apply f to all the pats, stopping on error.
+       apply := func(pats []*pattern) error {
+               if err != nil {
+                       return err
+               }
+               for _, p := range pats {
+                       err = f(p)
+                       if err != nil {
+                               return err
+                       }
+               }
+               return nil
+       }
+
+       // Our simple indexing scheme doesn't try to prune multi patterns; assume
+       // any of them can match the argument.
+       if err := apply(idx.multis); err != nil {
+               return err
+       }
+       if pat.lastSegment().s == "/" {
+               // All paths that a dollar pattern matches end in a slash; no paths that
+               // an ordinary pattern matches do. So only other dollar or multi
+               // patterns can conflict with a dollar pattern. Furthermore, conflicting
+               // dollar patterns must have the {$} in the same position.
+               return apply(idx.segments[routingIndexKey{s: "/", pos: len(pat.segments) - 1}])
+       }
+       // For ordinary and multi patterns, the only conflicts can be with a multi,
+       // or a pattern that has the same literal or a wildcard at some literal
+       // position.
+       // We could intersect all the possible matches at each position, but we
+       // do something simpler: we find the position with the fewest patterns.
+       var lmin, wmin []*pattern
+       min := math.MaxInt
+       hasLit := false
+       for i, seg := range pat.segments {
+               if seg.multi {
+                       break
+               }
+               if !seg.wild {
+                       hasLit = true
+                       lpats := idx.segments[routingIndexKey{s: seg.s, pos: i}]
+                       wpats := idx.segments[routingIndexKey{s: "", pos: i}]
+                       if sum := len(lpats) + len(wpats); sum < min {
+                               lmin = lpats
+                               wmin = wpats
+                               min = sum
+                       }
+               }
+       }
+       if hasLit {
+               apply(lmin)
+               apply(wmin)
+               return err
+       }
+
+       // This pattern is all wildcards.
+       // Check it against everything.
+       for _, pats := range idx.segments {
+               apply(pats)
+       }
+       return err
+}
diff --git a/src/net/http/routing_index_test.go b/src/net/http/routing_index_test.go
new file mode 100644 (file)
index 0000000..7030fc8
--- /dev/null
@@ -0,0 +1,207 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+       "bytes"
+       "fmt"
+       "slices"
+       "sort"
+       "strings"
+       "testing"
+)
+
+func TestIndex(t *testing.T) {
+       pats := []string{"HEAD /", "/a"}
+
+       var patterns []*pattern
+       var idx routingIndex
+       for _, p := range pats {
+               pat := mustParsePattern(t, p)
+               patterns = append(patterns, pat)
+               idx.addPattern(pat)
+       }
+
+       compare := func(pat *pattern) {
+               t.Helper()
+               got := indexConflicts(pat, &idx)
+               want := trueConflicts(pat, patterns)
+               if !slices.Equal(got, want) {
+                       t.Errorf("%q:\ngot  %q\nwant %q", pat, got, want)
+               }
+       }
+
+       compare(mustParsePattern(t, "GET /foo"))
+       compare(mustParsePattern(t, "GET /{x}"))
+}
+
+// This test works by comparing possiblyConflictingPatterns with
+// an exhaustive loop through all patterns.
+func FuzzIndex(f *testing.F) {
+       inits := []string{"/a", "/a/b", "/{x0}", "/{x0}/b", "/a/{x0}", "/a/{$}", "/a/b/{$}",
+               "/a/", "/a/b/", "/{x}/b/c/{$}", "GET /{x0}/", "HEAD /a"}
+
+       var patterns []*pattern
+       var idx routingIndex
+
+       // compare takes a fatalf function because fuzzing doesn't like
+       // it when the fuzz function calls f.Fatalf.
+       compare := func(pat *pattern, fatalf func(string, ...any)) {
+               got := indexConflicts(pat, &idx)
+               want := trueConflicts(pat, patterns)
+               if !slices.Equal(got, want) {
+                       fatalf("%q:\ngot  %q\nwant %q", pat, got, want)
+               }
+       }
+
+       for _, p := range inits {
+               pat, err := parsePattern(p)
+               if err != nil {
+                       f.Fatal(err)
+               }
+               compare(pat, f.Fatalf)
+               patterns = append(patterns, pat)
+               idx.addPattern(pat)
+               f.Add(bytesFromPattern(pat))
+       }
+
+       f.Fuzz(func(t *testing.T, pb []byte) {
+               pat := bytesToPattern(pb)
+               if pat == nil {
+                       return
+               }
+               compare(pat, t.Fatalf)
+       })
+}
+
+func trueConflicts(pat *pattern, pats []*pattern) []string {
+       var s []string
+       for _, p := range pats {
+               if pat.conflictsWith(p) {
+                       s = append(s, p.String())
+               }
+       }
+       sort.Strings(s)
+       return s
+}
+
+func indexConflicts(pat *pattern, idx *routingIndex) []string {
+       var s []string
+       idx.possiblyConflictingPatterns(pat, func(p *pattern) error {
+               if pat.conflictsWith(p) {
+                       s = append(s, p.String())
+               }
+               return nil
+       })
+       sort.Strings(s)
+       return slices.Compact(s)
+}
+
+// TODO: incorporate host and method; make encoding denser.
+func bytesToPattern(bs []byte) *pattern {
+       if len(bs) == 0 {
+               return nil
+       }
+       var sb strings.Builder
+       wc := 0
+       for _, b := range bs[:len(bs)-1] {
+               sb.WriteByte('/')
+               switch b & 0x3 {
+               case 0:
+                       fmt.Fprintf(&sb, "{x%d}", wc)
+                       wc++
+               case 1:
+                       sb.WriteString("a")
+               case 2:
+                       sb.WriteString("b")
+               case 3:
+                       sb.WriteString("c")
+               }
+       }
+       sb.WriteByte('/')
+       switch bs[len(bs)-1] & 0x7 {
+       case 0:
+               fmt.Fprintf(&sb, "{x%d}", wc)
+       case 1:
+               sb.WriteString("a")
+       case 2:
+               sb.WriteString("b")
+       case 3:
+               sb.WriteString("c")
+       case 4, 5:
+               fmt.Fprintf(&sb, "{x%d...}", wc)
+       default:
+               sb.WriteString("{$}")
+       }
+       pat, err := parsePattern(sb.String())
+       if err != nil {
+               panic(err)
+       }
+       return pat
+}
+
+func bytesFromPattern(p *pattern) []byte {
+       var bs []byte
+       for _, s := range p.segments {
+               var b byte
+               switch {
+               case s.multi:
+                       b = 4
+               case s.wild:
+                       b = 0
+               case s.s == "/":
+                       b = 7
+               case s.s == "a":
+                       b = 1
+               case s.s == "b":
+                       b = 2
+               case s.s == "c":
+                       b = 3
+               default:
+                       panic("bad pattern")
+               }
+               bs = append(bs, b)
+       }
+       return bs
+}
+
+func TestBytesPattern(t *testing.T) {
+       tests := []struct {
+               bs  []byte
+               pat string
+       }{
+               {[]byte{0, 1, 2, 3}, "/{x0}/a/b/c"},
+               {[]byte{16, 17, 18, 19}, "/{x0}/a/b/c"},
+               {[]byte{4, 4}, "/{x0}/{x1...}"},
+               {[]byte{6, 7}, "/b/{$}"},
+       }
+       t.Run("To", func(t *testing.T) {
+               for _, test := range tests {
+                       p := bytesToPattern(test.bs)
+                       got := p.String()
+                       if got != test.pat {
+                               t.Errorf("%v: got %q, want %q", test.bs, got, test.pat)
+                       }
+               }
+       })
+       t.Run("From", func(t *testing.T) {
+               for _, test := range tests {
+                       p, err := parsePattern(test.pat)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       got := bytesFromPattern(p)
+                       var want []byte
+                       for _, b := range test.bs[:len(test.bs)-1] {
+                               want = append(want, b%4)
+
+                       }
+                       want = append(want, test.bs[len(test.bs)-1]%8)
+                       if !bytes.Equal(got, want) {
+                               t.Errorf("%s: got %v, want %v", test.pat, got, want)
+                       }
+               }
+       })
+}
index bc5bcb9a7171859cbf3f6e395ae43f8048998e5f..629d8d3c627bb7556aa0ed87be2eb1358e19196e 100644 (file)
@@ -2347,7 +2347,8 @@ func RedirectHandler(url string, code int) Handler {
 type ServeMux struct {
        mu       sync.RWMutex
        tree     routingNode
-       patterns []*pattern
+       index    routingIndex
+       patterns []*pattern // TODO(jba): remove if possible
 }
 
 // NewServeMux allocates and returns a new ServeMux.
@@ -2624,8 +2625,8 @@ func (mux *ServeMux) register(pattern string, handler Handler) {
        }
 }
 
-func (mux *ServeMux) registerErr(pattern string, handler Handler) error {
-       if pattern == "" {
+func (mux *ServeMux) registerErr(patstr string, handler Handler) error {
+       if patstr == "" {
                return errors.New("http: invalid pattern")
        }
        if handler == nil {
@@ -2635,9 +2636,9 @@ func (mux *ServeMux) registerErr(pattern string, handler Handler) error {
                return errors.New("http: nil handler")
        }
 
-       pat, err := parsePattern(pattern)
+       pat, err := parsePattern(patstr)
        if err != nil {
-               return fmt.Errorf("parsing %q: %w", pattern, err)
+               return fmt.Errorf("parsing %q: %w", patstr, err)
        }
 
        // Get the caller's location, for better conflict error messages.
@@ -2652,16 +2653,17 @@ func (mux *ServeMux) registerErr(pattern string, handler Handler) error {
        mux.mu.Lock()
        defer mux.mu.Unlock()
        // Check for conflict.
-       // This makes a quadratic number of calls to conflictsWith: we check
-       // each pattern against every other pattern.
-       // TODO(jba): add indexing to speed this up.
-       for _, pat2 := range mux.patterns {
+       if err := mux.index.possiblyConflictingPatterns(pat, func(pat2 *pattern) error {
                if pat.conflictsWith(pat2) {
                        return fmt.Errorf("pattern %q (registered at %s) conflicts with pattern %q (registered at %s)",
                                pat, pat.loc, pat2, pat2.loc)
                }
+               return nil
+       }); err != nil {
+               return err
        }
        mux.tree.addPattern(pat, handler)
+       mux.index.addPattern(pat)
        mux.patterns = append(mux.patterns, pat)
        return nil
 }
diff --git a/src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da b/src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da
new file mode 100644 (file)
index 0000000..06a7336
--- /dev/null
@@ -0,0 +1,2 @@
+go test fuzz v1
+[]byte("101$")
diff --git a/src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3 b/src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3
new file mode 100644 (file)
index 0000000..520bff1
--- /dev/null
@@ -0,0 +1,2 @@
+go test fuzz v1
+[]byte("1010")