]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/crypto/tls/tls_test.go
crypto/tls: change default minimum version to 1.2
[gostls13.git] / src / crypto / tls / tls_test.go
index d8a43add1796f6ff4a04a6a5136097746effd59a..16f655dd93dbebcca445931aa1e4b64e5eb39b8b 100644 (file)
@@ -170,35 +170,59 @@ func TestDialTimeout(t *testing.T) {
        if testing.Short() {
                t.Skip("skipping in short mode")
        }
-       listener := newLocalListener(t)
 
-       addr := listener.Addr().String()
-       defer listener.Close()
-
-       complete := make(chan bool)
-       defer close(complete)
+       timeout := 100 * time.Microsecond
+       for !t.Failed() {
+               acceptc := make(chan net.Conn)
+               listener := newLocalListener(t)
+               go func() {
+                       for {
+                               conn, err := listener.Accept()
+                               if err != nil {
+                                       close(acceptc)
+                                       return
+                               }
+                               acceptc <- conn
+                       }
+               }()
 
-       go func() {
-               conn, err := listener.Accept()
-               if err != nil {
-                       t.Error(err)
-                       return
+               addr := listener.Addr().String()
+               dialer := &net.Dialer{
+                       Timeout: timeout,
+               }
+               if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil {
+                       conn.Close()
+                       t.Errorf("DialWithTimeout unexpectedly completed successfully")
+               } else if !isTimeoutError(err) {
+                       t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
                }
-               <-complete
-               conn.Close()
-       }()
 
-       dialer := &net.Dialer{
-               Timeout: 10 * time.Millisecond,
-       }
+               listener.Close()
 
-       var err error
-       if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
-               t.Fatal("DialWithTimeout completed successfully")
-       }
+               // We're looking for a timeout during the handshake, so check that the
+               // Listener actually accepted the connection to initiate it. (If the server
+               // takes too long to accept the connection, we might cancel before the
+               // underlying net.Conn is ever dialed — without ever attempting a
+               // handshake.)
+               lconn, ok := <-acceptc
+               if ok {
+                       // The Listener accepted a connection, so assume that it was from our
+                       // Dial: we triggered the timeout at the point where we wanted it!
+                       t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr())
+                       lconn.Close()
+               }
+               // Close any spurious extra connecitions from the listener. (This is
+               // possible if there are, for example, stray Dial calls from other tests.)
+               for extraConn := range acceptc {
+                       t.Logf("spurious extra connection from %s", extraConn.RemoteAddr())
+                       extraConn.Close()
+               }
+               if ok {
+                       break
+               }
 
-       if !isTimeoutError(err) {
-               t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
+               t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout)
+               timeout *= 2
        }
 }
 
@@ -450,6 +474,9 @@ func TestTLSUniqueMatches(t *testing.T) {
        if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
                t.Error("client and server channel bindings differ")
        }
+       if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
+               t.Error("tls-unique is empty or zero")
+       }
        conn.Close()
 
        conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
@@ -470,6 +497,9 @@ func TestTLSUniqueMatches(t *testing.T) {
        if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
                t.Error("client and server channel bindings differ when session resumption is used")
        }
+       if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
+               t.Error("resumption tls-unique is empty or zero")
+       }
 }
 
 func TestVerifyHostname(t *testing.T) {
@@ -734,7 +764,7 @@ func TestWarningAlertFlood(t *testing.T) {
 }
 
 func TestCloneFuncFields(t *testing.T) {
-       const expectedCount = 6
+       const expectedCount = 8
        called := 0
 
        c1 := Config{
@@ -762,6 +792,14 @@ func TestCloneFuncFields(t *testing.T) {
                        called |= 1 << 5
                        return nil
                },
+               UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) {
+                       called |= 1 << 6
+                       return nil, nil
+               },
+               WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) {
+                       called |= 1 << 7
+                       return nil, nil
+               },
        }
 
        c2 := c1.Clone()
