]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/crypto/tls/tls_test.go
crypto/tls: change default minimum version to 1.2
[gostls13.git] / src / crypto / tls / tls_test.go
index 178b519f1cc7c7fdb4a7aaaa26f926967acc96f0..16f655dd93dbebcca445931aa1e4b64e5eb39b8b 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/x509"
        "encoding/json"
@@ -13,11 +14,11 @@ import (
        "fmt"
        "internal/testenv"
        "io"
-       "io/ioutil"
        "math"
        "net"
        "os"
        "reflect"
+       "sort"
        "strings"
        "testing"
        "time"
@@ -169,35 +170,171 @@ func TestDialTimeout(t *testing.T) {
        if testing.Short() {
                t.Skip("skipping in short mode")
        }
-       listener := newLocalListener(t)
 
-       addr := listener.Addr().String()
-       defer listener.Close()
+       timeout := 100 * time.Microsecond
+       for !t.Failed() {
+               acceptc := make(chan net.Conn)
+               listener := newLocalListener(t)
+               go func() {
+                       for {
+                               conn, err := listener.Accept()
+                               if err != nil {
+                                       close(acceptc)
+                                       return
+                               }
+                               acceptc <- conn
+                       }
+               }()
+
+               addr := listener.Addr().String()
+               dialer := &net.Dialer{
+                       Timeout: timeout,
+               }
+               if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil {
+                       conn.Close()
+                       t.Errorf("DialWithTimeout unexpectedly completed successfully")
+               } else if !isTimeoutError(err) {
+                       t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
+               }
+
+               listener.Close()
+
+               // We're looking for a timeout during the handshake, so check that the
+               // Listener actually accepted the connection to initiate it. (If the server
+               // takes too long to accept the connection, we might cancel before the
+               // underlying net.Conn is ever dialed — without ever attempting a
+               // handshake.)
+               lconn, ok := <-acceptc
+               if ok {
+                       // The Listener accepted a connection, so assume that it was from our
+                       // Dial: we triggered the timeout at the point where we wanted it!
+                       t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr())
+                       lconn.Close()
+               }
+               // Close any spurious extra connecitions from the listener. (This is
+               // possible if there are, for example, stray Dial calls from other tests.)
+               for extraConn := range acceptc {
+                       t.Logf("spurious extra connection from %s", extraConn.RemoteAddr())
+                       extraConn.Close()
+               }
+               if ok {
+                       break
+               }
+
+               t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout)
+               timeout *= 2
+       }
+}
 
-       complete := make(chan bool)
-       defer close(complete)
+func TestDeadlineOnWrite(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping in short mode")
+       }
+
+       ln := newLocalListener(t)
+       defer ln.Close()
+
+       srvCh := make(chan *Conn, 1)
 
        go func() {
-               conn, err := listener.Accept()
+               sconn, err := ln.Accept()
                if err != nil {
-                       t.Error(err)
+                       srvCh <- nil
                        return
                }
-               <-complete
-               conn.Close()
+               srv := Server(sconn, testConfig.Clone())
+               if err := srv.Handshake(); err != nil {
+                       srvCh <- nil
+                       return
+               }
+               srvCh <- srv
        }()
 
-       dialer := &net.Dialer{
-               Timeout: 10 * time.Millisecond,
+       clientConfig := testConfig.Clone()
+       clientConfig.MaxVersion = VersionTLS12
+       conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer conn.Close()
+
+       srv := <-srvCh
+       if srv == nil {
+               t.Error(err)
        }
 
-       var err error
-       if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
-               t.Fatal("DialWithTimeout completed successfully")
+       // Make sure the client/server is setup correctly and is able to do a typical Write/Read
+       buf := make([]byte, 6)
+       if _, err := srv.Write([]byte("foobar")); err != nil {
+               t.Errorf("Write err: %v", err)
+       }
+       if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" {
+               t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
+       }
+
+       // Set a deadline which should cause Write to timeout
+       if err = srv.SetDeadline(time.Now()); err != nil {
+               t.Fatalf("SetDeadline(time.Now()) err: %v", err)
+       }
+       if _, err = srv.Write([]byte("should fail")); err == nil {
+               t.Fatal("Write should have timed out")
+       }
+
+       // Clear deadline and make sure it still times out
+       if err = srv.SetDeadline(time.Time{}); err != nil {
+               t.Fatalf("SetDeadline(time.Time{}) err: %v", err)
+       }
+       if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil {
+               t.Fatal("Write which previously failed should still time out")
        }
 
