"net"
"os"
"reflect"
+ "sort"
"strings"
"testing"
"time"
if testing.Short() {
t.Skip("skipping in short mode")
}
- listener := newLocalListener(t)
- addr := listener.Addr().String()
- defer listener.Close()
-
- complete := make(chan bool)
- defer close(complete)
+ timeout := 100 * time.Microsecond
+ for !t.Failed() {
+ acceptc := make(chan net.Conn)
+ listener := newLocalListener(t)
+ go func() {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ close(acceptc)
+ return
+ }
+ acceptc <- conn
+ }
+ }()
- go func() {
- conn, err := listener.Accept()
- if err != nil {
- t.Error(err)
- return
+ addr := listener.Addr().String()
+ dialer := &net.Dialer{
+ Timeout: timeout,
+ }
+ if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil {
+ conn.Close()
+ t.Errorf("DialWithTimeout unexpectedly completed successfully")
+ } else if !isTimeoutError(err) {
+ t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
}
- <-complete
- conn.Close()
- }()
-
- dialer := &net.Dialer{
- Timeout: 10 * time.Millisecond,
- }
- var err error
- if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
- t.Fatal("DialWithTimeout completed successfully")
- }
+ listener.Close()
+
+ // We're looking for a timeout during the handshake, so check that the
+ // Listener actually accepted the connection to initiate it. (If the server
+ // takes too long to accept the connection, we might cancel before the
+ // underlying net.Conn is ever dialed — without ever attempting a
+ // handshake.)
+ lconn, ok := <-acceptc
+ if ok {
+ // The Listener accepted a connection, so assume that it was from our
+ // Dial: we triggered the timeout at the point where we wanted it!
+ t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr())
+ lconn.Close()
+ }
+ // Close any spurious extra connecitions from the listener. (This is
+ // possible if there are, for example, stray Dial calls from other tests.)
+ for extraConn := range acceptc {
+ t.Logf("spurious extra connection from %s", extraConn.RemoteAddr())
+ extraConn.Close()
+ }
+ if ok {
+ break
+ }
- if !isTimeoutError(err) {
- t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
+ t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout)
+ timeout *= 2
}
}
if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
t.Error("client and server channel bindings differ")
}
+ if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
+ t.Error("tls-unique is empty or zero")
+ }
conn.Close()
conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
t.Error("client and server channel bindings differ when session resumption is used")
}
+ if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
+ t.Error("resumption tls-unique is empty or zero")
+ }
}
func TestVerifyHostname(t *testing.T) {
}
func TestCloneFuncFields(t *testing.T) {
- const expectedCount = 6
+ const expectedCount = 8
called := 0
c1 := Config{
called |= 1 << 5
return nil
},
+ UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) {
+ called |= 1 << 6
+ return nil, nil
+ },
+ WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) {
+ called |= 1 << 7
+ return nil, nil
+ },
}
c2 := c1.Clone()
c2.GetConfigForClient(nil)
c2.VerifyPeerCertificate(nil, nil)
c2.VerifyConnection(ConnectionState{})
+ c2.UnwrapSession(nil, ConnectionState{})
+ c2.WrapSession(ConnectionState{}, nil)
if called != (1<<expectedCount)-1 {
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
switch fn := typ.Field(i).Name; fn {
case "Rand":
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
- case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
+ case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession":
// DeepEqual can't compare functions. If you add a
// function field to this list, you must also change
// TestCloneFuncFields to ensure that the func field is
SupportedPoints: []uint8{pointFormatUncompressed},
SignatureSchemes: []SignatureScheme{Ed25519},
SupportedVersions: []uint16{VersionTLS10},
+ config: &Config{MinVersion: VersionTLS10},
}, "doesn't support Ed25519"},
{ed25519Cert, &ClientHelloInfo{
CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support
SupportedPoints: []uint8{pointFormatUncompressed},
SupportedVersions: []uint16{VersionTLS10},
+ config: &Config{MinVersion: VersionTLS10},
}, ""},
{rsaCert, &ClientHelloInfo{
CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
}
}
- cipherSuiteByID := func(id uint16) *CipherSuite {
+ CipherSuiteByID := func(id uint16) *CipherSuite {
for _, c := range CipherSuites() {
if c.ID == id {
return c
}
for _, c := range cipherSuites {
- cc := cipherSuiteByID(c.id)
+ cc := CipherSuiteByID(c.id)
if cc == nil {
t.Errorf("%#04x: no CipherSuite entry", c.id)
continue
}
- if defaultOff := c.flags&suiteDefaultOff != 0; defaultOff != cc.Insecure {
- t.Errorf("%#04x: Insecure %v, expected %v", c.id, cc.Insecure, defaultOff)
- }
if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
} else if !tls12Only && len(cc.SupportedVersions) != 3 {
}
}
for _, c := range cipherSuitesTLS13 {
- cc := cipherSuiteByID(c.id)
+ cc := CipherSuiteByID(c.id)
if cc == nil {
t.Errorf("%#04x: no CipherSuite entry", c.id)
continue
if got := CipherSuiteName(0xabc); got != "0x0ABC" {
t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
}
+
+ if len(cipherSuitesPreferenceOrder) != len(cipherSuites) {
+ t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites")
+ }
+ if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) {
+ t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder")
+ }
+
+ // Check that disabled suites are at the end of the preference lists, and
+ // that they are marked insecure.
+ for i, id := range disabledCipherSuites {
+ offset := len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
+ if cipherSuitesPreferenceOrder[offset+i] != id {
+ t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrder", i)
+ }
+ if cipherSuitesPreferenceOrderNoAES[offset+i] != id {
+ t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrderNoAES", i)
+ }
+ c := CipherSuiteByID(id)
+ if c == nil {
+ t.Errorf("%#04x: no CipherSuite entry", id)
+ continue
+ }
+ if !c.Insecure {
+ t.Errorf("%#04x: disabled by default but not marked insecure", id)
+ }
+ }
+
+ for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} {
+ // Check that insecure and HTTP/2 bad cipher suites are at the end of
+ // the preference lists.
+ var sawInsecure, sawBad bool
+ for _, id := range prefOrder {
+ c := CipherSuiteByID(id)
+ if c == nil {
+ t.Errorf("%#04x: no CipherSuite entry", id)
+ continue
+ }
+
+ if c.Insecure {
+ sawInsecure = true
+ } else if sawInsecure {
+ t.Errorf("%#04x: secure suite after insecure one(s)", id)
+ }
+
+ if http2isBadCipher(id) {
+ sawBad = true
+ } else if sawBad {
+ t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id)
+ }
+ }
+
+ // Check that the list is sorted according to the documented criteria.
+ isBetter := func(a, b int) bool {
+ aSuite, bSuite := cipherSuiteByID(prefOrder[a]), cipherSuiteByID(prefOrder[b])
+ aName, bName := CipherSuiteName(prefOrder[a]), CipherSuiteName(prefOrder[b])
+ // * < RC4
+ if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") {
+ return true
+ } else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") {
+ return false
+ }
+ // * < CBC_SHA256
+ if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") {
+ return true
+ } else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") {
+ return false
+ }
+ // * < 3DES
+ if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") {
+ return true
+ } else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") {
+ return false
+ }
+ // ECDHE < *
+ if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 {
+ return true
+ } else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 {
+ return false
+ }
+ // AEAD < CBC
+ if aSuite.aead != nil && bSuite.aead == nil {
+ return true
+ } else if aSuite.aead == nil && bSuite.aead != nil {
+ return false
+ }
+ // AES < ChaCha20
+ if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") {
+ return i == 0 // true for cipherSuitesPreferenceOrder
+ } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") {
+ return i != 0 // true for cipherSuitesPreferenceOrderNoAES
+ }
+ // AES-128 < AES-256
+ if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") {
+ return true
+ } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") {
+ return false
+ }
+ // ECDSA < RSA
+ if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 {
+ return true
+ } else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 {
+ return false
+ }
+ t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName)
+ panic("unreachable")
+ }
+ if !sort.SliceIsSorted(prefOrder, isBetter) {
+ t.Error("preference order is not sorted according to the rules")
+ }
+ }
+}
+
+func TestVersionName(t *testing.T) {
+ if got, exp := VersionName(VersionTLS13), "TLS 1.3"; got != exp {
+ t.Errorf("unexpected VersionName: got %q, expected %q", got, exp)
+ }
+ if got, exp := VersionName(0x12a), "0x012A"; got != exp {
+ t.Errorf("unexpected fallback VersionName: got %q, expected %q", got, exp)
+ }
+}
+
+// http2isBadCipher is copied from net/http.
+// TODO: if it ends up exposed somewhere, use that instead.
+func http2isBadCipher(cipher uint16) bool {
+ switch cipher {
+ case TLS_RSA_WITH_RC4_128_SHA,
+ TLS_RSA_WITH_3DES_EDE_CBC_SHA,
+ TLS_RSA_WITH_AES_128_CBC_SHA,
+ TLS_RSA_WITH_AES_256_CBC_SHA,
+ TLS_RSA_WITH_AES_128_CBC_SHA256,
+ TLS_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_RC4_128_SHA,
+ TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
+ return true
+ default:
+ return false
+ }
}
type brokenSigner struct{ crypto.Signer }
t.Error(err)
}
}
+
+func TestVerifyCertificates(t *testing.T) {
+ // See https://go.dev/issue/31641.
+ t.Run("TLSv12", func(t *testing.T) { testVerifyCertificates(t, VersionTLS12) })
+ t.Run("TLSv13", func(t *testing.T) { testVerifyCertificates(t, VersionTLS13) })
+}
+
+func testVerifyCertificates(t *testing.T, version uint16) {
+ tests := []struct {
+ name string
+
+ InsecureSkipVerify bool
+ ClientAuth ClientAuthType
+ ClientCertificates bool
+ }{
+ {
+ name: "defaults",
+ },
+ {
+ name: "InsecureSkipVerify",
+ InsecureSkipVerify: true,
+ },
+ {
+ name: "RequestClientCert with no certs",
+ ClientAuth: RequestClientCert,
+ },
+ {
+ name: "RequestClientCert with certs",
+ ClientAuth: RequestClientCert,
+ ClientCertificates: true,
+ },
+ {
+ name: "RequireAnyClientCert",
+ ClientAuth: RequireAnyClientCert,
+ ClientCertificates: true,
+ },
+ {
+ name: "VerifyClientCertIfGiven with no certs",
+ ClientAuth: VerifyClientCertIfGiven,
+ },
+ {
+ name: "VerifyClientCertIfGiven with certs",
+ ClientAuth: VerifyClientCertIfGiven,
+ ClientCertificates: true,
+ },
+ {
+ name: "RequireAndVerifyClientCert",
+ ClientAuth: RequireAndVerifyClientCert,
+ ClientCertificates: true,
+ },
+ }
+
+ issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rootCAs := x509.NewCertPool()
+ rootCAs.AddCert(issuer)
+
+ for _, test := range tests {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ var serverVerifyConnection, clientVerifyConnection bool
+ var serverVerifyPeerCertificates, clientVerifyPeerCertificates bool
+
+ clientConfig := testConfig.Clone()
+ clientConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
+ clientConfig.MaxVersion = version
+ clientConfig.MinVersion = version
+ clientConfig.RootCAs = rootCAs
+ clientConfig.ServerName = "example.golang"
+ clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
+ serverConfig := clientConfig.Clone()
+ serverConfig.ClientCAs = rootCAs
+
+ clientConfig.VerifyConnection = func(cs ConnectionState) error {
+ clientVerifyConnection = true
+ return nil
+ }
+ clientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+ clientVerifyPeerCertificates = true
+ return nil
+ }
+ serverConfig.VerifyConnection = func(cs ConnectionState) error {
+ serverVerifyConnection = true
+ return nil
+ }
+ serverConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+ serverVerifyPeerCertificates = true
+ return nil
+ }
+
+ clientConfig.InsecureSkipVerify = test.InsecureSkipVerify
+ serverConfig.ClientAuth = test.ClientAuth
+ if !test.ClientCertificates {
+ clientConfig.Certificates = nil
+ }
+
+ if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
+ t.Fatal(err)
+ }
+
+ want := serverConfig.ClientAuth != NoClientCert
+ if serverVerifyPeerCertificates != want {
+ t.Errorf("VerifyPeerCertificates on the server: got %v, want %v",
+ serverVerifyPeerCertificates, want)
+ }
+ if !clientVerifyPeerCertificates {
+ t.Errorf("VerifyPeerCertificates not called on the client")
+ }
+ if !serverVerifyConnection {
+ t.Error("VerifyConnection did not get called on the server")
+ }
+ if !clientVerifyConnection {
+ t.Error("VerifyConnection did not get called on the client")
+ }
+
+ serverVerifyPeerCertificates, clientVerifyPeerCertificates = false, false
+ serverVerifyConnection, clientVerifyConnection = false, false
+ cs, _, err := testHandshake(t, clientConfig, serverConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !cs.DidResume {
+ t.Error("expected resumption")
+ }
+
+ if serverVerifyPeerCertificates {
+ t.Error("VerifyPeerCertificates got called on the server on resumption")
+ }
+ if clientVerifyPeerCertificates {
+ t.Error("VerifyPeerCertificates got called on the client on resumption")
+ }
+ if !serverVerifyConnection {
+ t.Error("VerifyConnection did not get called on the server on resumption")
+ }
+ if !clientVerifyConnection {
+ t.Error("VerifyConnection did not get called on the client on resumption")
+ }
+ })
+ }
+}