import (
"bytes"
+ "context"
"crypto"
"crypto/x509"
"encoding/json"
"fmt"
"internal/testenv"
"io"
- "io/ioutil"
"math"
"net"
"os"
"reflect"
+ "sort"
"strings"
"testing"
"time"
if testing.Short() {
t.Skip("skipping in short mode")
}
- listener := newLocalListener(t)
- addr := listener.Addr().String()
- defer listener.Close()
-
- complete := make(chan bool)
- defer close(complete)
+ timeout := 100 * time.Microsecond
+ for !t.Failed() {
+ acceptc := make(chan net.Conn)
+ listener := newLocalListener(t)
+ go func() {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ close(acceptc)
+ return
+ }
+ acceptc <- conn
+ }
+ }()
- go func() {
- conn, err := listener.Accept()
- if err != nil {
- t.Error(err)
- return
+ addr := listener.Addr().String()
+ dialer := &net.Dialer{
+ Timeout: timeout,
+ }
+ if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil {
+ conn.Close()
+ t.Errorf("DialWithTimeout unexpectedly completed successfully")
+ } else if !isTimeoutError(err) {
+ t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
}
- <-complete
- conn.Close()
- }()
-
- dialer := &net.Dialer{
- Timeout: 10 * time.Millisecond,
- }
- var err error
- if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
- t.Fatal("DialWithTimeout completed successfully")
- }
+ listener.Close()
+
+ // We're looking for a timeout during the handshake, so check that the
+ // Listener actually accepted the connection to initiate it. (If the server
+ // takes too long to accept the connection, we might cancel before the
+ // underlying net.Conn is ever dialed — without ever attempting a
+ // handshake.)
+ lconn, ok := <-acceptc
+ if ok {
+ // The Listener accepted a connection, so assume that it was from our
+ // Dial: we triggered the timeout at the point where we wanted it!
+ t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr())
+ lconn.Close()
+ }
+ // Close any spurious extra connecitions from the listener. (This is
+ // possible if there are, for example, stray Dial calls from other tests.)
+ for extraConn := range acceptc {
+ t.Logf("spurious extra connection from %s", extraConn.RemoteAddr())
+ extraConn.Close()
+ }
+ if ok {
+ break
+ }
- if !isTimeoutError(err) {
- t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
+ t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout)
+ timeout *= 2
}
}
}
}
+type readerFunc func([]byte) (int, error)
+
+func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
+
+// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
+// (The other cases are all handled by the existing dial tests in this package, which
+// all also flow through the same code shared code paths)
+func TestDialer(t *testing.T) {
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ unblockServer := make(chan struct{}) // close-only
+ defer close(unblockServer)
+ go func() {
+ conn, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+ <-unblockServer
+ }()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ d := Dialer{Config: &Config{
+ Rand: readerFunc(func(b []byte) (n int, err error) {
+ // By the time crypto/tls wants randomness, that means it has a TCP
+ // connection, so we're past the Dialer's dial and now blocked
+ // in a handshake. Cancel our context and see if we get unstuck.
+ // (Our TCP listener above never reads or writes, so the Handshake
+ // would otherwise be stuck forever)
+ cancel()
+ return len(b), nil
+ }),
+ ServerName: "foo",
+ }}
+ _, err := d.DialContext(ctx, "tcp", ln.Addr().String())
+ if err != context.Canceled {
+ t.Errorf("err = %v; want context.Canceled", err)
+ }
+}
+
func isTimeoutError(err error) bool {
if ne, ok := err.(net.Error); ok {
return ne.Timeout()
if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
t.Error("client and server channel bindings differ")
}
+ if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
+ t.Error("tls-unique is empty or zero")
+ }
conn.Close()
conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
t.Error("client and server channel bindings differ when session resumption is used")
}
+ if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
+ t.Error("resumption tls-unique is empty or zero")
+ }
}
func TestVerifyHostname(t *testing.T) {
}
<-closeReturned
- if err := tconn.Close(); err != errClosed {
- t.Errorf("Close error = %v; want errClosed", err)
+ if err := tconn.Close(); err != net.ErrClosed {
+ t.Errorf("Close error = %v; want net.ErrClosed", err)
}
}
}
defer srv.Close()
- data, err := ioutil.ReadAll(srv)
+ data, err := io.ReadAll(srv)
if err != nil {
return err
}
return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
}
- data, err := ioutil.ReadAll(conn)
+ data, err := io.ReadAll(conn)
if err != nil {
return err
}
}
defer srv.Close()
- _, err = ioutil.ReadAll(srv)
+ _, err = io.ReadAll(srv)
if err == nil {
return errors.New("unexpected lack of error from server")
}
}
func TestCloneFuncFields(t *testing.T) {
- const expectedCount = 5
+ const expectedCount = 8
called := 0
c1 := Config{
called |= 1 << 4
return nil
},
+ VerifyConnection: func(ConnectionState) error {
+ called |= 1 << 5
+ return nil
+ },
+ UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) {
+ called |= 1 << 6
+ return nil, nil
+ },
+ WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) {
+ called |= 1 << 7
+ return nil, nil
+ },
}
c2 := c1.Clone()
c2.GetClientCertificate(nil)
c2.GetConfigForClient(nil)
c2.VerifyPeerCertificate(nil, nil)
+ c2.VerifyConnection(ConnectionState{})
+ c2.UnwrapSession(nil, ConnectionState{})
+ c2.WrapSession(ConnectionState{}, nil)
if called != (1<<expectedCount)-1 {
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
typ := v.Type()
for i := 0; i < typ.NumField(); i++ {
f := v.Field(i)
- if !f.CanSet() {
- // unexported field; not cloned.
- continue
- }
-
// testing/quick can't handle functions or interfaces and so
// isn't used here.
switch fn := typ.Field(i).Name; fn {
case "Rand":
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
- case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
+ case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession":
// DeepEqual can't compare functions. If you add a
// function field to this list, you must also change
// TestCloneFuncFields to ensure that the func field is
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
case "Renegotiation":
f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
+ case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
+ continue // these are unexported fields that are handled separately
default:
t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
}
}
+ // Set the unexported fields related to session ticket keys, which are copied with Clone().
+ c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
+ c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
c2 := c1.Clone()
- // DeepEqual also compares unexported fields, thus c2 needs to have run
- // serverInit in order to be DeepEqual to c1. Cloning it and discarding
- // the result is sufficient.
- c2.Clone()
-
if !reflect.DeepEqual(&c1, c2) {
t.Errorf("clone failed to copy a field")
}
}
+func TestCloneNilConfig(t *testing.T) {
+ var config *Config
+ if cc := config.Clone(); cc != nil {
+ t.Fatalf("Clone with nil should return nil, got: %+v", cc)
+ }
+}
+
// changeImplConn is a net.Conn which can change its Write and Close
// methods.
type changeImplConn struct {
if ss.ServerName != serverName {
t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
}
- if cs.ServerName != "" {
- t.Errorf("Got unexpected server name on the client side")
+ if cs.ServerName != serverName {
+ t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
}
if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
SupportedPoints: []uint8{pointFormatUncompressed},
SignatureSchemes: []SignatureScheme{Ed25519},
SupportedVersions: []uint16{VersionTLS10},
+ config: &Config{MinVersion: VersionTLS10},
}, "doesn't support Ed25519"},
{ed25519Cert, &ClientHelloInfo{
CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support
SupportedPoints: []uint8{pointFormatUncompressed},
SupportedVersions: []uint16{VersionTLS10},
+ config: &Config{MinVersion: VersionTLS10},
}, ""},
{rsaCert, &ClientHelloInfo{
CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
}
}
- cipherSuiteByID := func(id uint16) *CipherSuite {
+ CipherSuiteByID := func(id uint16) *CipherSuite {
for _, c := range CipherSuites() {
if c.ID == id {
return c
}
for _, c := range cipherSuites {
- cc := cipherSuiteByID(c.id)
+ cc := CipherSuiteByID(c.id)
if cc == nil {
t.Errorf("%#04x: no CipherSuite entry", c.id)
continue
}
- if defaultOff := c.flags&suiteDefaultOff != 0; defaultOff != cc.Insecure {
- t.Errorf("%#04x: Insecure %v, expected %v", c.id, cc.Insecure, defaultOff)
- }
if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
} else if !tls12Only && len(cc.SupportedVersions) != 3 {
}
}
for _, c := range cipherSuitesTLS13 {
- cc := cipherSuiteByID(c.id)
+ cc := CipherSuiteByID(c.id)
if cc == nil {
t.Errorf("%#04x: no CipherSuite entry", c.id)
continue
if got := CipherSuiteName(0xabc); got != "0x0ABC" {
t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
}
+
+ if len(cipherSuitesPreferenceOrder) != len(cipherSuites) {
+ t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites")
+ }
+ if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) {
+ t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder")
+ }
+
+ // Check that disabled suites are at the end of the preference lists, and
+ // that they are marked insecure.
+ for i, id := range disabledCipherSuites {
+ offset := len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
+ if cipherSuitesPreferenceOrder[offset+i] != id {
+ t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrder", i)
+ }
+ if cipherSuitesPreferenceOrderNoAES[offset+i] != id {
+ t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrderNoAES", i)
+ }
+ c := CipherSuiteByID(id)
+ if c == nil {
+ t.Errorf("%#04x: no CipherSuite entry", id)
+ continue
+ }
+ if !c.Insecure {
+ t.Errorf("%#04x: disabled by default but not marked insecure", id)
+ }
+ }
+
+ for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} {
+ // Check that insecure and HTTP/2 bad cipher suites are at the end of
+ // the preference lists.
+ var sawInsecure, sawBad bool
+ for _, id := range prefOrder {
+ c := CipherSuiteByID(id)
+ if c == nil {
+ t.Errorf("%#04x: no CipherSuite entry", id)
+ continue
+ }
+
+ if c.Insecure {
+ sawInsecure = true
+ } else if sawInsecure {
+ t.Errorf("%#04x: secure suite after insecure one(s)", id)
+ }
+
+ if http2isBadCipher(id) {
+ sawBad = true
+ } else if sawBad {
+ t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id)
+ }
+ }
+
+ // Check that the list is sorted according to the documented criteria.
+ isBetter := func(a, b int) bool {
+ aSuite, bSuite := cipherSuiteByID(prefOrder[a]), cipherSuiteByID(prefOrder[b])
+ aName, bName := CipherSuiteName(prefOrder[a]), CipherSuiteName(prefOrder[b])
+ // * < RC4
+ if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") {
+ return true
+ } else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") {
+ return false
+ }
+ // * < CBC_SHA256
+ if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") {
+ return true
+ } else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") {
+ return false
+ }
+ // * < 3DES
+ if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") {
+ return true
+ } else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") {
+ return false
+ }
+ // ECDHE < *
+ if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 {
+ return true
+ } else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 {
+ return false
+ }
+ // AEAD < CBC
+ if aSuite.aead != nil && bSuite.aead == nil {
+ return true
+ } else if aSuite.aead == nil && bSuite.aead != nil {
+ return false
+ }
+ // AES < ChaCha20
+ if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") {
+ return i == 0 // true for cipherSuitesPreferenceOrder
+ } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") {
+ return i != 0 // true for cipherSuitesPreferenceOrderNoAES
+ }
+ // AES-128 < AES-256
+ if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") {
+ return true
+ } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") {
+ return false
+ }
+ // ECDSA < RSA
+ if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 {
+ return true
+ } else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 {
+ return false
+ }
+ t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName)
+ panic("unreachable")
+ }
+ if !sort.SliceIsSorted(prefOrder, isBetter) {
+ t.Error("preference order is not sorted according to the rules")
+ }
+ }
+}
+
+func TestVersionName(t *testing.T) {
+ if got, exp := VersionName(VersionTLS13), "TLS 1.3"; got != exp {
+ t.Errorf("unexpected VersionName: got %q, expected %q", got, exp)
+ }
+ if got, exp := VersionName(0x12a), "0x012A"; got != exp {
+ t.Errorf("unexpected fallback VersionName: got %q, expected %q", got, exp)
+ }
+}
+
+// http2isBadCipher is copied from net/http.
+// TODO: if it ends up exposed somewhere, use that instead.
+func http2isBadCipher(cipher uint16) bool {
+ switch cipher {
+ case TLS_RSA_WITH_RC4_128_SHA,
+ TLS_RSA_WITH_3DES_EDE_CBC_SHA,
+ TLS_RSA_WITH_AES_128_CBC_SHA,
+ TLS_RSA_WITH_AES_256_CBC_SHA,
+ TLS_RSA_WITH_AES_128_CBC_SHA256,
+ TLS_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_RC4_128_SHA,
+ TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
+ return true
+ default:
+ return false
+ }
}
type brokenSigner struct{ crypto.Signer }
}
// TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that
-// always makes PKCS#1 v1.5 signatures, so can't be used with RSA-PSS.
+// always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS.
func TestPKCS1OnlyCert(t *testing.T) {
clientConfig := testConfig.Clone()
clientConfig.Certificates = []Certificate{{
PrivateKey: brokenSigner{testRSAPrivateKey},
}}
serverConfig := testConfig.Clone()
- serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS#1 v1.5
+ serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5
serverConfig.ClientAuth = RequireAnyClientCert
// If RSA-PSS is selected, the handshake should fail.
t.Error(err)
}
}
+
+func TestVerifyCertificates(t *testing.T) {
+ // See https://go.dev/issue/31641.
+ t.Run("TLSv12", func(t *testing.T) { testVerifyCertificates(t, VersionTLS12) })
+ t.Run("TLSv13", func(t *testing.T) { testVerifyCertificates(t, VersionTLS13) })
+}
+
+func testVerifyCertificates(t *testing.T, version uint16) {
+ tests := []struct {
+ name string
+
+ InsecureSkipVerify bool
+ ClientAuth ClientAuthType
+ ClientCertificates bool
+ }{
+ {
+ name: "defaults",
+ },
+ {
+ name: "InsecureSkipVerify",
+ InsecureSkipVerify: true,
+ },
+ {
+ name: "RequestClientCert with no certs",
+ ClientAuth: RequestClientCert,
+ },
+ {
+ name: "RequestClientCert with certs",
+ ClientAuth: RequestClientCert,
+ ClientCertificates: true,
+ },
+ {
+ name: "RequireAnyClientCert",
+ ClientAuth: RequireAnyClientCert,
+ ClientCertificates: true,
+ },
+ {
+ name: "VerifyClientCertIfGiven with no certs",
+ ClientAuth: VerifyClientCertIfGiven,
+ },
+ {
+ name: "VerifyClientCertIfGiven with certs",
+ ClientAuth: VerifyClientCertIfGiven,
+ ClientCertificates: true,
+ },
+ {
+ name: "RequireAndVerifyClientCert",
+ ClientAuth: RequireAndVerifyClientCert,
+ ClientCertificates: true,
+ },
+ }
+
+ issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rootCAs := x509.NewCertPool()
+ rootCAs.AddCert(issuer)
+
+ for _, test := range tests {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ var serverVerifyConnection, clientVerifyConnection bool
+ var serverVerifyPeerCertificates, clientVerifyPeerCertificates bool
+
+ clientConfig := testConfig.Clone()
+ clientConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
+ clientConfig.MaxVersion = version
+ clientConfig.MinVersion = version
+ clientConfig.RootCAs = rootCAs
+ clientConfig.ServerName = "example.golang"
+ clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
+ serverConfig := clientConfig.Clone()
+ serverConfig.ClientCAs = rootCAs
+
+ clientConfig.VerifyConnection = func(cs ConnectionState) error {
+ clientVerifyConnection = true
+ return nil
+ }
+ clientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+ clientVerifyPeerCertificates = true
+ return nil
+ }
+ serverConfig.VerifyConnection = func(cs ConnectionState) error {
+ serverVerifyConnection = true
+ return nil
+ }
+ serverConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+ serverVerifyPeerCertificates = true
+ return nil
+ }
+
+ clientConfig.InsecureSkipVerify = test.InsecureSkipVerify
+ serverConfig.ClientAuth = test.ClientAuth
+ if !test.ClientCertificates {
+ clientConfig.Certificates = nil
+ }
+
+ if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
+ t.Fatal(err)
+ }
+
+ want := serverConfig.ClientAuth != NoClientCert
+ if serverVerifyPeerCertificates != want {
+ t.Errorf("VerifyPeerCertificates on the server: got %v, want %v",
+ serverVerifyPeerCertificates, want)
+ }
+ if !clientVerifyPeerCertificates {
+ t.Errorf("VerifyPeerCertificates not called on the client")
+ }
+ if !serverVerifyConnection {
+ t.Error("VerifyConnection did not get called on the server")
+ }
+ if !clientVerifyConnection {
+ t.Error("VerifyConnection did not get called on the client")
+ }
+
+ serverVerifyPeerCertificates, clientVerifyPeerCertificates = false, false
+ serverVerifyConnection, clientVerifyConnection = false, false
+ cs, _, err := testHandshake(t, clientConfig, serverConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !cs.DidResume {
+ t.Error("expected resumption")
+ }
+
+ if serverVerifyPeerCertificates {
+ t.Error("VerifyPeerCertificates got called on the server on resumption")
+ }
+ if clientVerifyPeerCertificates {
+ t.Error("VerifyPeerCertificates got called on the client on resumption")
+ }
+ if !serverVerifyConnection {
+ t.Error("VerifyConnection did not get called on the server on resumption")
+ }
+ if !clientVerifyConnection {
+ t.Error("VerifyConnection did not get called on the client on resumption")
+ }
+ })
+ }
+}