]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/crypto/tls/handshake_messages_test.go
crypto/tls: add SessionState and use it on the server side
[gostls13.git] / src / crypto / tls / handshake_messages_test.go
index 1ef6c432ffad51192778c760a89b8e02abf9921c..b280f0967468a84c91a55b341c77bb8a303bbfbc 100644 (file)
@@ -15,7 +15,7 @@ import (
        "time"
 )
 
-var tests = []any{
+var tests = []handshakeMessage{
        &clientHelloMsg{},
        &serverHelloMsg{},
        &finishedMsg{},
@@ -28,14 +28,13 @@ var tests = []any{
        &certificateStatusMsg{},
        &clientKeyExchangeMsg{},
        &newSessionTicketMsg{},
-       &sessionState{},
-       &sessionStateTLS13{},
        &encryptedExtensionsMsg{},
        &endOfEarlyDataMsg{},
        &keyUpdateMsg{},
        &newSessionTicketMsgTLS13{},
        &certificateRequestMsgTLS13{},
        &certificateMsgTLS13{},
+       &SessionState{},
 }
 
 func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
@@ -50,8 +49,8 @@ func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
 func TestMarshalUnmarshal(t *testing.T) {
        rand := rand.New(rand.NewSource(time.Now().UnixNano()))
 
-       for i, iface := range tests {
-               ty := reflect.ValueOf(iface).Type()
+       for i, m := range tests {
+               ty := reflect.ValueOf(m).Type()
 
                n := 100
                if testing.Short() {
@@ -66,15 +65,14 @@ func TestMarshalUnmarshal(t *testing.T) {
 
                        m1 := v.Interface().(handshakeMessage)
                        marshaled := mustMarshal(t, m1)
-                       m2 := iface.(handshakeMessage)
-                       if !m2.unmarshal(marshaled) {
+                       if !m.unmarshal(marshaled) {
                                t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
                                break
                        }
-                       m2.marshal() // to fill any marshal cache in the message
+                       m.marshal() // to fill any marshal cache in the message
 
-                       if !reflect.DeepEqual(m1, m2) {
-                               t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
+                       if !reflect.DeepEqual(m1, m) {
+                               t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
                                break
                        }
 
@@ -85,7 +83,7 @@ func TestMarshalUnmarshal(t *testing.T) {
                                // data is optional and the length of the
                                // Finished varies across versions.
                                for j := 0; j < len(marshaled); j++ {
-                                       if m2.unmarshal(marshaled[0:j]) {
+                                       if m.unmarshal(marshaled[0:j]) {
                                                t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
                                                break
                                        }
@@ -97,9 +95,7 @@ func TestMarshalUnmarshal(t *testing.T) {
 
 func TestFuzz(t *testing.T) {
        rand := rand.New(rand.NewSource(0))
-       for _, iface := range tests {
-               m := iface.(handshakeMessage)
-
+       for _, m := range tests {
                for j := 0; j < 1000; j++ {
                        len := rand.Intn(100)
                        bytes := randomBytes(len, rand)
@@ -317,22 +313,11 @@ func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        return reflect.ValueOf(m)
 }
 
-func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
-       s := &sessionState{}
-       s.vers = uint16(rand.Intn(10000))
+func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
+       s := &SessionState{}
+       s.version = uint16(rand.Intn(10000))
        s.cipherSuite = uint16(rand.Intn(10000))
-       s.masterSecret = randomBytes(rand.Intn(100)+1, rand)
-       s.createdAt = uint64(rand.Int63())
-       for i := 0; i < rand.Intn(20); i++ {
-               s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand))
-       }
-       return reflect.ValueOf(s)
-}
-
-func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
-       s := &sessionStateTLS13{}
-       s.cipherSuite = uint16(rand.Intn(10000))
-       s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
+       s.secret = randomBytes(rand.Intn(100)+1, rand)
        s.createdAt = uint64(rand.Int63())
        for i := 0; i < rand.Intn(2)+1; i++ {
                s.certificate.Certificate = append(
@@ -350,6 +335,16 @@ func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
        return reflect.ValueOf(s)
 }
 
+func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
+func (s *SessionState) unmarshal(b []byte) bool {
+       ss, err := ParseSessionState(b)
+       if err != nil {
+               return false
+       }
+       *s = *ss
+       return true
+}
+
 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m := &endOfEarlyDataMsg{}
        return reflect.ValueOf(m)