]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/handshake_messages_test.go
[dev.boringcrypto] all: merge master into dev.boringcrypto
[gostls13.git] / src / crypto / tls / handshake_messages_test.go
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "math/rand"
10         "reflect"
11         "strings"
12         "testing"
13         "testing/quick"
14         "time"
15 )
16
17 var tests = []interface{}{
18         &clientHelloMsg{},
19         &serverHelloMsg{},
20         &finishedMsg{},
21
22         &certificateMsg{},
23         &certificateRequestMsg{},
24         &certificateVerifyMsg{
25                 hasSignatureAlgorithm: true,
26         },
27         &certificateStatusMsg{},
28         &clientKeyExchangeMsg{},
29         &nextProtoMsg{},
30         &newSessionTicketMsg{},
31         &sessionState{},
32 }
33
34 func TestMarshalUnmarshal(t *testing.T) {
35         rand := rand.New(rand.NewSource(time.Now().UnixNano()))
36
37         for i, iface := range tests {
38                 ty := reflect.ValueOf(iface).Type()
39
40                 n := 100
41                 if testing.Short() {
42                         n = 5
43                 }
44                 for j := 0; j < n; j++ {
45                         v, ok := quick.Value(ty, rand)
46                         if !ok {
47                                 t.Errorf("#%d: failed to create value", i)
48                                 break
49                         }
50
51                         m1 := v.Interface().(handshakeMessage)
52                         marshaled := m1.marshal()
53                         m2 := iface.(handshakeMessage)
54                         if !m2.unmarshal(marshaled) {
55                                 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
56                                 break
57                         }
58                         m2.marshal() // to fill any marshal cache in the message
59
60                         if !reflect.DeepEqual(m1, m2) {
61                                 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
62                                 break
63                         }
64
65                         if i >= 3 {
66                                 // The first three message types (ClientHello,
67                                 // ServerHello and Finished) are allowed to
68                                 // have parsable prefixes because the extension
69                                 // data is optional and the length of the
70                                 // Finished varies across versions.
71                                 for j := 0; j < len(marshaled); j++ {
72                                         if m2.unmarshal(marshaled[0:j]) {
73                                                 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
74                                                 break
75                                         }
76                                 }
77                         }
78                 }
79         }
80 }
81
82 func TestFuzz(t *testing.T) {
83         rand := rand.New(rand.NewSource(0))
84         for _, iface := range tests {
85                 m := iface.(handshakeMessage)
86
87                 for j := 0; j < 1000; j++ {
88                         len := rand.Intn(100)
89                         bytes := randomBytes(len, rand)
90                         // This just looks for crashes due to bounds errors etc.
91                         m.unmarshal(bytes)
92                 }
93         }
94 }
95
96 func randomBytes(n int, rand *rand.Rand) []byte {
97         r := make([]byte, n)
98         if _, err := rand.Read(r); err != nil {
99                 panic("rand.Read failed: " + err.Error())
100         }
101         return r
102 }
103
104 func randomString(n int, rand *rand.Rand) string {
105         b := randomBytes(n, rand)
106         return string(b)
107 }
108
109 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
110         m := &clientHelloMsg{}
111         m.vers = uint16(rand.Intn(65536))
112         m.random = randomBytes(32, rand)
113         m.sessionId = randomBytes(rand.Intn(32), rand)
114         m.cipherSuites = make([]uint16, rand.Intn(63)+1)
115         for i := 0; i < len(m.cipherSuites); i++ {
116                 cs := uint16(rand.Int31())
117                 if cs == scsvRenegotiation {
118                         cs += 1
119                 }
120                 m.cipherSuites[i] = cs
121         }
122         m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
123         if rand.Intn(10) > 5 {
124                 m.nextProtoNeg = true
125         }
126         if rand.Intn(10) > 5 {
127                 m.serverName = randomString(rand.Intn(255), rand)
128                 for strings.HasSuffix(m.serverName, ".") {
129                         m.serverName = m.serverName[:len(m.serverName)-1]
130                 }
131         }
132         m.ocspStapling = rand.Intn(10) > 5
133         m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
134         m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
135         for i := range m.supportedCurves {
136                 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
137         }
138         if rand.Intn(10) > 5 {
139                 m.ticketSupported = true
140                 if rand.Intn(10) > 5 {
141                         m.sessionTicket = randomBytes(rand.Intn(300), rand)
142                 } else {
143                         m.sessionTicket = make([]byte, 0)
144                 }
145         }
146         if rand.Intn(10) > 5 {
147                 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
148         }
149         if rand.Intn(10) > 5 {
150                 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
151         }
152         for i := 0; i < rand.Intn(5); i++ {
153                 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
154         }
155         if rand.Intn(10) > 5 {
156                 m.scts = true
157         }
158         if rand.Intn(10) > 5 {
159                 m.secureRenegotiationSupported = true
160                 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
161         }
162         for i := 0; i < rand.Intn(5); i++ {
163                 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
164         }
165         if rand.Intn(10) > 5 {
166                 m.cookie = randomBytes(rand.Intn(500)+1, rand)
167         }
168         for i := 0; i < rand.Intn(5); i++ {
169                 var ks keyShare
170                 ks.group = CurveID(rand.Intn(30000) + 1)
171                 ks.data = randomBytes(rand.Intn(200)+1, rand)
172                 m.keyShares = append(m.keyShares, ks)
173         }
174         switch rand.Intn(3) {
175         case 1:
176                 m.pskModes = []uint8{pskModeDHE}
177         case 2:
178                 m.pskModes = []uint8{pskModeDHE, pskModePlain}
179         }
180         for i := 0; i < rand.Intn(5); i++ {
181                 var psk pskIdentity
182                 psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
183                 psk.label = randomBytes(rand.Intn(500)+1, rand)
184                 m.pskIdentities = append(m.pskIdentities, psk)
185                 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
186         }
187
188         return reflect.ValueOf(m)
189 }
190
191 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
192         m := &serverHelloMsg{}
193         m.vers = uint16(rand.Intn(65536))
194         m.random = randomBytes(32, rand)
195         m.sessionId = randomBytes(rand.Intn(32), rand)
196         m.cipherSuite = uint16(rand.Int31())
197         m.compressionMethod = uint8(rand.Intn(256))
198
199         if rand.Intn(10) > 5 {
200                 m.nextProtoNeg = true
201                 for i := 0; i < rand.Intn(10); i++ {
202                         m.nextProtos = append(m.nextProtos, randomString(20, rand))
203                 }
204         }
205
206         if rand.Intn(10) > 5 {
207                 m.ocspStapling = true
208         }
209         if rand.Intn(10) > 5 {
210                 m.ticketSupported = true
211         }
212         m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
213
214         for i := 0; i < rand.Intn(4); i++ {
215                 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
216         }
217
218         if rand.Intn(10) > 5 {
219                 m.secureRenegotiationSupported = true
220                 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
221         }
222         if rand.Intn(10) > 5 {
223                 m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
224         }
225         if rand.Intn(10) > 5 {
226                 m.cookie = randomBytes(rand.Intn(500)+1, rand)
227         }
228         if rand.Intn(10) > 5 {
229                 for i := 0; i < rand.Intn(5); i++ {
230                         m.serverShare.group = CurveID(rand.Intn(30000) + 1)
231                         m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
232                 }
233         } else if rand.Intn(10) > 5 {
234                 m.selectedGroup = CurveID(rand.Intn(30000) + 1)
235         }
236         if rand.Intn(10) > 5 {
237                 m.selectedIdentityPresent = true
238                 m.selectedIdentity = uint16(rand.Intn(0xffff))
239         }
240
241         return reflect.ValueOf(m)
242 }
243
244 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
245         m := &certificateMsg{}
246         numCerts := rand.Intn(20)
247         m.certificates = make([][]byte, numCerts)
248         for i := 0; i < numCerts; i++ {
249                 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
250         }
251         return reflect.ValueOf(m)
252 }
253
254 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
255         m := &certificateRequestMsg{}
256         m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
257         for i := 0; i < rand.Intn(100); i++ {
258                 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
259         }
260         return reflect.ValueOf(m)
261 }
262
263 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
264         m := &certificateVerifyMsg{}
265         m.hasSignatureAlgorithm = true
266         m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
267         m.signature = randomBytes(rand.Intn(15)+1, rand)
268         return reflect.ValueOf(m)
269 }
270
271 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
272         m := &certificateStatusMsg{}
273         if rand.Intn(10) > 5 {
274                 m.statusType = statusTypeOCSP
275                 m.response = randomBytes(rand.Intn(10)+1, rand)
276         } else {
277                 m.statusType = 42
278         }
279         return reflect.ValueOf(m)
280 }
281
282 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
283         m := &clientKeyExchangeMsg{}
284         m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
285         return reflect.ValueOf(m)
286 }
287
288 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
289         m := &finishedMsg{}
290         m.verifyData = randomBytes(12, rand)
291         return reflect.ValueOf(m)
292 }
293
294 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
295         m := &nextProtoMsg{}
296         m.proto = randomString(rand.Intn(255), rand)
297         return reflect.ValueOf(m)
298 }
299
300 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
301         m := &newSessionTicketMsg{}
302         m.ticket = randomBytes(rand.Intn(4), rand)
303         return reflect.ValueOf(m)
304 }
305
306 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
307         s := &sessionState{}
308         s.vers = uint16(rand.Intn(10000))
309         s.cipherSuite = uint16(rand.Intn(10000))
310         s.masterSecret = randomBytes(rand.Intn(100), rand)
311         numCerts := rand.Intn(20)
312         s.certificates = make([][]byte, numCerts)
313         for i := 0; i < numCerts; i++ {
314                 s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
315         }
316         return reflect.ValueOf(s)
317 }
318
319 func TestRejectEmptySCTList(t *testing.T) {
320         // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
321
322         var random [32]byte
323         sct := []byte{0x42, 0x42, 0x42, 0x42}
324         serverHello := serverHelloMsg{
325                 vers:   VersionTLS12,
326                 random: random[:],
327                 scts:   [][]byte{sct},
328         }
329         serverHelloBytes := serverHello.marshal()
330
331         var serverHelloCopy serverHelloMsg
332         if !serverHelloCopy.unmarshal(serverHelloBytes) {
333                 t.Fatal("Failed to unmarshal initial message")
334         }
335
336         // Change serverHelloBytes so that the SCT list is empty
337         i := bytes.Index(serverHelloBytes, sct)
338         if i < 0 {
339                 t.Fatal("Cannot find SCT in ServerHello")
340         }
341
342         var serverHelloEmptySCT []byte
343         serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
344         // Append the extension length and SCT list length for an empty list.
345         serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
346         serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
347
348         // Update the handshake message length.
349         serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
350         serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
351         serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
352
353         // Update the extensions length
354         serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
355         serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
356
357         if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
358                 t.Fatal("Unmarshaled ServerHello with empty SCT list")
359         }
360 }
361
362 func TestRejectEmptySCT(t *testing.T) {
363         // Not only must the SCT list be non-empty, but the SCT elements must
364         // not be zero length.
365
366         var random [32]byte
367         serverHello := serverHelloMsg{
368                 vers:   VersionTLS12,
369                 random: random[:],
370                 scts:   [][]byte{nil},
371         }
372         serverHelloBytes := serverHello.marshal()
373
374         var serverHelloCopy serverHelloMsg
375         if serverHelloCopy.unmarshal(serverHelloBytes) {
376                 t.Fatal("Unmarshaled ServerHello with zero-length SCT")
377         }
378 }