+       // Verify the error
+       if ne := err.(net.Error); ne.Temporary() != false {
+               t.Error("Write timed out but incorrectly classified the error as Temporary")
+       }
        if !isTimeoutError(err) {
-               t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
+               t.Error("Write timed out but did not classify the error as a Timeout")
+       }
+}
+
+type readerFunc func([]byte) (int, error)
+
+func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
+
+// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
+// (The other cases are all handled by the existing dial tests in this package, which
+// all also flow through the same code shared code paths)
+func TestDialer(t *testing.T) {
+       ln := newLocalListener(t)
+       defer ln.Close()
+
+       unblockServer := make(chan struct{}) // close-only
+       defer close(unblockServer)
+       go func() {
+               conn, err := ln.Accept()
+               if err != nil {
+                       return
+               }
+               defer conn.Close()
+               <-unblockServer
+       }()
+
+       ctx, cancel := context.WithCancel(context.Background())
+       d := Dialer{Config: &Config{
+               Rand: readerFunc(func(b []byte) (n int, err error) {
+                       // By the time crypto/tls wants randomness, that means it has a TCP
+                       // connection, so we're past the Dialer's dial and now blocked
+                       // in a handshake. Cancel our context and see if we get unstuck.
+                       // (Our TCP listener above never reads or writes, so the Handshake
+                       // would otherwise be stuck forever)
+                       cancel()
+                       return len(b), nil
+               }),
+               ServerName: "foo",
+       }}
+       _, err := d.DialContext(ctx, "tcp", ln.Addr().String())
+       if err != context.Canceled {
+               t.Errorf("err = %v; want context.Canceled", err)
        }
 }
 
@@ -294,7 +431,11 @@ func TestTLSUniqueMatches(t *testing.T) {
        defer ln.Close()
 
        serverTLSUniques := make(chan []byte)
+       parentDone := make(chan struct{})
+       childDone := make(chan struct{})
+       defer close(parentDone)
        go func() {
+               defer close(childDone)
                for i := 0; i < 2; i++ {
                        sconn, err := ln.Accept()
                        if err != nil {
@@ -308,7 +449,11 @@ func TestTLSUniqueMatches(t *testing.T) {
                                t.Error(err)
                                return
                        }
-                       serverTLSUniques <- srv.ConnectionState().TLSUnique
+                       select {
+                       case <-parentDone:
+                               return
+                       case serverTLSUniques <- srv.ConnectionState().TLSUnique:
+                       }
                }
        }()
 
@@ -318,9 +463,20 @@ func TestTLSUniqueMatches(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) {
+
+       var serverTLSUniquesValue []byte
+       select {
+       case <-childDone:
+               return
+       case serverTLSUniquesValue = <-serverTLSUniques:
+       }
+
+       if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
                t.Error("client and server channel bindings differ")
        }
+       if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
+               t.Error("tls-unique is empty or zero")
+       }
        conn.Close()
 
        conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
@@ -331,9 +487,19 @@ func TestTLSUniqueMatches(t *testing.T) {
        if !conn.ConnectionState().DidResume {
                t.Error("second session did not use resumption")
        }
-       if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) {
+
+       select {
+       case <-childDone:
+               return
+       case serverTLSUniquesValue = <-serverTLSUniques:
+       }
+
+       if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
                t.Error("client and server channel bindings differ when session resumption is used")
        }
+       if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
+               t.Error("resumption tls-unique is empty or zero")
+       }
 }
 
 func TestVerifyHostname(t *testing.T) {
@@ -433,8 +599,8 @@ func TestConnCloseBreakingWrite(t *testing.T) {
        }
 
        <-closeReturned
-       if err := tconn.Close(); err != errClosed {
-               t.Errorf("Close error = %v; want errClosed", err)
+       if err := tconn.Close(); err != net.ErrClosed {
+               t.Errorf("Close error = %v; want net.ErrClosed", err)
        }
 }
 
@@ -458,7 +624,7 @@ func TestConnCloseWrite(t *testing.T) {
                }
                defer srv.Close()
 
-               data, err := ioutil.ReadAll(srv)
+               data, err := io.ReadAll(srv)
                if err != nil {
                        return err
                }
@@ -499,7 +665,7 @@ func TestConnCloseWrite(t *testing.T) {
                        return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
                }
 
-               data, err := ioutil.ReadAll(conn)
+               data, err := io.ReadAll(conn)
                if err != nil {
                        return err
                }
@@ -562,7 +728,7 @@ func TestWarningAlertFlood(t *testing.T) {
                }
                defer srv.Close()
 
-               _, err = ioutil.ReadAll(srv)
+               _, err = io.ReadAll(srv)
                if err == nil {
                        return errors.New("unexpected lack of error from server")
                }
@@ -598,7 +764,7 @@ func TestWarningAlertFlood(t *testing.T) {
 }
 
 func TestCloneFuncFields(t *testing.T) {
-       const expectedCount = 5
+       const expectedCount = 8
        called := 0
 
        c1 := Config{
@@ -622,6 +788,18 @@ func TestCloneFuncFields(t *testing.T) {
                        called |= 1 << 4
                        return nil
                },
+               VerifyConnection: func(ConnectionState) error {
+                       called |= 1 << 5
+                       return nil
+               },
+               UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) {
+                       called |= 1 << 6
+                       return nil, nil
+               },
+               WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) {
+                       called |= 1 << 7
+                       return nil, nil
+               },
        }
 
        c2 := c1.Clone()
