]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/handshake_messages_test.go
[dev.boringcrypto] all: merge master into dev.boringcrypto
[gostls13.git] / src / crypto / tls / handshake_messages_test.go
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "math/rand"
10         "reflect"
11         "strings"
12         "testing"
13         "testing/quick"
14         "time"
15 )
16
17 var tests = []interface{}{
18         &clientHelloMsg{},
19         &serverHelloMsg{},
20         &finishedMsg{},
21
22         &certificateMsg{},
23         &certificateRequestMsg{},
24         &certificateVerifyMsg{
25                 hasSignatureAlgorithm: true,
26         },
27         &certificateStatusMsg{},
28         &clientKeyExchangeMsg{},
29         &newSessionTicketMsg{},
30         &sessionState{},
31         &sessionStateTLS13{},
32         &encryptedExtensionsMsg{},
33         &endOfEarlyDataMsg{},
34         &keyUpdateMsg{},
35         &newSessionTicketMsgTLS13{},
36         &certificateRequestMsgTLS13{},
37         &certificateMsgTLS13{},
38 }
39
40 func TestMarshalUnmarshal(t *testing.T) {
41         rand := rand.New(rand.NewSource(time.Now().UnixNano()))
42
43         for i, iface := range tests {
44                 ty := reflect.ValueOf(iface).Type()
45
46                 n := 100
47                 if testing.Short() {
48                         n = 5
49                 }
50                 for j := 0; j < n; j++ {
51                         v, ok := quick.Value(ty, rand)
52                         if !ok {
53                                 t.Errorf("#%d: failed to create value", i)
54                                 break
55                         }
56
57                         m1 := v.Interface().(handshakeMessage)
58                         marshaled := m1.marshal()
59                         m2 := iface.(handshakeMessage)
60                         if !m2.unmarshal(marshaled) {
61                                 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
62                                 break
63                         }
64                         m2.marshal() // to fill any marshal cache in the message
65
66                         if !reflect.DeepEqual(m1, m2) {
67                                 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
68                                 break
69                         }
70
71                         if i >= 3 {
72                                 // The first three message types (ClientHello,
73                                 // ServerHello and Finished) are allowed to
74                                 // have parsable prefixes because the extension
75                                 // data is optional and the length of the
76                                 // Finished varies across versions.
77                                 for j := 0; j < len(marshaled); j++ {
78                                         if m2.unmarshal(marshaled[0:j]) {
79                                                 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
80                                                 break
81                                         }
82                                 }
83                         }
84                 }
85         }
86 }
87
88 func TestFuzz(t *testing.T) {
89         rand := rand.New(rand.NewSource(0))
90         for _, iface := range tests {
91                 m := iface.(handshakeMessage)
92
93                 for j := 0; j < 1000; j++ {
94                         len := rand.Intn(100)
95                         bytes := randomBytes(len, rand)
96                         // This just looks for crashes due to bounds errors etc.
97                         m.unmarshal(bytes)
98                 }
99         }
100 }
101
102 func randomBytes(n int, rand *rand.Rand) []byte {
103         r := make([]byte, n)
104         if _, err := rand.Read(r); err != nil {
105                 panic("rand.Read failed: " + err.Error())
106         }
107         return r
108 }
109
110 func randomString(n int, rand *rand.Rand) string {
111         b := randomBytes(n, rand)
112         return string(b)
113 }
114
115 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
116         m := &clientHelloMsg{}
117         m.vers = uint16(rand.Intn(65536))
118         m.random = randomBytes(32, rand)
119         m.sessionId = randomBytes(rand.Intn(32), rand)
120         m.cipherSuites = make([]uint16, rand.Intn(63)+1)
121         for i := 0; i < len(m.cipherSuites); i++ {
122                 cs := uint16(rand.Int31())
123                 if cs == scsvRenegotiation {
124                         cs += 1
125                 }
126                 m.cipherSuites[i] = cs
127         }
128         m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
129         if rand.Intn(10) > 5 {
130                 m.serverName = randomString(rand.Intn(255), rand)
131                 for strings.HasSuffix(m.serverName, ".") {
132                         m.serverName = m.serverName[:len(m.serverName)-1]
133                 }
134         }
135         m.ocspStapling = rand.Intn(10) > 5
136         m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
137         m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
138         for i := range m.supportedCurves {
139                 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
140         }
141         if rand.Intn(10) > 5 {
142                 m.ticketSupported = true
143                 if rand.Intn(10) > 5 {
144                         m.sessionTicket = randomBytes(rand.Intn(300), rand)
145                 } else {
146                         m.sessionTicket = make([]byte, 0)
147                 }
148         }
149         if rand.Intn(10) > 5 {
150                 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
151         }
152         if rand.Intn(10) > 5 {
153                 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
154         }
155         for i := 0; i < rand.Intn(5); i++ {
156                 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
157         }
158         if rand.Intn(10) > 5 {
159                 m.scts = true
160         }
161         if rand.Intn(10) > 5 {
162                 m.secureRenegotiationSupported = true
163                 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
164         }
165         for i := 0; i < rand.Intn(5); i++ {
166                 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
167         }
168         if rand.Intn(10) > 5 {
169                 m.cookie = randomBytes(rand.Intn(500)+1, rand)
170         }
171         for i := 0; i < rand.Intn(5); i++ {
172                 var ks keyShare
173                 ks.group = CurveID(rand.Intn(30000) + 1)
174                 ks.data = randomBytes(rand.Intn(200)+1, rand)
175                 m.keyShares = append(m.keyShares, ks)
176         }
177         switch rand.Intn(3) {
178         case 1:
179                 m.pskModes = []uint8{pskModeDHE}
180         case 2:
181                 m.pskModes = []uint8{pskModeDHE, pskModePlain}
182         }
183         for i := 0; i < rand.Intn(5); i++ {
184                 var psk pskIdentity
185                 psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
186                 psk.label = randomBytes(rand.Intn(500)+1, rand)
187                 m.pskIdentities = append(m.pskIdentities, psk)
188                 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
189         }
190         if rand.Intn(10) > 5 {
191                 m.earlyData = true
192         }
193
194         return reflect.ValueOf(m)
195 }
196
197 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
198         m := &serverHelloMsg{}
199         m.vers = uint16(rand.Intn(65536))
200         m.random = randomBytes(32, rand)
201         m.sessionId = randomBytes(rand.Intn(32), rand)
202         m.cipherSuite = uint16(rand.Int31())
203         m.compressionMethod = uint8(rand.Intn(256))
204         m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
205
206         if rand.Intn(10) > 5 {
207                 m.ocspStapling = true
208         }
209         if rand.Intn(10) > 5 {
210                 m.ticketSupported = true
211         }
212         if rand.Intn(10) > 5 {
213                 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
214         }
215
216         for i := 0; i < rand.Intn(4); i++ {
217                 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
218         }
219
220         if rand.Intn(10) > 5 {
221                 m.secureRenegotiationSupported = true
222                 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
223         }
224         if rand.Intn(10) > 5 {
225                 m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
226         }
227         if rand.Intn(10) > 5 {
228                 m.cookie = randomBytes(rand.Intn(500)+1, rand)
229         }
230         if rand.Intn(10) > 5 {
231                 for i := 0; i < rand.Intn(5); i++ {
232                         m.serverShare.group = CurveID(rand.Intn(30000) + 1)
233                         m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
234                 }
235         } else if rand.Intn(10) > 5 {
236                 m.selectedGroup = CurveID(rand.Intn(30000) + 1)
237         }
238         if rand.Intn(10) > 5 {
239                 m.selectedIdentityPresent = true
240                 m.selectedIdentity = uint16(rand.Intn(0xffff))
241         }
242
243         return reflect.ValueOf(m)
244 }
245
246 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
247         m := &encryptedExtensionsMsg{}
248
249         if rand.Intn(10) > 5 {
250                 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
251         }
252
253         return reflect.ValueOf(m)
254 }
255
256 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
257         m := &certificateMsg{}
258         numCerts := rand.Intn(20)
259         m.certificates = make([][]byte, numCerts)
260         for i := 0; i < numCerts; i++ {
261                 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
262         }
263         return reflect.ValueOf(m)
264 }
265
266 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
267         m := &certificateRequestMsg{}
268         m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
269         for i := 0; i < rand.Intn(100); i++ {
270                 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
271         }
272         return reflect.ValueOf(m)
273 }
274
275 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
276         m := &certificateVerifyMsg{}
277         m.hasSignatureAlgorithm = true
278         m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
279         m.signature = randomBytes(rand.Intn(15)+1, rand)
280         return reflect.ValueOf(m)
281 }
282
283 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
284         m := &certificateStatusMsg{}
285         m.response = randomBytes(rand.Intn(10)+1, rand)
286         return reflect.ValueOf(m)
287 }
288
289 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
290         m := &clientKeyExchangeMsg{}
291         m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
292         return reflect.ValueOf(m)
293 }
294
295 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
296         m := &finishedMsg{}
297         m.verifyData = randomBytes(12, rand)
298         return reflect.ValueOf(m)
299 }
300
301 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
302         m := &newSessionTicketMsg{}
303         m.ticket = randomBytes(rand.Intn(4), rand)
304         return reflect.ValueOf(m)
305 }
306
307 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
308         s := &sessionState{}
309         s.vers = uint16(rand.Intn(10000))
310         s.cipherSuite = uint16(rand.Intn(10000))
311         s.masterSecret = randomBytes(rand.Intn(100)+1, rand)
312         s.createdAt = uint64(rand.Int63())
313         for i := 0; i < rand.Intn(20); i++ {
314                 s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand))
315         }
316         return reflect.ValueOf(s)
317 }
318
319 func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
320         s := &sessionStateTLS13{}
321         s.cipherSuite = uint16(rand.Intn(10000))
322         s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
323         s.createdAt = uint64(rand.Int63())
324         for i := 0; i < rand.Intn(2)+1; i++ {
325                 s.certificate.Certificate = append(
326                         s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
327         }
328         if rand.Intn(10) > 5 {
329                 s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
330         }
331         if rand.Intn(10) > 5 {
332                 for i := 0; i < rand.Intn(2)+1; i++ {
333                         s.certificate.SignedCertificateTimestamps = append(
334                                 s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
335                 }
336         }
337         return reflect.ValueOf(s)
338 }
339
340 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
341         m := &endOfEarlyDataMsg{}
342         return reflect.ValueOf(m)
343 }
344
345 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
346         m := &keyUpdateMsg{}
347         m.updateRequested = rand.Intn(10) > 5
348         return reflect.ValueOf(m)
349 }
350
351 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
352         m := &newSessionTicketMsgTLS13{}
353         m.lifetime = uint32(rand.Intn(500000))
354         m.ageAdd = uint32(rand.Intn(500000))
355         m.nonce = randomBytes(rand.Intn(100), rand)
356         m.label = randomBytes(rand.Intn(1000), rand)
357         if rand.Intn(10) > 5 {
358                 m.maxEarlyData = uint32(rand.Intn(500000))
359         }
360         return reflect.ValueOf(m)
361 }
362
363 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
364         m := &certificateRequestMsgTLS13{}
365         if rand.Intn(10) > 5 {
366                 m.ocspStapling = true
367         }
368         if rand.Intn(10) > 5 {
369                 m.scts = true
370         }
371         if rand.Intn(10) > 5 {
372                 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
373         }
374         if rand.Intn(10) > 5 {
375                 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
376         }
377         if rand.Intn(10) > 5 {
378                 m.certificateAuthorities = make([][]byte, 3)
379                 for i := 0; i < 3; i++ {
380                         m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
381                 }
382         }
383         return reflect.ValueOf(m)
384 }
385
386 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
387         m := &certificateMsgTLS13{}
388         for i := 0; i < rand.Intn(2)+1; i++ {
389                 m.certificate.Certificate = append(
390                         m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
391         }
392         if rand.Intn(10) > 5 {
393                 m.ocspStapling = true
394                 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
395         }
396         if rand.Intn(10) > 5 {
397                 m.scts = true
398                 for i := 0; i < rand.Intn(2)+1; i++ {
399                         m.certificate.SignedCertificateTimestamps = append(
400                                 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
401                 }
402         }
403         return reflect.ValueOf(m)
404 }
405
406 func TestRejectEmptySCTList(t *testing.T) {
407         // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
408
409         var random [32]byte
410         sct := []byte{0x42, 0x42, 0x42, 0x42}
411         serverHello := serverHelloMsg{
412                 vers:   VersionTLS12,
413                 random: random[:],
414                 scts:   [][]byte{sct},
415         }
416         serverHelloBytes := serverHello.marshal()
417
418         var serverHelloCopy serverHelloMsg
419         if !serverHelloCopy.unmarshal(serverHelloBytes) {
420                 t.Fatal("Failed to unmarshal initial message")
421         }
422
423         // Change serverHelloBytes so that the SCT list is empty
424         i := bytes.Index(serverHelloBytes, sct)
425         if i < 0 {
426                 t.Fatal("Cannot find SCT in ServerHello")
427         }
428
429         var serverHelloEmptySCT []byte
430         serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
431         // Append the extension length and SCT list length for an empty list.
432         serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
433         serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
434
435         // Update the handshake message length.
436         serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
437         serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
438         serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
439
440         // Update the extensions length
441         serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
442         serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
443
444         if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
445                 t.Fatal("Unmarshaled ServerHello with empty SCT list")
446         }
447 }
448
449 func TestRejectEmptySCT(t *testing.T) {
450         // Not only must the SCT list be non-empty, but the SCT elements must
451         // not be zero length.
452
453         var random [32]byte
454         serverHello := serverHelloMsg{
455                 vers:   VersionTLS12,
456                 random: random[:],
457                 scts:   [][]byte{nil},
458         }
459         serverHelloBytes := serverHello.marshal()
460
461         var serverHelloCopy serverHelloMsg
462         if serverHelloCopy.unmarshal(serverHelloBytes) {
463                 t.Fatal("Unmarshaled ServerHello with zero-length SCT")
464         }
465 }