]> Cypherpunks.ru repositories - gostls13.git/commitdiff
net/http: switch HTTP1 to ASCII equivalents of string functions
authorRoberto Clapis <roberto@golang.org>
Wed, 7 Apr 2021 12:36:40 +0000 (14:36 +0200)
committerFilippo Valsorda <filippo@golang.org>
Mon, 10 May 2021 23:42:56 +0000 (23:42 +0000)
The current implementation uses UTF-aware functions
like strings.EqualFold and strings.ToLower.

This could, in some cases, cause http smuggling.

Change-Id: I0e76a993470a1e1b1b472f4b2859ea0a2b22ada0
Reviewed-on: https://go-review.googlesource.com/c/go/+/308009
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Roberto Clapis <roberto@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
14 files changed:
src/go/build/deps_test.go
src/net/http/client.go
src/net/http/cookie.go
src/net/http/cookiejar/jar.go
src/net/http/cookiejar/punycode.go
src/net/http/header.go
src/net/http/http.go
src/net/http/httputil/reverseproxy.go
src/net/http/httputil/reverseproxy_test.go
src/net/http/internal/ascii/print.go [new file with mode: 0644]
src/net/http/internal/ascii/print_test.go [new file with mode: 0644]
src/net/http/request.go
src/net/http/transfer.go
src/net/http/transport.go

index 3a0e769284d3f430f49caebd397e61659249f989..5d1cf7f4c97cfd7a3755a7801921321908b81406 100644 (file)
@@ -440,7 +440,7 @@ var depsRules = `
        # HTTP, King of Dependencies.
 
        FMT
-       < golang.org/x/net/http2/hpack, net/http/internal;
+       < golang.org/x/net/http2/hpack, net/http/internal, net/http/internal/ascii;
 
        FMT, NET, container/list, encoding/binary, log
        < golang.org/x/text/transform
@@ -458,6 +458,7 @@ var depsRules = `
        golang.org/x/net/http/httpproxy,
        golang.org/x/net/http2/hpack,
        net/http/internal,
+       net/http/internal/ascii,
        net/http/httptrace,
        mime/multipart,
        log