@@ -631,6 +809,9 @@ func TestCloneFuncFields(t *testing.T) {
        c2.GetClientCertificate(nil)
        c2.GetConfigForClient(nil)
        c2.VerifyPeerCertificate(nil, nil)
+       c2.VerifyConnection(ConnectionState{})
+       c2.UnwrapSession(nil, ConnectionState{})
+       c2.WrapSession(ConnectionState{}, nil)
 
        if called != (1<<expectedCount)-1 {
                t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
@@ -644,17 +825,12 @@ func TestCloneNonFuncFields(t *testing.T) {
        typ := v.Type()
        for i := 0; i < typ.NumField(); i++ {
                f := v.Field(i)
-               if !f.CanSet() {
-                       // unexported field; not cloned.
-                       continue
-               }
-
                // testing/quick can't handle functions or interfaces and so
                // isn't used here.
                switch fn := typ.Field(i).Name; fn {
                case "Rand":
                        f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
-               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
+               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession":
                        // DeepEqual can't compare functions. If you add a
                        // function field to this list, you must also change
                        // TestCloneFuncFields to ensure that the func field is
@@ -689,22 +865,29 @@ func TestCloneNonFuncFields(t *testing.T) {
                        f.Set(reflect.ValueOf([]CurveID{CurveP256}))
                case "Renegotiation":
                        f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
+               case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
+                       continue // these are unexported fields that are handled separately
                default:
                        t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
                }
        }
+       // Set the unexported fields related to session ticket keys, which are copied with Clone().
+       c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
+       c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
 
        c2 := c1.Clone()
-       // DeepEqual also compares unexported fields, thus c2 needs to have run
-       // serverInit in order to be DeepEqual to c1. Cloning it and discarding
-       // the result is sufficient.
-       c2.Clone()
-
        if !reflect.DeepEqual(&c1, c2) {
                t.Errorf("clone failed to copy a field")
        }
 }
 
+func TestCloneNilConfig(t *testing.T) {
+       var config *Config
+       if cc := config.Clone(); cc != nil {
+               t.Fatalf("Clone with nil should return nil, got: %+v", cc)
+       }
+}
+
 // changeImplConn is a net.Conn which can change its Write and Close
 // methods.
 type changeImplConn struct {
@@ -980,8 +1163,8 @@ func TestConnectionState(t *testing.T) {
                        if ss.ServerName != serverName {
                                t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
                        }
-                       if cs.ServerName != "" {
-                               t.Errorf("Got unexpected server name on the client side")
+                       if cs.ServerName != serverName {
+                               t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
                        }
 
                        if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
@@ -1182,6 +1365,7 @@ func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
                        SupportedPoints:   []uint8{pointFormatUncompressed},
                        SignatureSchemes:  []SignatureScheme{Ed25519},
                        SupportedVersions: []uint16{VersionTLS10},
+                       config:            &Config{MinVersion: VersionTLS10},
                }, "doesn't support Ed25519"},
                {ed25519Cert, &ClientHelloInfo{
                        CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
@@ -1196,6 +1380,7 @@ func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
                        SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
                        SupportedPoints:   []uint8{pointFormatUncompressed},
                        SupportedVersions: []uint16{VersionTLS10},
+                       config:            &Config{MinVersion: VersionTLS10},
                }, ""},
                {rsaCert, &ClientHelloInfo{
                        CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
@@ -1241,7 +1426,7 @@ func TestCipherSuites(t *testing.T) {
                }
        }
 
-       cipherSuiteByID := func(id uint16) *CipherSuite {
+       CipherSuiteByID := func(id uint16) *CipherSuite {
                for _, c := range CipherSuites() {
                        if c.ID == id {
                                return c
@@ -1256,15 +1441,12 @@ func TestCipherSuites(t *testing.T) {
        }
 
        for _, c := range cipherSuites {
-               cc := cipherSuiteByID(c.id)
+               cc := CipherSuiteByID(c.id)
                if cc == nil {
                        t.Errorf("%#04x: no CipherSuite entry", c.id)
                        continue
                }
 
-               if defaultOff := c.flags&suiteDefaultOff != 0; defaultOff != cc.Insecure {
-                       t.Errorf("%#04x: Insecure %v, expected %v", c.id, cc.Insecure, defaultOff)
-               }
                if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
                        t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
                } else if !tls12Only && len(cc.SupportedVersions) != 3 {
@@ -1276,7 +1458,7 @@ func TestCipherSuites(t *testing.T) {
                }
        }
        for _, c := range cipherSuitesTLS13 {
-               cc := cipherSuiteByID(c.id)
+               cc := CipherSuiteByID(c.id)
                if cc == nil {
                        t.Errorf("%#04x: no CipherSuite entry", c.id)
                        continue
@@ -1297,6 +1479,152 @@ func TestCipherSuites(t *testing.T) {
        if got := CipherSuiteName(0xabc); got != "0x0ABC" {
                t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
        }
+
+       if len(cipherSuitesPreferenceOrder) != len(cipherSuites) {
+               t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites")
+       }
+       if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) {
+               t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder")
+       }
+
+       // Check that disabled suites are at the end of the preference lists, and
+       // that they are marked insecure.
+       for i, id := range disabledCipherSuites {
+               offset := len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
+               if cipherSuitesPreferenceOrder[offset+i] != id {
+                       t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrder", i)
+               }
+               if cipherSuitesPreferenceOrderNoAES[offset+i] != id {
+                       t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrderNoAES", i)
+               }
+               c := CipherSuiteByID(id)
+               if c == nil {
+                       t.Errorf("%#04x: no CipherSuite entry", id)
+                       continue
+               }
+               if !c.Insecure {
+                       t.Errorf("%#04x: disabled by default but not marked insecure", id)
+               }
+       }
+
+       for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} {
+               // Check that insecure and HTTP/2 bad cipher suites are at the end of
+               // the preference lists.
+               var sawInsecure, sawBad bool
+               for _, id := range prefOrder {
+                       c := CipherSuiteByID(id)
+                       if c == nil {
+                               t.Errorf("%#04x: no CipherSuite entry", id)
+                               continue
+                       }
+
+                       if c.Insecure {
+                               sawInsecure = true
+                       } else if sawInsecure {
+                               t.Errorf("%#04x: secure suite after insecure one(s)", id)
+                       }
+
+                       if http2isBadCipher(id) {
+                               sawBad = true
+                       } else if sawBad {
+                               t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id)
+                       }
+               }
+
+               // Check that the list is sorted according to the documented criteria.
+               isBetter := func(a, b int) bool {
+                       aSuite, bSuite := cipherSuiteByID(prefOrder[a]), cipherSuiteByID(prefOrder[b])
+                       aName, bName := CipherSuiteName(prefOrder[a]), CipherSuiteName(prefOrder[b])
+                       // * < RC4
+                       if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") {
+                               return true
+                       } else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") {
+                               return false
+                       }
+                       // * < CBC_SHA256
+                       if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") {
+                               return true
+                       } else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") {
+                               return false
+                       }
+                       // * < 3DES
+                       if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") {
+                               return true
+                       } else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") {
+                               return false
+                       }
+                       // ECDHE < *
+                       if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 {
+                               return true
+                       } else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 {
+                               return false
+                       }
+                       // AEAD < CBC
+                       if aSuite.aead != nil && bSuite.aead == nil {
+                               return true
+                       } else if aSuite.aead == nil && bSuite.aead != nil {
+                               return false
+                       }
+                       // AES < ChaCha20
+                       if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") {
+                               return i == 0 // true for cipherSuitesPreferenceOrder
+                       } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") {
+                               return i != 0 // true for cipherSuitesPreferenceOrderNoAES
+                       }
+                       // AES-128 < AES-256
+                       if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") {
+                               return true
+                       } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") {
+                               return false
+                       }
+                       // ECDSA < RSA
+                       if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 {
+                               return true
+                       } else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 {
+                               return false
+                       }
+                       t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName)
+                       panic("unreachable")
+               }
+               if !sort.SliceIsSorted(prefOrder, isBetter) {
+                       t.Error("preference order is not sorted according to the rules")
+               }
+       }
+}
+
+func TestVersionName(t *testing.T) {
+       if got, exp := VersionName(VersionTLS13), "TLS 1.3"; got != exp {
+               t.Errorf("unexpected VersionName: got %q, expected %q", got, exp)
+       }
+       if got, exp := VersionName(0x12a), "0x012A"; got != exp {
+               t.Errorf("unexpected fallback VersionName: got %q, expected %q", got, exp)
+       }
+}
+
+// http2isBadCipher is copied from net/http.
+// TODO: if it ends up exposed somewhere, use that instead.
+func http2isBadCipher(cipher uint16) bool {
+       switch cipher {
+       case TLS_RSA_WITH_RC4_128_SHA,
+               TLS_RSA_WITH_3DES_EDE_CBC_SHA,
+               TLS_RSA_WITH_AES_128_CBC_SHA,
+               TLS_RSA_WITH_AES_256_CBC_SHA,
+               TLS_RSA_WITH_AES_128_CBC_SHA256,
+               TLS_RSA_WITH_AES_128_GCM_SHA256,
+               TLS_RSA_WITH_AES_256_GCM_SHA384,
+               TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
+               TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+               TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+               TLS_ECDHE_RSA_WITH_RC4_128_SHA,
+               TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
+               TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+               TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+               TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
+               TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
+               return true
+       default:
+               return false
+       }
 }
 
 type brokenSigner struct{ crypto.Signer }
