]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/tls/quic_test.go
crypto/tls: support QUIC as a transport
[gostls13.git] / src / crypto / tls / quic_test.go
1 // Copyright 2023 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package tls
6
7 import (
8         "context"
9         "errors"
10         "reflect"
11         "testing"
12 )
13
14 type testQUICConn struct {
15         t           *testing.T
16         conn        *QUICConn
17         readSecret  map[QUICEncryptionLevel]suiteSecret
18         writeSecret map[QUICEncryptionLevel]suiteSecret
19         gotParams   []byte
20         complete    bool
21 }
22
23 func newTestQUICClient(t *testing.T, config *Config) *testQUICConn {
24         q := &testQUICConn{t: t}
25         q.conn = QUICClient(&QUICConfig{
26                 TLSConfig: config,
27         })
28         t.Cleanup(func() {
29                 q.conn.Close()
30         })
31         return q
32 }
33
34 func newTestQUICServer(t *testing.T, config *Config) *testQUICConn {
35         q := &testQUICConn{t: t}
36         q.conn = QUICServer(&QUICConfig{
37                 TLSConfig: config,
38         })
39         t.Cleanup(func() {
40                 q.conn.Close()
41         })
42         return q
43 }
44
45 type suiteSecret struct {
46         suite  uint16
47         secret []byte
48 }
49
50 func (q *testQUICConn) setReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
51         if _, ok := q.writeSecret[level]; !ok {
52                 q.t.Errorf("SetReadSecret for level %v called before SetWriteSecret", level)
53         }
54         if level == QUICEncryptionLevelApplication && !q.complete {
55                 q.t.Errorf("SetReadSecret for level %v called before HandshakeComplete", level)
56         }
57         if _, ok := q.readSecret[level]; ok {
58                 q.t.Errorf("SetReadSecret for level %v called twice", level)
59         }
60         if q.readSecret == nil {
61                 q.readSecret = map[QUICEncryptionLevel]suiteSecret{}
62         }
63         switch level {
64         case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
65                 q.readSecret[level] = suiteSecret{suite, secret}
66         default:
67                 q.t.Errorf("SetReadSecret for unexpected level %v", level)
68         }
69 }
70
71 func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
72         if _, ok := q.writeSecret[level]; ok {
73                 q.t.Errorf("SetWriteSecret for level %v called twice", level)
74         }
75         if q.writeSecret == nil {
76                 q.writeSecret = map[QUICEncryptionLevel]suiteSecret{}
77         }
78         switch level {
79         case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
80                 q.writeSecret[level] = suiteSecret{suite, secret}
81         default:
82                 q.t.Errorf("SetWriteSecret for unexpected level %v", level)
83         }
84 }
85
86 var errTransportParametersRequired = errors.New("transport parameters required")
87
88 func runTestQUICConnection(ctx context.Context, a, b *testQUICConn, onHandleCryptoData func()) error {
89         for _, c := range []*testQUICConn{a, b} {
90                 if !c.conn.conn.quic.started {
91                         if err := c.conn.Start(ctx); err != nil {
92                                 return err
93                         }
94                 }
95         }
96         idleCount := 0
97         for {
98                 e := a.conn.NextEvent()
99                 switch e.Kind {
100                 case QUICNoEvent:
101                         idleCount++
102                         if idleCount == 2 {
103                                 if !a.complete || !b.complete {
104                                         return errors.New("handshake incomplete")
105                                 }
106                                 return nil
107                         }
108                         a, b = b, a
109                 case QUICSetReadSecret:
110                         a.setReadSecret(e.Level, e.Suite, e.Data)
111                 case QUICSetWriteSecret:
112                         a.setWriteSecret(e.Level, e.Suite, e.Data)
113                 case QUICWriteData:
114                         if err := b.conn.HandleData(e.Level, e.Data); err != nil {
115                                 return err
116                         }
117                 case QUICTransportParameters:
118                         a.gotParams = e.Data
119                         if a.gotParams == nil {
120                                 a.gotParams = []byte{}
121                         }
122                 case QUICTransportParametersRequired:
123                         return errTransportParametersRequired
124                 case QUICHandshakeDone:
125                         a.complete = true
126                 }
127                 if e.Kind != QUICNoEvent {
128                         idleCount = 0
129                 }
130         }
131 }
132
133 func TestQUICConnection(t *testing.T) {
134         config := testConfig.Clone()
135         config.MinVersion = VersionTLS13
136
137         cli := newTestQUICClient(t, config)
138         cli.conn.SetTransportParameters(nil)
139
140         srv := newTestQUICServer(t, config)
141         srv.conn.SetTransportParameters(nil)
142
143         if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
144                 t.Fatalf("error during connection handshake: %v", err)
145         }
146
147         if _, ok := cli.readSecret[QUICEncryptionLevelHandshake]; !ok {
148                 t.Errorf("client has no Handshake secret")
149         }
150         if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; !ok {
151                 t.Errorf("client has no Application secret")
152         }
153         if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; !ok {
154                 t.Errorf("server has no Handshake secret")
155         }
156         if _, ok := srv.readSecret[QUICEncryptionLevelApplication]; !ok {
157                 t.Errorf("server has no Application secret")
158         }
159         for _, level := range []QUICEncryptionLevel{QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication} {
160                 if _, ok := cli.readSecret[level]; !ok {
161                         t.Errorf("client has no %v read secret", level)
162                 }
163                 if _, ok := srv.readSecret[level]; !ok {
164                         t.Errorf("server has no %v read secret", level)
165                 }
166                 if !reflect.DeepEqual(cli.readSecret[level], srv.writeSecret[level]) {
167                         t.Errorf("client read secret does not match server write secret for level %v", level)
168                 }
169                 if !reflect.DeepEqual(cli.writeSecret[level], srv.readSecret[level]) {
170                         t.Errorf("client write secret does not match server read secret for level %v", level)
171                 }
172         }
173 }
174
175 func TestQUICSessionResumption(t *testing.T) {
176         clientConfig := testConfig.Clone()
177         clientConfig.MinVersion = VersionTLS13
178         clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
179         clientConfig.ServerName = "example.go.dev"
180
181         serverConfig := testConfig.Clone()
182         serverConfig.MinVersion = VersionTLS13
183
184         cli := newTestQUICClient(t, clientConfig)
185         cli.conn.SetTransportParameters(nil)
186         srv := newTestQUICServer(t, serverConfig)
187         srv.conn.SetTransportParameters(nil)
188         if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
189                 t.Fatalf("error during first connection handshake: %v", err)
190         }
191         if cli.conn.ConnectionState().DidResume {
192                 t.Errorf("first connection unexpectedly used session resumption")
193         }
194
195         cli2 := newTestQUICClient(t, clientConfig)
196         cli2.conn.SetTransportParameters(nil)
197         srv2 := newTestQUICServer(t, serverConfig)
198         srv2.conn.SetTransportParameters(nil)
199         if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil {
200                 t.Fatalf("error during second connection handshake: %v", err)
201         }
202         if !cli2.conn.ConnectionState().DidResume {
203                 t.Errorf("second connection did not use session resumption")
204         }
205 }
206
207 func TestQUICPostHandshakeClientAuthentication(t *testing.T) {
208         // RFC 9001, Section 4.4.
209         config := testConfig.Clone()
210         config.MinVersion = VersionTLS13
211         cli := newTestQUICClient(t, config)
212         cli.conn.SetTransportParameters(nil)
213         srv := newTestQUICServer(t, config)
214         srv.conn.SetTransportParameters(nil)
215         if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
216                 t.Fatalf("error during connection handshake: %v", err)
217         }
218
219         certReq := new(certificateRequestMsgTLS13)
220         certReq.ocspStapling = true
221         certReq.scts = true
222         certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
223         certReqBytes, err := certReq.marshal()
224         if err != nil {
225                 t.Fatal(err)
226         }
227         if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
228                 byte(typeCertificateRequest),
229                 byte(0), byte(0), byte(len(certReqBytes)),
230         }, certReqBytes...)); err == nil {
231                 t.Fatalf("post-handshake authentication request: got no error, want one")
232         }
233 }
234
235 func TestQUICPostHandshakeKeyUpdate(t *testing.T) {
236         // RFC 9001, Section 6.
237         config := testConfig.Clone()
238         config.MinVersion = VersionTLS13
239         cli := newTestQUICClient(t, config)
240         cli.conn.SetTransportParameters(nil)
241         srv := newTestQUICServer(t, config)
242         srv.conn.SetTransportParameters(nil)
243         if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
244                 t.Fatalf("error during connection handshake: %v", err)
245         }
246
247         keyUpdate := new(keyUpdateMsg)
248         keyUpdateBytes, err := keyUpdate.marshal()
249         if err != nil {
250                 t.Fatal(err)
251         }
252         if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
253                 byte(typeKeyUpdate),
254                 byte(0), byte(0), byte(len(keyUpdateBytes)),
255         }, keyUpdateBytes...)); !errors.Is(err, alertUnexpectedMessage) {
256                 t.Fatalf("key update request: got error %v, want alertUnexpectedMessage", err)
257         }
258 }
259
260 func TestQUICHandshakeError(t *testing.T) {
261         clientConfig := testConfig.Clone()
262         clientConfig.MinVersion = VersionTLS13
263         clientConfig.InsecureSkipVerify = false
264         clientConfig.ServerName = "name"
265
266         serverConfig := testConfig.Clone()
267         serverConfig.MinVersion = VersionTLS13
268
269         cli := newTestQUICClient(t, clientConfig)
270         cli.conn.SetTransportParameters(nil)
271         srv := newTestQUICServer(t, serverConfig)
272         srv.conn.SetTransportParameters(nil)
273         err := runTestQUICConnection(context.Background(), cli, srv, nil)
274         if !errors.Is(err, AlertError(alertBadCertificate)) {
275                 t.Errorf("connection handshake terminated with error %q, want alertBadCertificate", err)
276         }
277         var e *CertificateVerificationError
278         if !errors.As(err, &e) {
279                 t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err)
280         }
281 }
282
283 // Test that QUICConn.ConnectionState can be used during the handshake,
284 // and that it reports the application protocol as soon as it has been
285 // negotiated.
286 func TestQUICConnectionState(t *testing.T) {
287         config := testConfig.Clone()
288         config.MinVersion = VersionTLS13
289         config.NextProtos = []string{"h3"}
290         cli := newTestQUICClient(t, config)
291         cli.conn.SetTransportParameters(nil)
292         srv := newTestQUICServer(t, config)
293         srv.conn.SetTransportParameters(nil)
294         onHandleCryptoData := func() {
295                 cliCS := cli.conn.ConnectionState()
296                 cliWantALPN := ""
297                 if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok {
298                         cliWantALPN = "h3"
299                 }
300                 if want, got := cliCS.NegotiatedProtocol, cliWantALPN; want != got {
301                         t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
302                 }
303
304                 srvCS := srv.conn.ConnectionState()
305                 srvWantALPN := ""
306                 if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok {
307                         srvWantALPN = "h3"
308                 }
309                 if want, got := srvCS.NegotiatedProtocol, srvWantALPN; want != got {
310                         t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
311                 }
312         }
313         if err := runTestQUICConnection(context.Background(), cli, srv, onHandleCryptoData); err != nil {
314                 t.Fatalf("error during connection handshake: %v", err)
315         }
316 }
317
318 func TestQUICStartContextPropagation(t *testing.T) {
319         const key = "key"
320         const value = "value"
321         ctx := context.WithValue(context.Background(), key, value)
322         config := testConfig.Clone()
323         config.MinVersion = VersionTLS13
324         calls := 0
325         config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) {
326                 calls++
327                 got, _ := info.Context().Value(key).(string)
328                 if got != value {
329                         t.Errorf("GetConfigForClient context key %q has value %q, want %q", key, got, value)
330                 }
331                 return nil, nil
332         }
333         cli := newTestQUICClient(t, config)
334         cli.conn.SetTransportParameters(nil)
335         srv := newTestQUICServer(t, config)
336         srv.conn.SetTransportParameters(nil)
337         if err := runTestQUICConnection(ctx, cli, srv, nil); err != nil {
338                 t.Fatalf("error during connection handshake: %v", err)
339         }
340         if calls != 1 {
341                 t.Errorf("GetConfigForClient called %v times, want 1", calls)
342         }
343 }
344
345 func TestQUICDelayedTransportParameters(t *testing.T) {
346         clientConfig := testConfig.Clone()
347         clientConfig.MinVersion = VersionTLS13
348         clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
349         clientConfig.ServerName = "example.go.dev"
350
351         serverConfig := testConfig.Clone()
352         serverConfig.MinVersion = VersionTLS13
353
354         cliParams := "client params"
355         srvParams := "server params"
356
357         cli := newTestQUICClient(t, clientConfig)
358         srv := newTestQUICServer(t, serverConfig)
359         if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
360                 t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err)
361         }
362         cli.conn.SetTransportParameters([]byte(cliParams))
363         if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
364                 t.Fatalf("handshake with no server parameters: %v; want errTransportParametersRequired", err)
365         }
366         srv.conn.SetTransportParameters([]byte(srvParams))
367         if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
368                 t.Fatalf("error during connection handshake: %v", err)
369         }
370
371         if got, want := string(cli.gotParams), srvParams; got != want {
372                 t.Errorf("client got transport params: %q, want %q", got, want)
373         }
374         if got, want := string(srv.gotParams), cliParams; got != want {
375                 t.Errorf("server got transport params: %q, want %q", got, want)
376         }
377 }
378
379 func TestQUICEmptyTransportParameters(t *testing.T) {
380         config := testConfig.Clone()
381         config.MinVersion = VersionTLS13
382
383         cli := newTestQUICClient(t, config)
384         cli.conn.SetTransportParameters(nil)
385         srv := newTestQUICServer(t, config)
386         srv.conn.SetTransportParameters(nil)
387         if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
388                 t.Fatalf("error during connection handshake: %v", err)
389         }
390
391         if cli.gotParams == nil {
392                 t.Errorf("client did not get transport params")
393         }
394         if srv.gotParams == nil {
395                 t.Errorf("server did not get transport params")
396         }
397         if len(cli.gotParams) != 0 {
398                 t.Errorf("client got transport params: %v, want empty", cli.gotParams)
399         }
400         if len(srv.gotParams) != 0 {
401                 t.Errorf("server got transport params: %v, want empty", srv.gotParams)
402         }
403 }
404
405 func TestQUICCanceledWaitingForData(t *testing.T) {
406         config := testConfig.Clone()
407         config.MinVersion = VersionTLS13
408         cli := newTestQUICClient(t, config)
409         cli.conn.SetTransportParameters(nil)
410         cli.conn.Start(context.Background())
411         for cli.conn.NextEvent().Kind != QUICNoEvent {
412         }
413         err := cli.conn.Close()
414         if !errors.Is(err, alertCloseNotify) {
415                 t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
416         }
417 }
418
419 func TestQUICCanceledWaitingForTransportParams(t *testing.T) {
420         config := testConfig.Clone()
421         config.MinVersion = VersionTLS13
422         cli := newTestQUICClient(t, config)
423         cli.conn.Start(context.Background())
424         for cli.conn.NextEvent().Kind != QUICTransportParametersRequired {
425         }
426         err := cli.conn.Close()
427         if !errors.Is(err, alertCloseNotify) {
428                 t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
429         }
430 }