]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/handshake_messages_test.go
[dev.boringcrypto] all: merge master (nearly Go 1.10 beta 1) into dev.boringcrypto
[gostls13.git] / src / crypto / tls / handshake_messages_test.go
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "bytes"
9         "math/rand"
10         "reflect"
11         "strings"
12         "testing"
13         "testing/quick"
14 )
15
16 var tests = []interface{}{
17         &clientHelloMsg{},
18         &serverHelloMsg{},
19         &finishedMsg{},
20
21         &certificateMsg{},
22         &certificateRequestMsg{},
23         &certificateVerifyMsg{},
24         &certificateStatusMsg{},
25         &clientKeyExchangeMsg{},
26         &nextProtoMsg{},
27         &newSessionTicketMsg{},
28         &sessionState{},
29 }
30
31 type testMessage interface {
32         marshal() []byte
33         unmarshal([]byte) bool
34         equal(interface{}) bool
35 }
36
37 func TestMarshalUnmarshal(t *testing.T) {
38         rand := rand.New(rand.NewSource(0))
39
40         for i, iface := range tests {
41                 ty := reflect.ValueOf(iface).Type()
42
43                 n := 100
44                 if testing.Short() {
45                         n = 5
46                 }
47                 for j := 0; j < n; j++ {
48                         v, ok := quick.Value(ty, rand)
49                         if !ok {
50                                 t.Errorf("#%d: failed to create value", i)
51                                 break
52                         }
53
54                         m1 := v.Interface().(testMessage)
55                         marshaled := m1.marshal()
56                         m2 := iface.(testMessage)
57                         if !m2.unmarshal(marshaled) {
58                                 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
59                                 break
60                         }
61                         m2.marshal() // to fill any marshal cache in the message
62
63                         if !m1.equal(m2) {
64                                 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
65                                 break
66                         }
67
68                         if i >= 3 {
69                                 // The first three message types (ClientHello,
70                                 // ServerHello and Finished) are allowed to
71                                 // have parsable prefixes because the extension
72                                 // data is optional and the length of the
73                                 // Finished varies across versions.
74                                 for j := 0; j < len(marshaled); j++ {
75                                         if m2.unmarshal(marshaled[0:j]) {
76                                                 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
77                                                 break
78                                         }
79                                 }
80                         }
81                 }
82         }
83 }
84
85 func TestFuzz(t *testing.T) {
86         rand := rand.New(rand.NewSource(0))
87         for _, iface := range tests {
88                 m := iface.(testMessage)
89
90                 for j := 0; j < 1000; j++ {
91                         len := rand.Intn(100)
92                         bytes := randomBytes(len, rand)
93                         // This just looks for crashes due to bounds errors etc.
94                         m.unmarshal(bytes)
95                 }
96         }
97 }
98
99 func randomBytes(n int, rand *rand.Rand) []byte {
100         r := make([]byte, n)
101         if _, err := rand.Read(r); err != nil {
102                 panic("rand.Read failed: " + err.Error())
103         }
104         return r
105 }
106
107 func randomString(n int, rand *rand.Rand) string {
108         b := randomBytes(n, rand)
109         return string(b)
110 }
111
112 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
113         m := &clientHelloMsg{}
114         m.vers = uint16(rand.Intn(65536))
115         m.random = randomBytes(32, rand)
116         m.sessionId = randomBytes(rand.Intn(32), rand)
117         m.cipherSuites = make([]uint16, rand.Intn(63)+1)
118         for i := 0; i < len(m.cipherSuites); i++ {
119                 cs := uint16(rand.Int31())
120                 if cs == scsvRenegotiation {
121                         cs += 1
122                 }
123                 m.cipherSuites[i] = cs
124         }
125         m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
126         if rand.Intn(10) > 5 {
127                 m.nextProtoNeg = true
128         }
129         if rand.Intn(10) > 5 {
130                 m.serverName = randomString(rand.Intn(255), rand)
131                 for strings.HasSuffix(m.serverName, ".") {
132                         m.serverName = m.serverName[:len(m.serverName)-1]
133                 }
134         }
135         m.ocspStapling = rand.Intn(10) > 5
136         m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
137         m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
138         for i := range m.supportedCurves {
139                 m.supportedCurves[i] = CurveID(rand.Intn(30000))
140         }
141         if rand.Intn(10) > 5 {
142                 m.ticketSupported = true
143                 if rand.Intn(10) > 5 {
144                         m.sessionTicket = randomBytes(rand.Intn(300), rand)
145                 }
146         }
147         if rand.Intn(10) > 5 {
148                 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
149         }
150         m.alpnProtocols = make([]string, rand.Intn(5))
151         for i := range m.alpnProtocols {
152                 m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
153         }
154         if rand.Intn(10) > 5 {
155                 m.scts = true
156         }
157
158         return reflect.ValueOf(m)
159 }
160
161 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
162         m := &serverHelloMsg{}
163         m.vers = uint16(rand.Intn(65536))
164         m.random = randomBytes(32, rand)
165         m.sessionId = randomBytes(rand.Intn(32), rand)
166         m.cipherSuite = uint16(rand.Int31())
167         m.compressionMethod = uint8(rand.Intn(256))
168
169         if rand.Intn(10) > 5 {
170                 m.nextProtoNeg = true
171
172                 n := rand.Intn(10)
173                 m.nextProtos = make([]string, n)
174                 for i := 0; i < n; i++ {
175                         m.nextProtos[i] = randomString(20, rand)
176                 }
177         }
178
179         if rand.Intn(10) > 5 {
180                 m.ocspStapling = true
181         }
182         if rand.Intn(10) > 5 {
183                 m.ticketSupported = true
184         }
185         m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
186
187         if rand.Intn(10) > 5 {
188                 numSCTs := rand.Intn(4)
189                 m.scts = make([][]byte, numSCTs)
190                 for i := range m.scts {
191                         m.scts[i] = randomBytes(rand.Intn(500), rand)
192                 }
193         }
194
195         return reflect.ValueOf(m)
196 }
197
198 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
199         m := &certificateMsg{}
200         numCerts := rand.Intn(20)
201         m.certificates = make([][]byte, numCerts)
202         for i := 0; i < numCerts; i++ {
203                 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
204         }
205         return reflect.ValueOf(m)
206 }
207
208 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
209         m := &certificateRequestMsg{}
210         m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
211         numCAs := rand.Intn(100)
212         m.certificateAuthorities = make([][]byte, numCAs)
213         for i := 0; i < numCAs; i++ {
214                 m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
215         }
216         return reflect.ValueOf(m)
217 }
218
219 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
220         m := &certificateVerifyMsg{}
221         m.signature = randomBytes(rand.Intn(15)+1, rand)
222         return reflect.ValueOf(m)
223 }
224
225 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
226         m := &certificateStatusMsg{}
227         if rand.Intn(10) > 5 {
228                 m.statusType = statusTypeOCSP
229                 m.response = randomBytes(rand.Intn(10)+1, rand)
230         } else {
231                 m.statusType = 42
232         }
233         return reflect.ValueOf(m)
234 }
235
236 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
237         m := &clientKeyExchangeMsg{}
238         m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
239         return reflect.ValueOf(m)
240 }
241
242 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
243         m := &finishedMsg{}
244         m.verifyData = randomBytes(12, rand)
245         return reflect.ValueOf(m)
246 }
247
248 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
249         m := &nextProtoMsg{}
250         m.proto = randomString(rand.Intn(255), rand)
251         return reflect.ValueOf(m)
252 }
253
254 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
255         m := &newSessionTicketMsg{}
256         m.ticket = randomBytes(rand.Intn(4), rand)
257         return reflect.ValueOf(m)
258 }
259
260 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
261         s := &sessionState{}
262         s.vers = uint16(rand.Intn(10000))
263         s.cipherSuite = uint16(rand.Intn(10000))
264         s.masterSecret = randomBytes(rand.Intn(100), rand)
265         numCerts := rand.Intn(20)
266         s.certificates = make([][]byte, numCerts)
267         for i := 0; i < numCerts; i++ {
268                 s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
269         }
270         return reflect.ValueOf(s)
271 }
272
273 func TestRejectEmptySCTList(t *testing.T) {
274         // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
275         // empty SCT lists are invalid.
276
277         var random [32]byte
278         sct := []byte{0x42, 0x42, 0x42, 0x42}
279         serverHello := serverHelloMsg{
280                 vers:   VersionTLS12,
281                 random: random[:],
282                 scts:   [][]byte{sct},
283         }
284         serverHelloBytes := serverHello.marshal()
285
286         var serverHelloCopy serverHelloMsg
287         if !serverHelloCopy.unmarshal(serverHelloBytes) {
288                 t.Fatal("Failed to unmarshal initial message")
289         }
290
291         // Change serverHelloBytes so that the SCT list is empty
292         i := bytes.Index(serverHelloBytes, sct)
293         if i < 0 {
294                 t.Fatal("Cannot find SCT in ServerHello")
295         }
296
297         var serverHelloEmptySCT []byte
298         serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
299         // Append the extension length and SCT list length for an empty list.
300         serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
301         serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
302
303         // Update the handshake message length.
304         serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
305         serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
306         serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
307
308         // Update the extensions length
309         serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
310         serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
311
312         if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
313                 t.Fatal("Unmarshaled ServerHello with empty SCT list")
314         }
315 }
316
317 func TestRejectEmptySCT(t *testing.T) {
318         // Not only must the SCT list be non-empty, but the SCT elements must
319         // not be zero length.
320
321         var random [32]byte
322         serverHello := serverHelloMsg{
323                 vers:   VersionTLS12,
324                 random: random[:],
325                 scts:   [][]byte{nil},
326         }
327         serverHelloBytes := serverHello.marshal()
328
329         var serverHelloCopy serverHelloMsg
330         if serverHelloCopy.unmarshal(serverHelloBytes) {
331                 t.Fatal("Unmarshaled ServerHello with zero-length SCT")
332         }
333 }