@@ -772,6 +810,8 @@ func TestCloneFuncFields(t *testing.T) {
        c2.GetConfigForClient(nil)
        c2.VerifyPeerCertificate(nil, nil)
        c2.VerifyConnection(ConnectionState{})
+       c2.UnwrapSession(nil, ConnectionState{})
+       c2.WrapSession(ConnectionState{}, nil)
 
        if called != (1<<expectedCount)-1 {
                t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
@@ -790,7 +830,7 @@ func TestCloneNonFuncFields(t *testing.T) {
                switch fn := typ.Field(i).Name; fn {
                case "Rand":
                        f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
-               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
+               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession":
                        // DeepEqual can't compare functions. If you add a
                        // function field to this list, you must also change
                        // TestCloneFuncFields to ensure that the func field is
@@ -1325,6 +1365,7 @@ func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
                        SupportedPoints:   []uint8{pointFormatUncompressed},
                        SignatureSchemes:  []SignatureScheme{Ed25519},
                        SupportedVersions: []uint16{VersionTLS10},
+                       config:            &Config{MinVersion: VersionTLS10},
                }, "doesn't support Ed25519"},
                {ed25519Cert, &ClientHelloInfo{
                        CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
@@ -1339,6 +1380,7 @@ func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
                        SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
                        SupportedPoints:   []uint8{pointFormatUncompressed},
                        SupportedVersions: []uint16{VersionTLS10},
+                       config:            &Config{MinVersion: VersionTLS10},
                }, ""},
                {rsaCert, &ClientHelloInfo{
                        CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
@@ -1550,6 +1592,15 @@ func TestCipherSuites(t *testing.T) {
        }
 }
 
+func TestVersionName(t *testing.T) {
+       if got, exp := VersionName(VersionTLS13), "TLS 1.3"; got != exp {
+               t.Errorf("unexpected VersionName: got %q, expected %q", got, exp)
+       }
+       if got, exp := VersionName(0x12a), "0x012A"; got != exp {
+               t.Errorf("unexpected fallback VersionName: got %q, expected %q", got, exp)
+       }
+}
+
 // http2isBadCipher is copied from net/http.
 // TODO: if it ends up exposed somewhere, use that instead.
 func http2isBadCipher(cipher uint16) bool {
@@ -1609,3 +1660,147 @@ func TestPKCS1OnlyCert(t *testing.T) {
                t.Error(err)
        }
 }
+
+func TestVerifyCertificates(t *testing.T) {
+       // See https://go.dev/issue/31641.
+       t.Run("TLSv12", func(t *testing.T) { testVerifyCertificates(t, VersionTLS12) })
+       t.Run("TLSv13", func(t *testing.T) { testVerifyCertificates(t, VersionTLS13) })
+}
+
+func testVerifyCertificates(t *testing.T, version uint16) {
+       tests := []struct {
+               name string
+
+               InsecureSkipVerify bool
+               ClientAuth         ClientAuthType
+               ClientCertificates bool
+       }{
+               {
+                       name: "defaults",
+               },
+               {
+                       name:               "InsecureSkipVerify",
+                       InsecureSkipVerify: true,
+               },
+               {
+                       name:       "RequestClientCert with no certs",
+                       ClientAuth: RequestClientCert,
+               },
+               {
+                       name:               "RequestClientCert with certs",
+                       ClientAuth:         RequestClientCert,
+                       ClientCertificates: true,
+               },
+               {
+                       name:               "RequireAnyClientCert",
+                       ClientAuth:         RequireAnyClientCert,
+                       ClientCertificates: true,
+               },
+               {
+                       name:       "VerifyClientCertIfGiven with no certs",
+                       ClientAuth: VerifyClientCertIfGiven,
+               },
+               {
+                       name:               "VerifyClientCertIfGiven with certs",
+                       ClientAuth:         VerifyClientCertIfGiven,
+                       ClientCertificates: true,
+               },
+               {
+                       name:               "RequireAndVerifyClientCert",
+                       ClientAuth:         RequireAndVerifyClientCert,
+                       ClientCertificates: true,
+               },
+       }
+
+       issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+       if err != nil {
+               t.Fatal(err)
+       }
+       rootCAs := x509.NewCertPool()
+       rootCAs.AddCert(issuer)
+
+       for _, test := range tests {
+               test := test
+               t.Run(test.name, func(t *testing.T) {
+                       t.Parallel()
+
+                       var serverVerifyConnection, clientVerifyConnection bool
+                       var serverVerifyPeerCertificates, clientVerifyPeerCertificates bool
+
+                       clientConfig := testConfig.Clone()
+                       clientConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
+                       clientConfig.MaxVersion = version
+                       clientConfig.MinVersion = version
+                       clientConfig.RootCAs = rootCAs
+                       clientConfig.ServerName = "example.golang"
+                       clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
+                       serverConfig := clientConfig.Clone()
+                       serverConfig.ClientCAs = rootCAs
+
+                       clientConfig.VerifyConnection = func(cs ConnectionState) error {
+                               clientVerifyConnection = true
+                               return nil
+                       }
+                       clientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+                               clientVerifyPeerCertificates = true
+                               return nil
+                       }
+                       serverConfig.VerifyConnection = func(cs ConnectionState) error {
+                               serverVerifyConnection = true
+                               return nil
+                       }
+                       serverConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+                               serverVerifyPeerCertificates = true
+                               return nil
+                       }
+
+                       clientConfig.InsecureSkipVerify = test.InsecureSkipVerify
+                       serverConfig.ClientAuth = test.ClientAuth
+                       if !test.ClientCertificates {
+                               clientConfig.Certificates = nil
+                       }
+
+                       if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
+                               t.Fatal(err)
+                       }
+
+                       want := serverConfig.ClientAuth != NoClientCert
+                       if serverVerifyPeerCertificates != want {
+                               t.Errorf("VerifyPeerCertificates on the server: got %v, want %v",
+                                       serverVerifyPeerCertificates, want)
+                       }
+                       if !clientVerifyPeerCertificates {
+                               t.Errorf("VerifyPeerCertificates not called on the client")
+                       }
+                       if !serverVerifyConnection {
+                               t.Error("VerifyConnection did not get called on the server")
+                       }
+                       if !clientVerifyConnection {
+                               t.Error("VerifyConnection did not get called on the client")
+                       }
+
+                       serverVerifyPeerCertificates, clientVerifyPeerCertificates = false, false
+                       serverVerifyConnection, clientVerifyConnection = false, false
+                       cs, _, err := testHandshake(t, clientConfig, serverConfig)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       if !cs.DidResume {
+                               t.Error("expected resumption")
+                       }
+
+                       if serverVerifyPeerCertificates {
+                               t.Error("VerifyPeerCertificates got called on the server on resumption")
+                       }
+                       if clientVerifyPeerCertificates {
+                               t.Error("VerifyPeerCertificates got called on the client on resumption")
+                       }
+                       if !serverVerifyConnection {
+                               t.Error("VerifyConnection did not get called on the server on resumption")
+                       }
+                       if !clientVerifyConnection {
+                               t.Error("VerifyConnection did not get called on the client on resumption")
+                       }
+               })
+       }
+}