@@ -468,7 +469,7 @@ var depsRules = `
        encoding/json, net/http
        < expvar;
 
-       net/http
+       net/http, net/http/internal/ascii
        < net/http/cookiejar, net/http/httputil;
 
        net/http, flag
index 82e665829e9ffba67b0828c85c000368402c0370..03c9155fbdd79e4bd85824de1fd92a2e29c3dcf4 100644 (file)
@@ -17,6 +17,7 @@ import (
        "fmt"
        "io"
        "log"
+       "net/http/internal/ascii"
        "net/url"
        "reflect"
        "sort"
@@ -547,7 +548,10 @@ func urlErrorOp(method string) string {
        if method == "" {
                return "Get"
        }
-       return method[:1] + strings.ToLower(method[1:])
+       if lowerMethod, ok := ascii.ToLower(method); ok {
+               return method[:1] + lowerMethod[1:]
+       }
+       return method
 }
 
 // Do sends an HTTP request and returns an HTTP response, following
index 141bc947f6e960f327045895896d1321027b6075..ca2c1c2506694cc7c636c09af3b7cf9921352287 100644 (file)
@@ -7,6 +7,7 @@ package http
 import (
        "log"
        "net"
+       "net/http/internal/ascii"
        "net/textproto"
        "strconv"
        "strings"
@@ -93,15 +94,23 @@ func readSetCookies(h Header) []*Cookie {
                        if j := strings.Index(attr, "="); j >= 0 {
                                attr, val = attr[:j], attr[j+1:]
                        }
-                       lowerAttr := strings.ToLower(attr)
+                       lowerAttr, isASCII := ascii.ToLower(attr)
+                       if !isASCII {
+                               continue
+                       }
                        val, ok = parseCookieValue(val, false)
                        if !ok {
                                c.Unparsed = append(c.Unparsed, parts[i])
                                continue
                        }
+
                        switch lowerAttr {
                        case "samesite":
-                               lowerVal := strings.ToLower(val)
+                               lowerVal, ascii := ascii.ToLower(val)
+                               if !ascii {
+                                       c.SameSite = SameSiteDefaultMode
+                                       continue
+                               }
                                switch lowerVal {
                                case "lax":
                                        c.SameSite = SameSiteLaxMode
index 9f1991708472b9a0a2438c2a0334c23be352b20b..e6583da7fe67df21caf82a4f0030c07ed8869b76 100644 (file)
@@ -10,6 +10,7 @@ import (
        "fmt"
        "net"
        "net/http"
+       "net/http/internal/ascii"
        "net/url"
        "sort"
        "strings"
@@ -296,7 +297,6 @@ func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
 // host name.
 func canonicalHost(host string) (string, error) {
        var err error
-       host = strings.ToLower(host)
        if hasPort(host) {
                host, _, err = net.SplitHostPort(host)
                if err != nil {
@@ -307,7 +307,13 @@ func canonicalHost(host string) (string, error) {
                // Strip trailing dot from fully qualified domain names.
                host = host[:len(host)-1]
        }
-       return toASCII(host)
+       encoded, err := toASCII(host)
+       if err != nil {
+               return "", err
+       }
+       // We know this is ascii, no need to check.
+       lower, _ := ascii.ToLower(encoded)
+       return lower, nil
 }
 
 // hasPort reports whether host contains a port number. host may be a host
@@ -469,7 +475,12 @@ func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
                // both are illegal.
                return "", false, errMalformedDomain
        }
-       domain = strings.ToLower(domain)
+
+       domain, isASCII := ascii.ToLower(domain)
+       if !isASCII {
+               // Received non-ASCII domain, e.g. "perché.com" instead of "xn--perch-fsa.com"
+               return "", false, errMalformedDomain
+       }
 
        if domain[len(domain)-1] == '.' {
                // We received stuff like "Domain=www.example.com.".
index a9cc666e8c9637e9806b22581ece20aef131735f..c7f438dd00707a87014d801592be92c2e498a811 100644 (file)
@@ -8,6 +8,7 @@ package cookiejar
 
 import (
        "fmt"
+       "net/http/internal/ascii"
        "strings"
        "unicode/utf8"
 )
@@ -133,12 +134,12 @@ const acePrefix = "xn--"
 // toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
 // toASCII("golang") is "golang".
 func toASCII(s string) (string, error) {
-       if ascii(s) {
+       if ascii.Is(s) {
                return s, nil
        }
        labels := strings.Split(s, ".")
        for i, label := range labels {
-               if !ascii(label) {
+               if !ascii.Is(label) {
                        a, err := encode(acePrefix, label)
                        if err != nil {
                                return "", err
@@ -148,12 +149,3 @@ func toASCII(s string) (string, error) {
        }
        return strings.Join(labels, "."), nil
 }
-
-func ascii(s string) bool {
-       for i := 0; i < len(s); i++ {
-               if s[i] >= utf8.RuneSelf {
-                       return false
-               }
-       }
-       return true
-}
index b9b53911f384410425323e30efe391f9bfa02c78..4c72dcb2c88d254f3e8b81a483bb6b40079f4b37 100644 (file)
@@ -7,6 +7,7 @@ package http
 import (
        "io"
        "net/http/httptrace"
+       "net/http/internal/ascii"
        "net/textproto"
        "sort"
        "strings"
@@ -251,7 +252,7 @@ func hasToken(v, token string) bool {
                if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) {
                        continue
                }
-               if strings.EqualFold(v[sp:sp+len(token)], token) {
+               if ascii.EqualFold(v[sp:sp+len(token)], token) {
                        return true
                }
        }
index 4c5054b399410dd22b02bb1661203e0bb3160d38..101799f574f4655610c2ec49d013b66c76b7a434 100644 (file)
@@ -62,15 +62,6 @@ func isNotToken(r rune) bool {
        return !httpguts.IsTokenRune(r)
 }
 
-func isASCII(s string) bool {
-       for i := 0; i < len(s); i++ {
-               if s[i] >= utf8.RuneSelf {
-                       return false
-               }
-       }
-       return true
-}
-
 // stringContainsCTLByte reports whether s contains any ASCII control character.
 func stringContainsCTLByte(s string) bool {
        for i := 0; i < len(s); i++ {
index db42ac6ba552e91ad6b32f98948eba61b30b737e..1432ee26d37f19ef57380dac19be899b584bc375 100644 (file)
@@ -13,6 +13,7 @@ import (
        "log"
        "net"
        "net/http"
+       "net/http/internal/ascii"
        "net/textproto"
        "net/url"
        "strings"
@@ -242,6 +243,10 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
        outreq.Close = false
 
        reqUpType := upgradeType(outreq.Header)
+       if !ascii.IsPrint(reqUpType) {
+               p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
+               return
+       }
        removeConnectionHeaders(outreq.Header)
 
        // Remove hop-by-hop headers to the backend. Especially
@@ -538,13 +543,16 @@ func upgradeType(h http.Header) string {
        if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
                return ""
        }
-       return strings.ToLower(h.Get("Upgrade"))
+       return h.Get("Upgrade")
 }
 
 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
        reqUpType := upgradeType(req.Header)
        resUpType := upgradeType(res.Header)
-       if reqUpType != resUpType {
+       if !ascii.IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
+               p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
+       }
+       if !ascii.EqualFold(reqUpType, resUpType) {
                p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
                return
        }
index 3acbd940e4b8cee91cfb76058325344f6d2373e7..b89eb90ad64b9279cfce0fb33b8eed0d67c497d6 100644 (file)
@@ -16,6 +16,7 @@ import (
        "log"
        "net/http"
        "net/http/httptest"
+       "net/http/internal/ascii"
        "net/url"
        "os"
        "reflect"
@@ -1183,7 +1184,7 @@ func TestReverseProxyWebSocket(t *testing.T) {
                t.Errorf("Header(XHeader) = %q; want %q", got, want)
        }
 
-       if upgradeType(res.Header) != "websocket" {
+       if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
                t.Fatalf("not websocket upgrade; got %#v", res.Header)
        }
        rwc, ok := res.Body.(io.ReadWriteCloser)
@@ -1300,7 +1301,7 @@ func TestReverseProxyWebSocketCancelation(t *testing.T) {
                t.Errorf("X-Header mismatch\n\tgot:  %q\n\twant: %q", g, w)
        }
 
-       if g, w := upgradeType(res.Header), "websocket"; g != w {
+       if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
                t.Fatalf("Upgrade header mismatch\n\tgot:  %q\n\twant: %q", g, w)
        }
 
diff --git a/src/net/http/internal/ascii/print.go b/src/net/http/internal/ascii/print.go
new file mode 100644 (file)
index 0000000..585e5ba
--- /dev/null
@@ -0,0 +1,61 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ascii
+
+import (
+       "strings"
+       "unicode"
+)
+
+// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t
+// are equal, ASCII-case-insensitively.
+func EqualFold(s, t string) bool {
+       if len(s) != len(t) {
+               return false
+       }
+       for i := 0; i < len(s); i++ {
+               if lower(s[i]) != lower(t[i]) {
+                       return false
+               }
+       }
+       return true
+}
+
+// lower returns the ASCII lowercase version of b.
+func lower(b byte) byte {
+       if 'A' <= b && b <= 'Z' {
+               return b + ('a' - 'A')
+       }
+       return b
+}
+
+// IsPrint returns whether s is ASCII and printable according to
+// https://tools.ietf.org/html/rfc20#section-4.2.
+func IsPrint(s string) bool {
+       for i := 0; i < len(s); i++ {
+               if s[i] < ' ' || s[i] > '~' {
+                       return false
+               }
+       }
+       return true
+}
+
+// Is returns whether s is ASCII.
+func Is(s string) bool {
+       for i := 0; i < len(s); i++ {
+               if s[i] > unicode.MaxASCII {
+                       return false
+               }
+       }
+       return true
+}
+
+// ToLower returns the lowercase version of s if s is ASCII and printable.
+func ToLower(s string) (lower string, ok bool) {
+       if !IsPrint(s) {
+               return "", false
+       }
+       return strings.ToLower(s), true
+}
diff --git a/src/net/http/internal/ascii/print_test.go b/src/net/http/internal/ascii/print_test.go
new file mode 100644 (file)
index 0000000..0b7767c
--- /dev/null
@@ -0,0 +1,95 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ascii
+
+import "testing"
+
+func TestEqualFold(t *testing.T) {
+       var tests = []struct {
+               name string
+               a, b string
+               want bool
+       }{
+               {
+                       name: "empty",
+                       want: true,
+               },
+               {
+                       name: "simple match",
+                       a:    "CHUNKED",
+                       b:    "chunked",
+                       want: true,
+               },
+               {
+                       name: "same string",
+                       a:    "chunked",
+                       b:    "chunked",
+                       want: true,
+               },
+               {
+                       name: "Unicode Kelvin symbol",
+                       a:    "chunKed", // This "K" is 'KELVIN SIGN' (\u212A)
+                       b:    "chunked",
+                       want: false,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if got := EqualFold(tt.a, tt.b); got != tt.want {
+                               t.Errorf("AsciiEqualFold(%q,%q): got %v want %v", tt.a, tt.b, got, tt.want)
+                       }
+               })
+       }
+}
+
+func TestIsPrint(t *testing.T) {
+       var tests = []struct {
+               name string
+               in   string
+               want bool
+       }{
+               {
+                       name: "empty",
+                       want: true,
+               },
+               {
+                       name: "ASCII low",
+                       in:   "This is a space: ' '",
+                       want: true,
+               },
+               {
+                       name: "ASCII high",
+                       in:   "This is a tilde: '~'",
+                       want: true,
+               },
+               {
+                       name: "ASCII low non-print",
+                       in:   "This is a unit separator: \x1F",
+                       want: false,
+               },
+               {
+                       name: "Ascii high non-print",
+                       in:   "This is a Delete: \x7F",
+                       want: false,
+               },
+               {
+                       name: "Unicode letter",
+                       in:   "Today it's 280K outside: it's freezing!", // This "K" is 'KELVIN SIGN' (\u212A)
+                       want: false,
+               },
+               {
+                       name: "Unicode emoji",
+                       in:   "Gophers like 🧀",
+                       want: false,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if got := IsPrint(tt.in); got != tt.want {
+                               t.Errorf("IsASCIIPrint(%q): got %v want %v", tt.in, got, tt.want)
+                       }
+               })
+       }
+}
index 4a07eb1c7982b8a1fb69b9b48e6604c9ab44a638..7895417af5065cc782353350b531caf44661ce6e 100644 (file)
@@ -19,6 +19,7 @@ import (
        "mime/multipart"
        "net"
        "net/http/httptrace"
+       "net/http/internal/ascii"
        "net/textproto"
        "net/url"
        urlpkg "net/url"
@@ -723,7 +724,7 @@ func idnaASCII(v string) (string, error) {
        // version does not.
        // Note that for correct ASCII IDNs ToASCII will only do considerably more
        // work, but it will not cause an allocation.
-       if isASCII(v) {
+       if ascii.Is(v) {
                return v, nil
        }
        return idna.Lookup.ToASCII(v)
@@ -948,7 +949,7 @@ func (r *Request) BasicAuth() (username, password string, ok bool) {
 func parseBasicAuth(auth string) (username, password string, ok bool) {
        const prefix = "Basic "
        // Case insensitive prefix match. See Issue 22736.
-       if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
+       if len(auth) < len(prefix) || !ascii.EqualFold(auth[:len(prefix)], prefix) {
                return
        }
        c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
@@ -1456,5 +1457,5 @@ func requestMethodUsuallyLacksBody(method string) bool {
 // an HTTP/1 connection.
 func (r *Request) requiresHTTP1() bool {
        return hasToken(r.Header.Get("Connection"), "upgrade") &&
-               strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
+               ascii.EqualFold(r.Header.Get("Upgrade"), "websocket")
 }
index fbb0c39829d729ca246c5011210b5af245abf355..85c2e5a360d68676f2f6c4574b0e5833fc68ca24 100644 (file)
@@ -12,6 +12,7 @@ import (
        "io"
        "net/http/httptrace"
        "net/http/internal"
+       "net/http/internal/ascii"
        "net/textproto"
        "reflect"
        "sort"
@@ -638,7 +639,7 @@ func (t *transferReader) parseTransferEncoding() error {
        if len(raw) != 1 {
                return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)}
        }
-       if strings.ToLower(textproto.TrimString(raw[0])) != "chunked" {
+       if !ascii.EqualFold(textproto.TrimString(raw[0]), "chunked") {
                return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])}
        }
 
index 57018d2392452fe46db5c1339382a704f3c67b6f..47cb992a50215b0f241b76274f9c771d05651c2b 100644 (file)
@@ -21,6 +21,7 @@ import (
        "log"
        "net"
        "net/http/httptrace"
+       "net/http/internal/ascii"
        "net/textproto"
        "net/url"
        "os"
@@ -2185,7 +2186,7 @@ func (pc *persistConn) readLoop() {
                }
 
                resp.Body = body
-               if rc.addedGzip && strings.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
+               if rc.addedGzip && ascii.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
                        resp.Body = &gzipReader{body: body}
                        resp.Header.Del("Content-Encoding")
                        resp.Header.Del("Content-Length")