@@ -1307,7 +1635,7 @@ func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts
 }
 
 // TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that
-// always makes PKCS#1 v1.5 signatures, so can't be used with RSA-PSS.
+// always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS.
 func TestPKCS1OnlyCert(t *testing.T) {
        clientConfig := testConfig.Clone()
        clientConfig.Certificates = []Certificate{{
@@ -1315,7 +1643,7 @@ func TestPKCS1OnlyCert(t *testing.T) {
                PrivateKey:  brokenSigner{testRSAPrivateKey},
        }}
        serverConfig := testConfig.Clone()
-       serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS#1 v1.5
+       serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5
        serverConfig.ClientAuth = RequireAnyClientCert
 
        // If RSA-PSS is selected, the handshake should fail.
@@ -1332,3 +1660,147 @@ func TestPKCS1OnlyCert(t *testing.T) {
                t.Error(err)
        }
 }
+
+func TestVerifyCertificates(t *testing.T) {
+       // See https://go.dev/issue/31641.
+       t.Run("TLSv12", func(t *testing.T) { testVerifyCertificates(t, VersionTLS12) })
+       t.Run("TLSv13", func(t *testing.T) { testVerifyCertificates(t, VersionTLS13) })
+}
+
+func testVerifyCertificates(t *testing.T, version uint16) {
+       tests := []struct {
+               name string
+
+               InsecureSkipVerify bool
+               ClientAuth         ClientAuthType
+               ClientCertificates bool
+       }{
+               {
+                       name: "defaults",
+               },
+               {
+                       name:               "InsecureSkipVerify",
+                       InsecureSkipVerify: true,
+               },
+               {
+                       name:       "RequestClientCert with no certs",
+                       ClientAuth: RequestClientCert,
+               },
+               {
+                       name:               "RequestClientCert with certs",
+                       ClientAuth:         RequestClientCert,
+                       ClientCertificates: true,
+               },
+               {
+                       name:               "RequireAnyClientCert",
+                       ClientAuth:         RequireAnyClientCert,
+                       ClientCertificates: true,
+               },
+               {
+                       name:       "VerifyClientCertIfGiven with no certs",
+                       ClientAuth: VerifyClientCertIfGiven,
+               },
+               {
+                       name:               "VerifyClientCertIfGiven with certs",
+                       ClientAuth:         VerifyClientCertIfGiven,
+                       ClientCertificates: true,
+               },
+               {
+                       name:               "RequireAndVerifyClientCert",
+                       ClientAuth:         RequireAndVerifyClientCert,
+                       ClientCertificates: true,
+               },
+       }
+
+       issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+       if err != nil {
+               t.Fatal(err)
+       }
+       rootCAs := x509.NewCertPool()
+       rootCAs.AddCert(issuer)
+
+       for _, test := range tests {
+               test := test
+               t.Run(test.name, func(t *testing.T) {
+                       t.Parallel()
+
+                       var serverVerifyConnection, clientVerifyConnection bool
+                       var serverVerifyPeerCertificates, clientVerifyPeerCertificates bool
+
+                       clientConfig := testConfig.Clone()
+                       clientConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
+                       clientConfig.MaxVersion = version
+                       clientConfig.MinVersion = version
+                       clientConfig.RootCAs = rootCAs
+                       clientConfig.ServerName = "example.golang"
+                       clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
+                       serverConfig := clientConfig.Clone()
+                       serverConfig.ClientCAs = rootCAs
+
+                       clientConfig.VerifyConnection = func(cs ConnectionState) error {
+                               clientVerifyConnection = true
+                               return nil
+                       }
+                       clientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+                               clientVerifyPeerCertificates = true
+                               return nil
+                       }
+                       serverConfig.VerifyConnection = func(cs ConnectionState) error {
+                               serverVerifyConnection = true
+                               return nil
+                       }
+                       serverConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+                               serverVerifyPeerCertificates = true
+                               return nil
+                       }
+
+                       clientConfig.InsecureSkipVerify = test.InsecureSkipVerify
+                       serverConfig.ClientAuth = test.ClientAuth
+                       if !test.ClientCertificates {
+                               clientConfig.Certificates = nil
+                       }
+
+                       if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
+                               t.Fatal(err)
+                       }
+
+                       want := serverConfig.ClientAuth != NoClientCert
+                       if serverVerifyPeerCertificates != want {
+                               t.Errorf("VerifyPeerCertificates on the server: got %v, want %v",
+                                       serverVerifyPeerCertificates, want)
+                       }
+                       if !clientVerifyPeerCertificates {
+                               t.Errorf("VerifyPeerCertificates not called on the client")
+                       }
+                       if !serverVerifyConnection {
+                               t.Error("VerifyConnection did not get called on the server")
+                       }
+                       if !clientVerifyConnection {
+                               t.Error("VerifyConnection did not get called on the client")
+                       }
+
+                       serverVerifyPeerCertificates, clientVerifyPeerCertificates = false, false
+                       serverVerifyConnection, clientVerifyConnection = false, false
+                       cs, _, err := testHandshake(t, clientConfig, serverConfig)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       if !cs.DidResume {
+                               t.Error("expected resumption")
+                       }
+
+                       if serverVerifyPeerCertificates {
+                               t.Error("VerifyPeerCertificates got called on the server on resumption")
+                       }
+                       if clientVerifyPeerCertificates {
+                               t.Error("VerifyPeerCertificates got called on the client on resumption")
+                       }
+                       if !serverVerifyConnection {
+                               t.Error("VerifyConnection did not get called on the server on resumption")
+                       }
+                       if !clientVerifyConnection {
+                               t.Error("VerifyConnection did not get called on the client on resumption")
+                       }
+               })
+       }
+}