--- /dev/null
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements a decision tree for fast matching of requests to
+// patterns.
+//
+// The root of the tree branches on the host of the request.
+// The next level branches on the method.
+// The remaining levels branch on consecutive segments of the path.
+//
+// The "more specific wins" precedence rule can result in backtracking.
+// For example, given the patterns
+// /a/b/z
+// /a/{x}/c
+// we will first try to match the path "/a/b/c" with /a/b/z, and
+// when that fails we will try against /a/{x}/c.
+
+package http
+
+import (
+ "net/url"
+ "strings"
+)
+
+// A routingNode is a node in the decision tree.
+// The same struct is used for leaf and interior nodes.
+type routingNode struct {
+ // A leaf node holds a single pattern and the Handler it was registered
+ // with.
+ pattern *pattern
+ handler Handler
+
+ // An interior node maps parts of the incoming request to child nodes.
+ // special children keys:
+ // "/" trailing slash (resulting from {$})
+ // "" single wildcard
+ // "*" multi wildcard
+ children mapping[string, *routingNode]
+ emptyChild *routingNode // optimization: child with key ""
+}
+
+// addPattern adds a pattern and its associated Handler to the tree
+// at root.
+func (root *routingNode) addPattern(p *pattern, h Handler) {
+ // First level of tree is host.
+ n := root.addChild(p.host)
+ // Second level of tree is method.
+ n = n.addChild(p.method)
+ // Remaining levels are path.
+ n.addSegments(p.segments, p, h)
+}
+
+// addSegments adds the given segments to the tree rooted at n.
+// If there are no segments, then n is a leaf node that holds
+// the given pattern and handler.
+func (n *routingNode) addSegments(segs []segment, p *pattern, h Handler) {
+ if len(segs) == 0 {
+ n.set(p, h)
+ return
+ }
+ seg := segs[0]
+ if seg.multi {
+ if len(segs) != 1 {
+ panic("multi wildcard not last")
+ }
+ n.addChild("*").set(p, h)
+ } else if seg.wild {
+ n.addChild("").addSegments(segs[1:], p, h)
+ } else {
+ n.addChild(seg.s).addSegments(segs[1:], p, h)
+ }
+}
+
+// set sets the pattern and handler for n, which
+// must be a leaf node.
+func (n *routingNode) set(p *pattern, h Handler) {
+ if n.pattern != nil || n.handler != nil {
+ panic("non-nil leaf fields")
+ }
+ n.pattern = p
+ n.handler = h
+}
+
+// addChild adds a child node with the given key to n
+// if one does not exist, and returns the child.
+func (n *routingNode) addChild(key string) *routingNode {
+ if key == "" {
+ if n.emptyChild == nil {
+ n.emptyChild = &routingNode{}
+ }
+ return n.emptyChild
+ }
+ if c := n.findChild(key); c != nil {
+ return c
+ }
+ c := &routingNode{}
+ n.children.add(key, c)
+ return c
+}
+
+// findChild returns the child of n with the given key, or nil
+// if there is no child with that key.
+func (n *routingNode) findChild(key string) *routingNode {
+ if key == "" {
+ return n.emptyChild
+ }
+ r, _ := n.children.find(key)
+ return r
+}
+
+// match returns the leaf node under root that matches the arguments, and a list
+// of values for pattern wildcards in the order that the wildcards appear.
+// For example, if the request path is "/a/b/c" and the pattern is "/{x}/b/{y}",
+// then the second return value will be []string{"a", "c"}.
+func (root *routingNode) match(host, method, path string) (*routingNode, []string) {
+ if host != "" {
+ // There is a host. If there is a pattern that specifies that host and it
+ // matches, we are done. If the pattern doesn't match, fall through to
+ // try patterns with no host.
+ if l, m := root.findChild(host).matchMethodAndPath(method, path); l != nil {
+ return l, m
+ }
+ }
+ return root.emptyChild.matchMethodAndPath(method, path)
+}
+
+// matchMethodAndPath matches the method and path.
+// Its return values are the same as [routingNode.match].
+// The receiver should be a child of the root.
+func (n *routingNode) matchMethodAndPath(method, path string) (*routingNode, []string) {
+ if n == nil {
+ return nil, nil
+ }
+ if l, m := n.findChild(method).matchPath(path, nil); l != nil {
+ // Exact match of method name.
+ return l, m
+ }
+ if method == "HEAD" {
+ // GET matches HEAD too.
+ if l, m := n.findChild("GET").matchPath(path, nil); l != nil {
+ return l, m
+ }
+ }
+ // No exact match; try patterns with no method.
+ return n.emptyChild.matchPath(path, nil)
+}
+
+// matchPath matches a path.
+// Its return values are the same as [routingNode.match].
+// matchPath calls itself recursively. The matches argument holds the wildcard matches
+// found so far.
+func (n *routingNode) matchPath(path string, matches []string) (*routingNode, []string) {
+ if n == nil {
+ return nil, nil
+ }
+ // If path is empty, then we are done.
+ // If n is a leaf node, we found a match; return it.
+ // If n is an interior node (which means it has a nil pattern),
+ // then we failed to match.
+ if path == "" {
+ if n.pattern == nil {
+ return nil, nil
+ }
+ return n, matches
+ }
+ // Get the first segment of path.
+ seg, rest := firstSegment(path)
+ // First try matching against patterns that have a literal for this position.
+ // We know by construction that such patterns are more specific than those
+ // with a wildcard at this position (they are either more specific, equivalent,
+ // or overlap, and we ruled out the first two when the patterns were registered).
+ if n, m := n.findChild(seg).matchPath(rest, matches); n != nil {
+ return n, m
+ }
+ // If matching a literal fails, try again with patterns that have a single
+ // wildcard (represented by an empty string in the child mapping).
+ // Again, by construction, patterns with a single wildcard must be more specific than
+ // those with a multi wildcard.
+ // We skip this step if the segment is a trailing slash, because single wildcards
+ // don't match trailing slashes.
+ if seg != "/" {
+ if n, m := n.emptyChild.matchPath(rest, append(matches, matchValue(seg))); n != nil {
+ return n, m
+ }
+ }
+ // Lastly, match the pattern (there can be at most one) that has a multi
+ // wildcard in this position to the rest of the path.
+ if c := n.findChild("*"); c != nil {
+ // Don't record a match for a nameless wildcard (which arises from a
+ // trailing slash in the pattern).
+ if c.pattern.lastSegment().s != "" {
+ matches = append(matches, matchValue(path[1:])) // remove initial slash
+ }
+ return c, matches
+ }
+ return nil, nil
+}
+
+func matchValue(path string) string {
+ m, err := url.PathUnescape(path)
+ if err != nil {
+ // Path is not properly escaped, so use the original.
+ return path
+ }
+ return m
+}
+
+// firstSegment splits path into its first segment, and the rest.
+// The path must begin with "/".
+// If path consists of only a slash, firstSegment returns ("/", "").
+func firstSegment(path string) (seg, rest string) {
+ if path == "/" {
+ return "/", ""
+ }
+ path = path[1:] // drop initial slash
+ i := strings.IndexByte(path, '/')
+ if i < 0 {
+ return path, ""
+ }
+ return path[:i], path[i:]
+}
--- /dev/null
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "fmt"
+ "io"
+ "sort"
+ "strings"
+ "testing"
+
+ "slices"
+)
+
+func TestRoutingFirstSegment(t *testing.T) {
+ for _, test := range []struct {
+ in string
+ want []string
+ }{
+ {"/a/b/c", []string{"a", "b", "c"}},
+ {"/a/b/", []string{"a", "b", "/"}},
+ {"/", []string{"/"}},
+ } {
+ var got []string
+ rest := test.in
+ for len(rest) > 0 {
+ var seg string
+ seg, rest = firstSegment(rest)
+ got = append(got, seg)
+ }
+ if !slices.Equal(got, test.want) {
+ t.Errorf("%q: got %v, want %v", test.in, got, test.want)
+ }
+ }
+}
+
+// TODO: test host and method
+var testTree *routingNode
+
+func getTestTree() *routingNode {
+ if testTree == nil {
+ testTree = buildTree("/a", "/a/b", "/a/{x}",
+ "/g/h/i", "/g/{x}/j",
+ "/a/b/{x...}", "/a/b/{y}", "/a/b/{$}")
+ }
+ return testTree
+}
+
+func buildTree(pats ...string) *routingNode {
+ root := &routingNode{}
+ for _, p := range pats {
+ pat, err := parsePattern(p)
+ if err != nil {
+ panic(err)
+ }
+ root.addPattern(pat, nil)
+ }
+ return root
+}
+
+func TestRoutingAddPattern(t *testing.T) {
+ want := `"":
+ "":
+ "a":
+ "/a"
+ "":
+ "/a/{x}"
+ "b":
+ "/a/b"
+ "":
+ "/a/b/{y}"
+ "*":
+ "/a/b/{x...}"
+ "/":
+ "/a/b/{$}"
+ "g":
+ "":
+ "j":
+ "/g/{x}/j"
+ "h":
+ "i":
+ "/g/h/i"
+`
+
+ var b strings.Builder
+ getTestTree().print(&b, 0)
+ got := b.String()
+ if got != want {
+ t.Errorf("got\n%s\nwant\n%s", got, want)
+ }
+}
+
+type testCase struct {
+ method, host, path string
+ wantPat string // "" for nil (no match)
+ wantMatches []string
+}
+
+func TestRoutingNodeMatch(t *testing.T) {
+
+ test := func(tree *routingNode, tests []testCase) {
+ t.Helper()
+ for _, test := range tests {
+ gotNode, gotMatches := tree.match(test.host, test.method, test.path)
+ got := ""
+ if gotNode != nil {
+ got = gotNode.pattern.String()
+ }
+ if got != test.wantPat {
+ t.Errorf("%s, %s, %s: got %q, want %q", test.host, test.method, test.path, got, test.wantPat)
+ }
+ if !slices.Equal(gotMatches, test.wantMatches) {
+ t.Errorf("%s, %s, %s: got matches %v, want %v", test.host, test.method, test.path, gotMatches, test.wantMatches)
+ }
+ }
+ }
+
+ test(getTestTree(), []testCase{
+ {"GET", "", "/a", "/a", nil},
+ {"Get", "", "/b", "", nil},
+ {"Get", "", "/a/b", "/a/b", nil},
+ {"Get", "", "/a/c", "/a/{x}", []string{"c"}},
+ {"Get", "", "/a/b/", "/a/b/{$}", nil},
+ {"Get", "", "/a/b/c", "/a/b/{y}", []string{"c"}},
+ {"Get", "", "/a/b/c/d", "/a/b/{x...}", []string{"c/d"}},
+ {"Get", "", "/g/h/i", "/g/h/i", nil},
+ {"Get", "", "/g/h/j", "/g/{x}/j", []string{"h"}},
+ })
+
+ tree := buildTree(
+ "/item/",
+ "POST /item/{user}",
+ "GET /item/{user}",
+ "/item/{user}",
+ "/item/{user}/{id}",
+ "/item/{user}/new",
+ "/item/{$}",
+ "POST alt.com/item/{user}",
+ "GET /headwins",
+ "HEAD /headwins",
+ "/path/{p...}")
+
+ test(tree, []testCase{
+ {"GET", "", "/item/jba",
+ "GET /item/{user}", []string{"jba"}},
+ {"POST", "", "/item/jba",
+ "POST /item/{user}", []string{"jba"}},
+ {"HEAD", "", "/item/jba",
+ "GET /item/{user}", []string{"jba"}},
+ {"get", "", "/item/jba",
+ "/item/{user}", []string{"jba"}}, // method matches are case-sensitive
+ {"POST", "", "/item/jba/17",
+ "/item/{user}/{id}", []string{"jba", "17"}},
+ {"GET", "", "/item/jba/new",
+ "/item/{user}/new", []string{"jba"}},
+ {"GET", "", "/item/",
+ "/item/{$}", []string{}},
+ {"GET", "", "/item/jba/17/line2",
+ "/item/", nil},
+ {"POST", "alt.com", "/item/jba",
+ "POST alt.com/item/{user}", []string{"jba"}},
+ {"GET", "alt.com", "/item/jba",
+ "GET /item/{user}", []string{"jba"}},
+ {"GET", "", "/item",
+ "", nil}, // does not match
+ {"GET", "", "/headwins",
+ "GET /headwins", nil},
+ {"HEAD", "", "/headwins", // HEAD is more specific than GET
+ "HEAD /headwins", nil},
+ {"GET", "", "/path/to/file",
+ "/path/{p...}", []string{"to/file"}},
+ })
+
+ // A pattern ending in {$} should only match URLS with a trailing slash.
+ pat1 := "/a/b/{$}"
+ test(buildTree(pat1), []testCase{
+ {"GET", "", "/a/b", "", nil},
+ {"GET", "", "/a/b/", pat1, nil},
+ {"GET", "", "/a/b/c", "", nil},
+ {"GET", "", "/a/b/c/d", "", nil},
+ })
+
+ // A pattern ending in a single wildcard should not match a trailing slash URL.
+ pat2 := "/a/b/{w}"
+ test(buildTree(pat2), []testCase{
+ {"GET", "", "/a/b", "", nil},
+ {"GET", "", "/a/b/", "", nil},
+ {"GET", "", "/a/b/c", pat2, []string{"c"}},
+ {"GET", "", "/a/b/c/d", "", nil},
+ })
+
+ // A pattern ending in a multi wildcard should match both URLs.
+ pat3 := "/a/b/{w...}"
+ test(buildTree(pat3), []testCase{
+ {"GET", "", "/a/b", "", nil},
+ {"GET", "", "/a/b/", pat3, []string{""}},
+ {"GET", "", "/a/b/c", pat3, []string{"c"}},
+ {"GET", "", "/a/b/c/d", pat3, []string{"c/d"}},
+ })
+
+ // All three of the above should work together.
+ test(buildTree(pat1, pat2, pat3), []testCase{
+ {"GET", "", "/a/b", "", nil},
+ {"GET", "", "/a/b/", pat1, nil},
+ {"GET", "", "/a/b/c", pat2, []string{"c"}},
+ {"GET", "", "/a/b/c/d", pat3, []string{"c/d"}},
+ })
+}
+
+func (n *routingNode) print(w io.Writer, level int) {
+ indent := strings.Repeat(" ", level)
+ if n.pattern != nil {
+ fmt.Fprintf(w, "%s%q\n", indent, n.pattern)
+ }
+ if n.emptyChild != nil {
+ fmt.Fprintf(w, "%s%q:\n", indent, "")
+ n.emptyChild.print(w, level+1)
+ }
+
+ var keys []string
+ n.children.eachPair(func(k string, _ *routingNode) bool {
+ keys = append(keys, k)
+ return true
+ })
+ sort.Strings(keys)
+
+ for _, k := range keys {
+ fmt.Fprintf(w, "%s%q:\n", indent, k)
+ n, _ := n.children.find(k)
+ n.print(w, level+1)
+ }
+}