1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
16 var tests = []interface{}{
22 &certificateRequestMsg{},
23 &certificateVerifyMsg{},
24 &certificateStatusMsg{},
25 &clientKeyExchangeMsg{},
27 &newSessionTicketMsg{},
31 type testMessage interface {
33 unmarshal([]byte) bool
34 equal(interface{}) bool
37 func TestMarshalUnmarshal(t *testing.T) {
38 rand := rand.New(rand.NewSource(0))
40 for i, iface := range tests {
41 ty := reflect.ValueOf(iface).Type()
47 for j := 0; j < n; j++ {
48 v, ok := quick.Value(ty, rand)
50 t.Errorf("#%d: failed to create value", i)
54 m1 := v.Interface().(testMessage)
55 marshaled := m1.marshal()
56 m2 := iface.(testMessage)
57 if !m2.unmarshal(marshaled) {
58 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
61 m2.marshal() // to fill any marshal cache in the message
64 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
69 // The first three message types (ClientHello,
70 // ServerHello and Finished) are allowed to
71 // have parsable prefixes because the extension
72 // data is optional and the length of the
73 // Finished varies across versions.
74 for j := 0; j < len(marshaled); j++ {
75 if m2.unmarshal(marshaled[0:j]) {
76 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
85 func TestFuzz(t *testing.T) {
86 rand := rand.New(rand.NewSource(0))
87 for _, iface := range tests {
88 m := iface.(testMessage)
90 for j := 0; j < 1000; j++ {
92 bytes := randomBytes(len, rand)
93 // This just looks for crashes due to bounds errors etc.
99 func randomBytes(n int, rand *rand.Rand) []byte {
101 if _, err := rand.Read(r); err != nil {
102 panic("rand.Read failed: " + err.Error())
107 func randomString(n int, rand *rand.Rand) string {
108 b := randomBytes(n, rand)
112 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
113 m := &clientHelloMsg{}
114 m.vers = uint16(rand.Intn(65536))
115 m.random = randomBytes(32, rand)
116 m.sessionId = randomBytes(rand.Intn(32), rand)
117 m.cipherSuites = make([]uint16, rand.Intn(63)+1)
118 for i := 0; i < len(m.cipherSuites); i++ {
119 cs := uint16(rand.Int31())
120 if cs == scsvRenegotiation {
123 m.cipherSuites[i] = cs
125 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
126 if rand.Intn(10) > 5 {
127 m.nextProtoNeg = true
129 if rand.Intn(10) > 5 {
130 m.serverName = randomString(rand.Intn(255), rand)
131 for strings.HasSuffix(m.serverName, ".") {
132 m.serverName = m.serverName[:len(m.serverName)-1]
135 m.ocspStapling = rand.Intn(10) > 5
136 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
137 m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
138 for i := range m.supportedCurves {
139 m.supportedCurves[i] = CurveID(rand.Intn(30000))
141 if rand.Intn(10) > 5 {
142 m.ticketSupported = true
143 if rand.Intn(10) > 5 {
144 m.sessionTicket = randomBytes(rand.Intn(300), rand)
147 if rand.Intn(10) > 5 {
148 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
150 m.alpnProtocols = make([]string, rand.Intn(5))
151 for i := range m.alpnProtocols {
152 m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
154 if rand.Intn(10) > 5 {
158 return reflect.ValueOf(m)
161 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
162 m := &serverHelloMsg{}
163 m.vers = uint16(rand.Intn(65536))
164 m.random = randomBytes(32, rand)
165 m.sessionId = randomBytes(rand.Intn(32), rand)
166 m.cipherSuite = uint16(rand.Int31())
167 m.compressionMethod = uint8(rand.Intn(256))
169 if rand.Intn(10) > 5 {
170 m.nextProtoNeg = true
173 m.nextProtos = make([]string, n)
174 for i := 0; i < n; i++ {
175 m.nextProtos[i] = randomString(20, rand)
179 if rand.Intn(10) > 5 {
180 m.ocspStapling = true
182 if rand.Intn(10) > 5 {
183 m.ticketSupported = true
185 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
187 if rand.Intn(10) > 5 {
188 numSCTs := rand.Intn(4)
189 m.scts = make([][]byte, numSCTs)
190 for i := range m.scts {
191 m.scts[i] = randomBytes(rand.Intn(500), rand)
195 return reflect.ValueOf(m)
198 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
199 m := &certificateMsg{}
200 numCerts := rand.Intn(20)
201 m.certificates = make([][]byte, numCerts)
202 for i := 0; i < numCerts; i++ {
203 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
205 return reflect.ValueOf(m)
208 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
209 m := &certificateRequestMsg{}
210 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
211 numCAs := rand.Intn(100)
212 m.certificateAuthorities = make([][]byte, numCAs)
213 for i := 0; i < numCAs; i++ {
214 m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
216 return reflect.ValueOf(m)
219 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
220 m := &certificateVerifyMsg{}
221 m.signature = randomBytes(rand.Intn(15)+1, rand)
222 return reflect.ValueOf(m)
225 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
226 m := &certificateStatusMsg{}
227 if rand.Intn(10) > 5 {
228 m.statusType = statusTypeOCSP
229 m.response = randomBytes(rand.Intn(10)+1, rand)
233 return reflect.ValueOf(m)
236 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
237 m := &clientKeyExchangeMsg{}
238 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
239 return reflect.ValueOf(m)
242 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
244 m.verifyData = randomBytes(12, rand)
245 return reflect.ValueOf(m)
248 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
250 m.proto = randomString(rand.Intn(255), rand)
251 return reflect.ValueOf(m)
254 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
255 m := &newSessionTicketMsg{}
256 m.ticket = randomBytes(rand.Intn(4), rand)
257 return reflect.ValueOf(m)
260 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
262 s.vers = uint16(rand.Intn(10000))
263 s.cipherSuite = uint16(rand.Intn(10000))
264 s.masterSecret = randomBytes(rand.Intn(100), rand)
265 numCerts := rand.Intn(20)
266 s.certificates = make([][]byte, numCerts)
267 for i := 0; i < numCerts; i++ {
268 s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
270 return reflect.ValueOf(s)
273 func TestRejectEmptySCTList(t *testing.T) {
274 // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
275 // empty SCT lists are invalid.
278 sct := []byte{0x42, 0x42, 0x42, 0x42}
279 serverHello := serverHelloMsg{
284 serverHelloBytes := serverHello.marshal()
286 var serverHelloCopy serverHelloMsg
287 if !serverHelloCopy.unmarshal(serverHelloBytes) {
288 t.Fatal("Failed to unmarshal initial message")
291 // Change serverHelloBytes so that the SCT list is empty
292 i := bytes.Index(serverHelloBytes, sct)
294 t.Fatal("Cannot find SCT in ServerHello")
297 var serverHelloEmptySCT []byte
298 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
299 // Append the extension length and SCT list length for an empty list.
300 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
301 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
303 // Update the handshake message length.
304 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
305 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
306 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
308 // Update the extensions length
309 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
310 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
312 if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
313 t.Fatal("Unmarshaled ServerHello with empty SCT list")
317 func TestRejectEmptySCT(t *testing.T) {
318 // Not only must the SCT list be non-empty, but the SCT elements must
319 // not be zero length.
322 serverHello := serverHelloMsg{
327 serverHelloBytes := serverHello.marshal()
329 var serverHelloCopy serverHelloMsg
330 if serverHelloCopy.unmarshal(serverHelloBytes) {
331 t.Fatal("Unmarshaled ServerHello with zero-length SCT")