--- /dev/null
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import "math"
+
+// A routingIndex optimizes conflict detection by indexing patterns.
+//
+// The basic idea is to rule out patterns that cannot conflict with a given
+// pattern because they have a different literal in a corresponding segment.
+// See the comments in [routingIndex.possiblyConflictingPatterns] for more details.
+type routingIndex struct {
+ // map from a particular segment position and value to all registered patterns
+ // with that value in that position.
+ // For example, the key {1, "b"} would hold the patterns "/a/b" and "/a/b/c"
+ // but not "/a", "b/a", "/a/c" or "/a/{x}".
+ segments map[routingIndexKey][]*pattern
+ // All patterns that end in a multi wildcard (including trailing slash).
+ // We do not try to be clever about indexing multi patterns, because there
+ // are unlikely to be many of them.
+ multis []*pattern
+}
+
+type routingIndexKey struct {
+ pos int // 0-based segment position
+ s string // literal, or empty for wildcard
+}
+
+func (idx *routingIndex) addPattern(pat *pattern) {
+ if pat.lastSegment().multi {
+ idx.multis = append(idx.multis, pat)
+ } else {
+ if idx.segments == nil {
+ idx.segments = map[routingIndexKey][]*pattern{}
+ }
+ for pos, seg := range pat.segments {
+ key := routingIndexKey{pos: pos, s: ""}
+ if !seg.wild {
+ key.s = seg.s
+ }
+ idx.segments[key] = append(idx.segments[key], pat)
+ }
+ }
+}
+
+// possiblyConflictingPatterns calls f on all patterns that might conflict with
+// pat. If f returns a non-nil error, possiblyConflictingPatterns returns immediately
+// with that error.
+//
+// To be correct, possiblyConflictingPatterns must include all patterns that
+// might conflict. But it may also include patterns that cannot conflict.
+// For instance, an implementation that returns all registered patterns is correct.
+// We use this fact throughout, simplifying the implementation by returning more
+// patterns that we might need to.
+func (idx *routingIndex) possiblyConflictingPatterns(pat *pattern, f func(*pattern) error) (err error) {
+ // Terminology:
+ // dollar pattern: one ending in "{$}"
+ // multi pattern: one ending in a trailing slash or "{x...}" wildcard
+ // ordinary pattern: neither of the above
+
+ // apply f to all the pats, stopping on error.
+ apply := func(pats []*pattern) error {
+ if err != nil {
+ return err
+ }
+ for _, p := range pats {
+ err = f(p)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+
+ // Our simple indexing scheme doesn't try to prune multi patterns; assume
+ // any of them can match the argument.
+ if err := apply(idx.multis); err != nil {
+ return err
+ }
+ if pat.lastSegment().s == "/" {
+ // All paths that a dollar pattern matches end in a slash; no paths that
+ // an ordinary pattern matches do. So only other dollar or multi
+ // patterns can conflict with a dollar pattern. Furthermore, conflicting
+ // dollar patterns must have the {$} in the same position.
+ return apply(idx.segments[routingIndexKey{s: "/", pos: len(pat.segments) - 1}])
+ }
+ // For ordinary and multi patterns, the only conflicts can be with a multi,
+ // or a pattern that has the same literal or a wildcard at some literal
+ // position.
+ // We could intersect all the possible matches at each position, but we
+ // do something simpler: we find the position with the fewest patterns.
+ var lmin, wmin []*pattern
+ min := math.MaxInt
+ hasLit := false
+ for i, seg := range pat.segments {
+ if seg.multi {
+ break
+ }
+ if !seg.wild {
+ hasLit = true
+ lpats := idx.segments[routingIndexKey{s: seg.s, pos: i}]
+ wpats := idx.segments[routingIndexKey{s: "", pos: i}]
+ if sum := len(lpats) + len(wpats); sum < min {
+ lmin = lpats
+ wmin = wpats
+ min = sum
+ }
+ }
+ }
+ if hasLit {
+ apply(lmin)
+ apply(wmin)
+ return err
+ }
+
+ // This pattern is all wildcards.
+ // Check it against everything.
+ for _, pats := range idx.segments {
+ apply(pats)
+ }
+ return err
+}
--- /dev/null
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "fmt"
+ "slices"
+ "sort"
+ "strings"
+ "testing"
+)
+
+func TestIndex(t *testing.T) {
+ pats := []string{"HEAD /", "/a"}
+
+ var patterns []*pattern
+ var idx routingIndex
+ for _, p := range pats {
+ pat := mustParsePattern(t, p)
+ patterns = append(patterns, pat)
+ idx.addPattern(pat)
+ }
+
+ compare := func(pat *pattern) {
+ t.Helper()
+ got := indexConflicts(pat, &idx)
+ want := trueConflicts(pat, patterns)
+ if !slices.Equal(got, want) {
+ t.Errorf("%q:\ngot %q\nwant %q", pat, got, want)
+ }
+ }
+
+ compare(mustParsePattern(t, "GET /foo"))
+ compare(mustParsePattern(t, "GET /{x}"))
+}
+
+// This test works by comparing possiblyConflictingPatterns with
+// an exhaustive loop through all patterns.
+func FuzzIndex(f *testing.F) {
+ inits := []string{"/a", "/a/b", "/{x0}", "/{x0}/b", "/a/{x0}", "/a/{$}", "/a/b/{$}",
+ "/a/", "/a/b/", "/{x}/b/c/{$}", "GET /{x0}/", "HEAD /a"}
+
+ var patterns []*pattern
+ var idx routingIndex
+
+ // compare takes a fatalf function because fuzzing doesn't like
+ // it when the fuzz function calls f.Fatalf.
+ compare := func(pat *pattern, fatalf func(string, ...any)) {
+ got := indexConflicts(pat, &idx)
+ want := trueConflicts(pat, patterns)
+ if !slices.Equal(got, want) {
+ fatalf("%q:\ngot %q\nwant %q", pat, got, want)
+ }
+ }
+
+ for _, p := range inits {
+ pat, err := parsePattern(p)
+ if err != nil {
+ f.Fatal(err)
+ }
+ compare(pat, f.Fatalf)
+ patterns = append(patterns, pat)
+ idx.addPattern(pat)
+ f.Add(bytesFromPattern(pat))
+ }
+
+ f.Fuzz(func(t *testing.T, pb []byte) {
+ pat := bytesToPattern(pb)
+ if pat == nil {
+ return
+ }
+ compare(pat, t.Fatalf)
+ })
+}
+
+func trueConflicts(pat *pattern, pats []*pattern) []string {
+ var s []string
+ for _, p := range pats {
+ if pat.conflictsWith(p) {
+ s = append(s, p.String())
+ }
+ }
+ sort.Strings(s)
+ return s
+}
+
+func indexConflicts(pat *pattern, idx *routingIndex) []string {
+ var s []string
+ idx.possiblyConflictingPatterns(pat, func(p *pattern) error {
+ if pat.conflictsWith(p) {
+ s = append(s, p.String())
+ }
+ return nil
+ })
+ sort.Strings(s)
+ return slices.Compact(s)
+}
+
+// TODO: incorporate host and method; make encoding denser.
+func bytesToPattern(bs []byte) *pattern {
+ if len(bs) == 0 {
+ return nil
+ }
+ var sb strings.Builder
+ wc := 0
+ for _, b := range bs[:len(bs)-1] {
+ sb.WriteByte('/')
+ switch b & 0x3 {
+ case 0:
+ fmt.Fprintf(&sb, "{x%d}", wc)
+ wc++
+ case 1:
+ sb.WriteString("a")
+ case 2:
+ sb.WriteString("b")
+ case 3:
+ sb.WriteString("c")
+ }
+ }
+ sb.WriteByte('/')
+ switch bs[len(bs)-1] & 0x7 {
+ case 0:
+ fmt.Fprintf(&sb, "{x%d}", wc)
+ case 1:
+ sb.WriteString("a")
+ case 2:
+ sb.WriteString("b")
+ case 3:
+ sb.WriteString("c")
+ case 4, 5:
+ fmt.Fprintf(&sb, "{x%d...}", wc)
+ default:
+ sb.WriteString("{$}")
+ }
+ pat, err := parsePattern(sb.String())
+ if err != nil {
+ panic(err)
+ }
+ return pat
+}
+
+func bytesFromPattern(p *pattern) []byte {
+ var bs []byte
+ for _, s := range p.segments {
+ var b byte
+ switch {
+ case s.multi:
+ b = 4
+ case s.wild:
+ b = 0
+ case s.s == "/":
+ b = 7
+ case s.s == "a":
+ b = 1
+ case s.s == "b":
+ b = 2
+ case s.s == "c":
+ b = 3
+ default:
+ panic("bad pattern")
+ }
+ bs = append(bs, b)
+ }
+ return bs
+}
+
+func TestBytesPattern(t *testing.T) {
+ tests := []struct {
+ bs []byte
+ pat string
+ }{
+ {[]byte{0, 1, 2, 3}, "/{x0}/a/b/c"},
+ {[]byte{16, 17, 18, 19}, "/{x0}/a/b/c"},
+ {[]byte{4, 4}, "/{x0}/{x1...}"},
+ {[]byte{6, 7}, "/b/{$}"},
+ }
+ t.Run("To", func(t *testing.T) {
+ for _, test := range tests {
+ p := bytesToPattern(test.bs)
+ got := p.String()
+ if got != test.pat {
+ t.Errorf("%v: got %q, want %q", test.bs, got, test.pat)
+ }
+ }
+ })
+ t.Run("From", func(t *testing.T) {
+ for _, test := range tests {
+ p, err := parsePattern(test.pat)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := bytesFromPattern(p)
+ var want []byte
+ for _, b := range test.bs[:len(test.bs)-1] {
+ want = append(want, b%4)
+
+ }
+ want = append(want, test.bs[len(test.bs)-1]%8)
+ if !bytes.Equal(got, want) {
+ t.Errorf("%s: got %v, want %v", test.pat, got, want)
+ }
+ }
+ })
+}
type ServeMux struct {
mu sync.RWMutex
tree routingNode
- patterns []*pattern
+ index routingIndex
+ patterns []*pattern // TODO(jba): remove if possible
}
// NewServeMux allocates and returns a new ServeMux.
}
}
-func (mux *ServeMux) registerErr(pattern string, handler Handler) error {
- if pattern == "" {
+func (mux *ServeMux) registerErr(patstr string, handler Handler) error {
+ if patstr == "" {
return errors.New("http: invalid pattern")
}
if handler == nil {
return errors.New("http: nil handler")
}
- pat, err := parsePattern(pattern)
+ pat, err := parsePattern(patstr)
if err != nil {
- return fmt.Errorf("parsing %q: %w", pattern, err)
+ return fmt.Errorf("parsing %q: %w", patstr, err)
}
// Get the caller's location, for better conflict error messages.
mux.mu.Lock()
defer mux.mu.Unlock()
// Check for conflict.
- // This makes a quadratic number of calls to conflictsWith: we check
- // each pattern against every other pattern.
- // TODO(jba): add indexing to speed this up.
- for _, pat2 := range mux.patterns {
+ if err := mux.index.possiblyConflictingPatterns(pat, func(pat2 *pattern) error {
if pat.conflictsWith(pat2) {
return fmt.Errorf("pattern %q (registered at %s) conflicts with pattern %q (registered at %s)",
pat, pat.loc, pat2, pat2.loc)
}
+ return nil
+ }); err != nil {
+ return err
}
mux.tree.addPattern(pat, handler)
+ mux.index.addPattern(pat)
mux.patterns = append(mux.patterns, pat)
return nil
}