import (
"bytes"
+ "context"
+ "crypto"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"internal/testenv"
"io"
- "io/ioutil"
"math"
"net"
"os"
"reflect"
+ "sort"
"strings"
"testing"
"time"
-----END CERTIFICATE-----
`
-var rsaKeyPEM = `-----BEGIN RSA PRIVATE KEY-----
+var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY-----
MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
------END RSA PRIVATE KEY-----
-`
+-----END RSA TESTING KEY-----
+`)
// keyPEM is the same as rsaKeyPEM, but declares itself as just
// "PRIVATE KEY", not "RSA PRIVATE KEY". https://golang.org/issue/4477
-var keyPEM = `-----BEGIN PRIVATE KEY-----
+var keyPEM = testingKey(`-----BEGIN TESTING KEY-----
MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
------END PRIVATE KEY-----
-`
+-----END TESTING KEY-----
+`)
var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
-----END CERTIFICATE-----
`
-var ecdsaKeyPEM = `-----BEGIN EC PARAMETERS-----
+var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS-----
BgUrgQQAIw==
-----END EC PARAMETERS-----
------BEGIN EC PRIVATE KEY-----
+-----BEGIN EC TESTING KEY-----
MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
------END EC PRIVATE KEY-----
-`
+-----END EC TESTING KEY-----
+`)
var keyPairTests = []struct {
algo string
if testing.Short() {
t.Skip("skipping in short mode")
}
- listener := newLocalListener(t)
- addr := listener.Addr().String()
- defer listener.Close()
+ timeout := 100 * time.Microsecond
+ for !t.Failed() {
+ acceptc := make(chan net.Conn)
+ listener := newLocalListener(t)
+ go func() {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ close(acceptc)
+ return
+ }
+ acceptc <- conn
+ }
+ }()
+
+ addr := listener.Addr().String()
+ dialer := &net.Dialer{
+ Timeout: timeout,
+ }
+ if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil {
+ conn.Close()
+ t.Errorf("DialWithTimeout unexpectedly completed successfully")
+ } else if !isTimeoutError(err) {
+ t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
+ }
+
+ listener.Close()
+
+ // We're looking for a timeout during the handshake, so check that the
+ // Listener actually accepted the connection to initiate it. (If the server
+ // takes too long to accept the connection, we might cancel before the
+ // underlying net.Conn is ever dialed — without ever attempting a
+ // handshake.)
+ lconn, ok := <-acceptc
+ if ok {
+ // The Listener accepted a connection, so assume that it was from our
+ // Dial: we triggered the timeout at the point where we wanted it!
+ t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr())
+ lconn.Close()
+ }
+ // Close any spurious extra connecitions from the listener. (This is
+ // possible if there are, for example, stray Dial calls from other tests.)
+ for extraConn := range acceptc {
+ t.Logf("spurious extra connection from %s", extraConn.RemoteAddr())
+ extraConn.Close()
+ }
+ if ok {
+ break
+ }
- complete := make(chan bool)
- defer close(complete)
+ t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout)
+ timeout *= 2
+ }
+}
+
+func TestDeadlineOnWrite(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ srvCh := make(chan *Conn, 1)
go func() {
- conn, err := listener.Accept()
+ sconn, err := ln.Accept()
if err != nil {
- t.Error(err)
+ srvCh <- nil
return
}
- <-complete
- conn.Close()
+ srv := Server(sconn, testConfig.Clone())
+ if err := srv.Handshake(); err != nil {
+ srvCh <- nil
+ return
+ }
+ srvCh <- srv
}()
- dialer := &net.Dialer{
- Timeout: 10 * time.Millisecond,
+ clientConfig := testConfig.Clone()
+ clientConfig.MaxVersion = VersionTLS12
+ conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
+ if err != nil {
+ t.Fatal(err)
}
+ defer conn.Close()
- var err error
- if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
- t.Fatal("DialWithTimeout completed successfully")
+ srv := <-srvCh
+ if srv == nil {
+ t.Error(err)
+ }
+
+ // Make sure the client/server is setup correctly and is able to do a typical Write/Read
+ buf := make([]byte, 6)
+ if _, err := srv.Write([]byte("foobar")); err != nil {
+ t.Errorf("Write err: %v", err)
+ }
+ if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" {
+ t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
+ }
+
+ // Set a deadline which should cause Write to timeout
+ if err = srv.SetDeadline(time.Now()); err != nil {
+ t.Fatalf("SetDeadline(time.Now()) err: %v", err)
+ }
+ if _, err = srv.Write([]byte("should fail")); err == nil {
+ t.Fatal("Write should have timed out")
+ }
+
+ // Clear deadline and make sure it still times out
+ if err = srv.SetDeadline(time.Time{}); err != nil {
+ t.Fatalf("SetDeadline(time.Time{}) err: %v", err)
+ }
+ if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil {
+ t.Fatal("Write which previously failed should still time out")
}
+ // Verify the error
+ if ne := err.(net.Error); ne.Temporary() != false {
+ t.Error("Write timed out but incorrectly classified the error as Temporary")
+ }
if !isTimeoutError(err) {
- t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
+ t.Error("Write timed out but did not classify the error as a Timeout")
+ }
+}
+
+type readerFunc func([]byte) (int, error)
+
+func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
+
+// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
+// (The other cases are all handled by the existing dial tests in this package, which
+// all also flow through the same code shared code paths)
+func TestDialer(t *testing.T) {
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ unblockServer := make(chan struct{}) // close-only
+ defer close(unblockServer)
+ go func() {
+ conn, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+ <-unblockServer
+ }()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ d := Dialer{Config: &Config{
+ Rand: readerFunc(func(b []byte) (n int, err error) {
+ // By the time crypto/tls wants randomness, that means it has a TCP
+ // connection, so we're past the Dialer's dial and now blocked
+ // in a handshake. Cancel our context and see if we get unstuck.
+ // (Our TCP listener above never reads or writes, so the Handshake
+ // would otherwise be stuck forever)
+ cancel()
+ return len(b), nil
+ }),
+ ServerName: "foo",
+ }}
+ _, err := d.DialContext(ctx, "tcp", ln.Addr().String())
+ if err != context.Canceled {
+ t.Errorf("err = %v; want context.Canceled", err)
}
}
defer ln.Close()
serverTLSUniques := make(chan []byte)
+ parentDone := make(chan struct{})
+ childDone := make(chan struct{})
+ defer close(parentDone)
go func() {
+ defer close(childDone)
for i := 0; i < 2; i++ {
sconn, err := ln.Accept()
if err != nil {
t.Error(err)
return
}
- serverTLSUniques <- srv.ConnectionState().TLSUnique
+ select {
+ case <-parentDone:
+ return
+ case serverTLSUniques <- srv.ConnectionState().TLSUnique:
+ }
}
}()
if err != nil {
t.Fatal(err)
}
- if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) {
+
+ var serverTLSUniquesValue []byte
+ select {
+ case <-childDone:
+ return
+ case serverTLSUniquesValue = <-serverTLSUniques:
+ }
+
+ if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
t.Error("client and server channel bindings differ")
}
+ if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
+ t.Error("tls-unique is empty or zero")
+ }
conn.Close()
conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
if !conn.ConnectionState().DidResume {
t.Error("second session did not use resumption")
}
- if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) {
+
+ select {
+ case <-childDone:
+ return
+ case serverTLSUniquesValue = <-serverTLSUniques:
+ }
+
+ if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
t.Error("client and server channel bindings differ when session resumption is used")
}
+ if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
+ t.Error("resumption tls-unique is empty or zero")
+ }
}
func TestVerifyHostname(t *testing.T) {
}
}
-func TestVerifyHostnameResumed(t *testing.T) {
- t.Run("TLSv12", func(t *testing.T) { testVerifyHostnameResumed(t, VersionTLS12) })
- t.Run("TLSv13", func(t *testing.T) { testVerifyHostnameResumed(t, VersionTLS13) })
-}
-
-func testVerifyHostnameResumed(t *testing.T, version uint16) {
- testenv.MustHaveExternalNetwork(t)
-
- config := &Config{
- MaxVersion: version,
- ClientSessionCache: NewLRUClientSessionCache(32),
- }
- for i := 0; i < 2; i++ {
- c, err := Dial("tcp", "mail.google.com:https", config)
- if err != nil {
- t.Fatalf("Dial #%d: %v", i, err)
- }
- cs := c.ConnectionState()
- if i > 0 && !cs.DidResume {
- t.Fatalf("Subsequent connection unexpectedly didn't resume")
- }
- if cs.Version != version {
- t.Fatalf("Unexpectedly negotiated version %x", cs.Version)
- }
- if cs.VerifiedChains == nil {
- t.Fatalf("Dial #%d: cs.VerifiedChains == nil", i)
- }
- if err := c.VerifyHostname("mail.google.com"); err != nil {
- t.Fatalf("verify mail.google.com #%d: %v", i, err)
- }
- // Give the client a chance to read the server session tickets.
- c.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
- if _, err := c.Read(make([]byte, 1)); err != nil {
- if err, ok := err.(net.Error); !ok || !err.Timeout() {
- t.Fatal(err)
- }
- }
- c.Close()
- }
-}
-
func TestConnCloseBreakingWrite(t *testing.T) {
ln := newLocalListener(t)
defer ln.Close()
}
<-closeReturned
- if err := tconn.Close(); err != errClosed {
- t.Errorf("Close error = %v; want errClosed", err)
+ if err := tconn.Close(); err != net.ErrClosed {
+ t.Errorf("Close error = %v; want net.ErrClosed", err)
}
}
}
defer srv.Close()
- data, err := ioutil.ReadAll(srv)
+ data, err := io.ReadAll(srv)
if err != nil {
return err
}
return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
}
- data, err := ioutil.ReadAll(conn)
+ data, err := io.ReadAll(conn)
if err != nil {
return err
}
}
defer srv.Close()
- _, err = ioutil.ReadAll(srv)
+ _, err = io.ReadAll(srv)
if err == nil {
return errors.New("unexpected lack of error from server")
}
}
func TestCloneFuncFields(t *testing.T) {
- const expectedCount = 5
+ const expectedCount = 8
called := 0
c1 := Config{
called |= 1 << 4
return nil
},
+ VerifyConnection: func(ConnectionState) error {
+ called |= 1 << 5
+ return nil
+ },
+ UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) {
+ called |= 1 << 6
+ return nil, nil
+ },
+ WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) {
+ called |= 1 << 7
+ return nil, nil
+ },
}
c2 := c1.Clone()
c2.GetClientCertificate(nil)
c2.GetConfigForClient(nil)
c2.VerifyPeerCertificate(nil, nil)
+ c2.VerifyConnection(ConnectionState{})
+ c2.UnwrapSession(nil, ConnectionState{})
+ c2.WrapSession(ConnectionState{}, nil)
if called != (1<<expectedCount)-1 {
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
typ := v.Type()
for i := 0; i < typ.NumField(); i++ {
f := v.Field(i)
- if !f.CanSet() {
- // unexported field; not cloned.
- continue
- }
-
// testing/quick can't handle functions or interfaces and so
// isn't used here.
switch fn := typ.Field(i).Name; fn {
case "Rand":
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
- case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
+ case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession":
// DeepEqual can't compare functions. If you add a
// function field to this list, you must also change
// TestCloneFuncFields to ensure that the func field is
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
case "Renegotiation":
f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
+ case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
+ continue // these are unexported fields that are handled separately
default:
t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
}
}
+ // Set the unexported fields related to session ticket keys, which are copied with Clone().
+ c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
+ c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
c2 := c1.Clone()
- // DeepEqual also compares unexported fields, thus c2 needs to have run
- // serverInit in order to be DeepEqual to c1. Cloning it and discarding
- // the result is sufficient.
- c2.Clone()
-
if !reflect.DeepEqual(&c1, c2) {
t.Errorf("clone failed to copy a field")
}
}
+func TestCloneNilConfig(t *testing.T) {
+ var config *Config
+ if cc := config.Clone(); cc != nil {
+ t.Fatalf("Clone with nil should return nil, got: %+v", cc)
+ }
+}
+
// changeImplConn is a net.Conn which can change its Write and Close
// methods.
type changeImplConn struct {
if ss.ServerName != serverName {
t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
}
- if cs.ServerName != "" {
- t.Errorf("Got unexpected server name on the client side")
+ if cs.ServerName != serverName {
+ t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
}
if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
}
}
-// TestEscapeRoute tests that the library will still work if support for TLS 1.3
-// is dropped later in the Go 1.12 cycle.
-func TestEscapeRoute(t *testing.T) {
- defer func(savedSupportedVersions []uint16) {
- supportedVersions = savedSupportedVersions
- }(supportedVersions)
- supportedVersions = []uint16{
- VersionTLS12,
- VersionTLS11,
- VersionTLS10,
- VersionSSL30,
+// Issue 28744: Ensure that we don't modify memory
+// that Config doesn't own such as Certificates.
+func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) {
+ c0 := Certificate{
+ Certificate: [][]byte{testRSACertificate},
+ PrivateKey: testRSAPrivateKey,
+ }
+ c1 := Certificate{
+ Certificate: [][]byte{testSNICertificate},
+ PrivateKey: testRSAPrivateKey,
}
+ config := testConfig.Clone()
+ config.Certificates = []Certificate{c0, c1}
- ss, cs, err := testHandshake(t, testConfig, testConfig)
- if err != nil {
- t.Fatalf("Handshake failed when support for TLS 1.3 was dropped: %v", err)
+ config.BuildNameToCertificate()
+ got := config.Certificates
+ want := []Certificate{c0, c1}
+ if !reflect.DeepEqual(got, want) {
+ t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want)
}
- if ss.Version != VersionTLS12 {
- t.Errorf("Server negotiated version %x, expected %x", cs.Version, VersionTLS12)
+}
+
+func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
+
+func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
+ rsaCert := &Certificate{
+ Certificate: [][]byte{testRSACertificate},
+ PrivateKey: testRSAPrivateKey,
+ }
+ pkcs1Cert := &Certificate{
+ Certificate: [][]byte{testRSACertificate},
+ PrivateKey: testRSAPrivateKey,
+ SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256},
+ }
+ ecdsaCert := &Certificate{
+ // ECDSA P-256 certificate
+ Certificate: [][]byte{testP256Certificate},
+ PrivateKey: testP256PrivateKey,
+ }
+ ed25519Cert := &Certificate{
+ Certificate: [][]byte{testEd25519Certificate},
+ PrivateKey: testEd25519PrivateKey,
+ }
+
+ tests := []struct {
+ c *Certificate
+ chi *ClientHelloInfo
+ wantErr string
+ }{
+ {rsaCert, &ClientHelloInfo{
+ ServerName: "example.golang",
+ SignatureSchemes: []SignatureScheme{PSSWithSHA256},
+ SupportedVersions: []uint16{VersionTLS13},
+ }, ""},
+ {ecdsaCert, &ClientHelloInfo{
+ SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
+ SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
+ }, ""},
+ {rsaCert, &ClientHelloInfo{
+ ServerName: "example.com",
+ SignatureSchemes: []SignatureScheme{PSSWithSHA256},
+ SupportedVersions: []uint16{VersionTLS13},
+ }, "not valid for requested server name"},
+ {ecdsaCert, &ClientHelloInfo{
+ SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384},
+ SupportedVersions: []uint16{VersionTLS13},
+ }, "signature algorithms"},
+ {pkcs1Cert, &ClientHelloInfo{
+ SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
+ SupportedVersions: []uint16{VersionTLS13},
+ }, "signature algorithms"},
+
+ {rsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
+ SignatureSchemes: []SignatureScheme{PKCS1WithSHA1},
+ SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
+ }, "signature algorithms"},
+ {rsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
+ SignatureSchemes: []SignatureScheme{PKCS1WithSHA1},
+ SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
+ config: &Config{
+ MaxVersion: VersionTLS12,
+ },
+ }, ""}, // Check that mutual version selection works.
+
+ {ecdsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{CurveP256},
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256},
+ SupportedVersions: []uint16{VersionTLS12},
+ }, ""},
+ {ecdsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{CurveP256},
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384},
+ SupportedVersions: []uint16{VersionTLS12},
+ }, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme.
+ {ecdsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{CurveP256},
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SignatureSchemes: nil,
+ SupportedVersions: []uint16{VersionTLS12},
+ }, ""}, // TLS 1.2 comes with default signature schemes.
+ {ecdsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{CurveP256},
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256},
+ SupportedVersions: []uint16{VersionTLS12},
+ }, "cipher suite"},
+ {ecdsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{CurveP256},
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256},
+ SupportedVersions: []uint16{VersionTLS12},
+ config: &Config{
+ CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
+ },
+ }, "cipher suite"},
+ {ecdsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{CurveP384},
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256},
+ SupportedVersions: []uint16{VersionTLS12},
+ }, "certificate curve"},
+ {ecdsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{CurveP256},
+ SupportedPoints: []uint8{1},
+ SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256},
+ SupportedVersions: []uint16{VersionTLS12},
+ }, "doesn't support ECDHE"},
+ {ecdsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{CurveP256},
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SignatureSchemes: []SignatureScheme{PSSWithSHA256},
+ SupportedVersions: []uint16{VersionTLS12},
+ }, "signature algorithms"},
+
+ {ed25519Cert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SignatureSchemes: []SignatureScheme{Ed25519},
+ SupportedVersions: []uint16{VersionTLS12},
+ }, ""},
+ {ed25519Cert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SignatureSchemes: []SignatureScheme{Ed25519},
+ SupportedVersions: []uint16{VersionTLS10},
+ config: &Config{MinVersion: VersionTLS10},
+ }, "doesn't support Ed25519"},
+ {ed25519Cert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
+ SupportedCurves: []CurveID{},
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SignatureSchemes: []SignatureScheme{Ed25519},
+ SupportedVersions: []uint16{VersionTLS12},
+ }, "doesn't support ECDHE"},
+
+ {rsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
+ SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support
+ SupportedPoints: []uint8{pointFormatUncompressed},
+ SupportedVersions: []uint16{VersionTLS10},
+ config: &Config{MinVersion: VersionTLS10},
+ }, ""},
+ {rsaCert, &ClientHelloInfo{
+ CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
+ SupportedVersions: []uint16{VersionTLS12},
+ }, ""}, // static RSA fallback
+ }
+ for i, tt := range tests {
+ err := tt.chi.SupportsCertificate(tt.c)
+ switch {
+ case tt.wantErr == "" && err != nil:
+ t.Errorf("%d: unexpected error: %v", i, err)
+ case tt.wantErr != "" && err == nil:
+ t.Errorf("%d: unexpected success", i)
+ case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr):
+ t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr)
+ }
}
- if cs.Version != VersionTLS12 {
- t.Errorf("Client negotiated version %x, expected %x", cs.Version, VersionTLS12)
+}
+
+func TestCipherSuites(t *testing.T) {
+ var lastID uint16
+ for _, c := range CipherSuites() {
+ if lastID > c.ID {
+ t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
+ } else {
+ lastID = c.ID
+ }
+
+ if c.Insecure {
+ t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID)
+ }
+ }
+ lastID = 0
+ for _, c := range InsecureCipherSuites() {
+ if lastID > c.ID {
+ t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
+ } else {
+ lastID = c.ID
+ }
+
+ if !c.Insecure {
+ t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID)
+ }
+ }
+
+ CipherSuiteByID := func(id uint16) *CipherSuite {
+ for _, c := range CipherSuites() {
+ if c.ID == id {
+ return c
+ }
+ }
+ for _, c := range InsecureCipherSuites() {
+ if c.ID == id {
+ return c
+ }
+ }
+ return nil
+ }
+
+ for _, c := range cipherSuites {
+ cc := CipherSuiteByID(c.id)
+ if cc == nil {
+ t.Errorf("%#04x: no CipherSuite entry", c.id)
+ continue
+ }
+
+ if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
+ t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
+ } else if !tls12Only && len(cc.SupportedVersions) != 3 {
+ t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions)
+ }
+
+ if got := CipherSuiteName(c.id); got != cc.Name {
+ t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
+ }
+ }
+ for _, c := range cipherSuitesTLS13 {
+ cc := CipherSuiteByID(c.id)
+ if cc == nil {
+ t.Errorf("%#04x: no CipherSuite entry", c.id)
+ continue
+ }
+
+ if cc.Insecure {
+ t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure)
+ }
+ if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 {
+ t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
+ }
+
+ if got := CipherSuiteName(c.id); got != cc.Name {
+ t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
+ }
+ }
+
+ if got := CipherSuiteName(0xabc); got != "0x0ABC" {
+ t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
+ }
+
+ if len(cipherSuitesPreferenceOrder) != len(cipherSuites) {
+ t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites")
+ }
+ if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) {
+ t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder")
+ }
+
+ // Check that disabled suites are at the end of the preference lists, and
+ // that they are marked insecure.
+ for i, id := range disabledCipherSuites {
+ offset := len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
+ if cipherSuitesPreferenceOrder[offset+i] != id {
+ t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrder", i)
+ }
+ if cipherSuitesPreferenceOrderNoAES[offset+i] != id {
+ t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrderNoAES", i)
+ }
+ c := CipherSuiteByID(id)
+ if c == nil {
+ t.Errorf("%#04x: no CipherSuite entry", id)
+ continue
+ }
+ if !c.Insecure {
+ t.Errorf("%#04x: disabled by default but not marked insecure", id)
+ }
+ }
+
+ for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} {
+ // Check that insecure and HTTP/2 bad cipher suites are at the end of
+ // the preference lists.
+ var sawInsecure, sawBad bool
+ for _, id := range prefOrder {
+ c := CipherSuiteByID(id)
+ if c == nil {
+ t.Errorf("%#04x: no CipherSuite entry", id)
+ continue
+ }
+
+ if c.Insecure {
+ sawInsecure = true
+ } else if sawInsecure {
+ t.Errorf("%#04x: secure suite after insecure one(s)", id)
+ }
+
+ if http2isBadCipher(id) {
+ sawBad = true
+ } else if sawBad {
+ t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id)
+ }
+ }
+
+ // Check that the list is sorted according to the documented criteria.
+ isBetter := func(a, b int) bool {
+ aSuite, bSuite := cipherSuiteByID(prefOrder[a]), cipherSuiteByID(prefOrder[b])
+ aName, bName := CipherSuiteName(prefOrder[a]), CipherSuiteName(prefOrder[b])
+ // * < RC4
+ if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") {
+ return true
+ } else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") {
+ return false
+ }
+ // * < CBC_SHA256
+ if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") {
+ return true
+ } else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") {
+ return false
+ }
+ // * < 3DES
+ if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") {
+ return true
+ } else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") {
+ return false
+ }
+ // ECDHE < *
+ if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 {
+ return true
+ } else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 {
+ return false
+ }
+ // AEAD < CBC
+ if aSuite.aead != nil && bSuite.aead == nil {
+ return true
+ } else if aSuite.aead == nil && bSuite.aead != nil {
+ return false
+ }
+ // AES < ChaCha20
+ if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") {
+ return i == 0 // true for cipherSuitesPreferenceOrder
+ } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") {
+ return i != 0 // true for cipherSuitesPreferenceOrderNoAES
+ }
+ // AES-128 < AES-256
+ if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") {
+ return true
+ } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") {
+ return false
+ }
+ // ECDSA < RSA
+ if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 {
+ return true
+ } else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 {
+ return false
+ }
+ t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName)
+ panic("unreachable")
+ }
+ if !sort.SliceIsSorted(prefOrder, isBetter) {
+ t.Error("preference order is not sorted according to the rules")
+ }
+ }
+}
+
+func TestVersionName(t *testing.T) {
+ if got, exp := VersionName(VersionTLS13), "TLS 1.3"; got != exp {
+ t.Errorf("unexpected VersionName: got %q, expected %q", got, exp)
+ }
+ if got, exp := VersionName(0x12a), "0x012A"; got != exp {
+ t.Errorf("unexpected fallback VersionName: got %q, expected %q", got, exp)
+ }
+}
+
+// http2isBadCipher is copied from net/http.
+// TODO: if it ends up exposed somewhere, use that instead.
+func http2isBadCipher(cipher uint16) bool {
+ switch cipher {
+ case TLS_RSA_WITH_RC4_128_SHA,
+ TLS_RSA_WITH_3DES_EDE_CBC_SHA,
+ TLS_RSA_WITH_AES_128_CBC_SHA,
+ TLS_RSA_WITH_AES_256_CBC_SHA,
+ TLS_RSA_WITH_AES_128_CBC_SHA256,
+ TLS_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_RC4_128_SHA,
+ TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
+ return true
+ default:
+ return false
+ }
+}
+
+type brokenSigner struct{ crypto.Signer }
+
+func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
+ // Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded.
+ return s.Signer.Sign(rand, digest, opts.HashFunc())
+}
+
+// TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that
+// always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS.
+func TestPKCS1OnlyCert(t *testing.T) {
+ clientConfig := testConfig.Clone()
+ clientConfig.Certificates = []Certificate{{
+ Certificate: [][]byte{testRSACertificate},
+ PrivateKey: brokenSigner{testRSAPrivateKey},
+ }}
+ serverConfig := testConfig.Clone()
+ serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5
+ serverConfig.ClientAuth = RequireAnyClientCert
+
+ // If RSA-PSS is selected, the handshake should fail.
+ if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil {
+ t.Fatal("expected broken certificate to cause connection to fail")
+ }
+
+ clientConfig.Certificates[0].SupportedSignatureAlgorithms =
+ []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}
+
+ // But if the certificate restricts supported algorithms, RSA-PSS should not
+ // be selected, and the handshake should succeed.
+ if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestVerifyCertificates(t *testing.T) {
+ // See https://go.dev/issue/31641.
+ t.Run("TLSv12", func(t *testing.T) { testVerifyCertificates(t, VersionTLS12) })
+ t.Run("TLSv13", func(t *testing.T) { testVerifyCertificates(t, VersionTLS13) })
+}
+
+func testVerifyCertificates(t *testing.T, version uint16) {
+ tests := []struct {
+ name string
+
+ InsecureSkipVerify bool
+ ClientAuth ClientAuthType
+ ClientCertificates bool
+ }{
+ {
+ name: "defaults",
+ },
+ {
+ name: "InsecureSkipVerify",
+ InsecureSkipVerify: true,
+ },
+ {
+ name: "RequestClientCert with no certs",
+ ClientAuth: RequestClientCert,
+ },
+ {
+ name: "RequestClientCert with certs",
+ ClientAuth: RequestClientCert,
+ ClientCertificates: true,
+ },
+ {
+ name: "RequireAnyClientCert",
+ ClientAuth: RequireAnyClientCert,
+ ClientCertificates: true,
+ },
+ {
+ name: "VerifyClientCertIfGiven with no certs",
+ ClientAuth: VerifyClientCertIfGiven,
+ },
+ {
+ name: "VerifyClientCertIfGiven with certs",
+ ClientAuth: VerifyClientCertIfGiven,
+ ClientCertificates: true,
+ },
+ {
+ name: "RequireAndVerifyClientCert",
+ ClientAuth: RequireAndVerifyClientCert,
+ ClientCertificates: true,
+ },
+ }
+
+ issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rootCAs := x509.NewCertPool()
+ rootCAs.AddCert(issuer)
+
+ for _, test := range tests {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ var serverVerifyConnection, clientVerifyConnection bool
+ var serverVerifyPeerCertificates, clientVerifyPeerCertificates bool
+
+ clientConfig := testConfig.Clone()
+ clientConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
+ clientConfig.MaxVersion = version
+ clientConfig.MinVersion = version
+ clientConfig.RootCAs = rootCAs
+ clientConfig.ServerName = "example.golang"
+ clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
+ serverConfig := clientConfig.Clone()
+ serverConfig.ClientCAs = rootCAs
+
+ clientConfig.VerifyConnection = func(cs ConnectionState) error {
+ clientVerifyConnection = true
+ return nil
+ }
+ clientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+ clientVerifyPeerCertificates = true
+ return nil
+ }
+ serverConfig.VerifyConnection = func(cs ConnectionState) error {
+ serverVerifyConnection = true
+ return nil
+ }
+ serverConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+ serverVerifyPeerCertificates = true
+ return nil
+ }
+
+ clientConfig.InsecureSkipVerify = test.InsecureSkipVerify
+ serverConfig.ClientAuth = test.ClientAuth
+ if !test.ClientCertificates {
+ clientConfig.Certificates = nil
+ }
+
+ if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
+ t.Fatal(err)
+ }
+
+ want := serverConfig.ClientAuth != NoClientCert
+ if serverVerifyPeerCertificates != want {
+ t.Errorf("VerifyPeerCertificates on the server: got %v, want %v",
+ serverVerifyPeerCertificates, want)
+ }
+ if !clientVerifyPeerCertificates {
+ t.Errorf("VerifyPeerCertificates not called on the client")
+ }
+ if !serverVerifyConnection {
+ t.Error("VerifyConnection did not get called on the server")
+ }
+ if !clientVerifyConnection {
+ t.Error("VerifyConnection did not get called on the client")
+ }
+
+ serverVerifyPeerCertificates, clientVerifyPeerCertificates = false, false
+ serverVerifyConnection, clientVerifyConnection = false, false
+ cs, _, err := testHandshake(t, clientConfig, serverConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !cs.DidResume {
+ t.Error("expected resumption")
+ }
+
+ if serverVerifyPeerCertificates {
+ t.Error("VerifyPeerCertificates got called on the server on resumption")
+ }
+ if clientVerifyPeerCertificates {
+ t.Error("VerifyPeerCertificates got called on the client on resumption")
+ }
+ if !serverVerifyConnection {
+ t.Error("VerifyConnection did not get called on the server on resumption")
+ }
+ if !clientVerifyConnection {
+ t.Error("VerifyConnection did not get called on the client on resumption")
+ }
+ })
}
}