]> Cypherpunks.ru repositories - gostls13.git/blob - src/encoding/base64/base64_test.go
encoding: add AppendEncode and AppendDecode
[gostls13.git] / src / encoding / base64 / base64_test.go
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package base64
6
7 import (
8         "bytes"
9         "errors"
10         "fmt"
11         "io"
12         "math"
13         "reflect"
14         "runtime/debug"
15         "strconv"
16         "strings"
17         "testing"
18         "time"
19 )
20
21 type testpair struct {
22         decoded, encoded string
23 }
24
25 var pairs = []testpair{
26         // RFC 3548 examples
27         {"\x14\xfb\x9c\x03\xd9\x7e", "FPucA9l+"},
28         {"\x14\xfb\x9c\x03\xd9", "FPucA9k="},
29         {"\x14\xfb\x9c\x03", "FPucAw=="},
30
31         // RFC 4648 examples
32         {"", ""},
33         {"f", "Zg=="},
34         {"fo", "Zm8="},
35         {"foo", "Zm9v"},
36         {"foob", "Zm9vYg=="},
37         {"fooba", "Zm9vYmE="},
38         {"foobar", "Zm9vYmFy"},
39
40         // Wikipedia examples
41         {"sure.", "c3VyZS4="},
42         {"sure", "c3VyZQ=="},
43         {"sur", "c3Vy"},
44         {"su", "c3U="},
45         {"leasure.", "bGVhc3VyZS4="},
46         {"easure.", "ZWFzdXJlLg=="},
47         {"asure.", "YXN1cmUu"},
48         {"sure.", "c3VyZS4="},
49 }
50
51 // Do nothing to a reference base64 string (leave in standard format)
52 func stdRef(ref string) string {
53         return ref
54 }
55
56 // Convert a reference string to URL-encoding
57 func urlRef(ref string) string {
58         ref = strings.ReplaceAll(ref, "+", "-")
59         ref = strings.ReplaceAll(ref, "/", "_")
60         return ref
61 }
62
63 // Convert a reference string to raw, unpadded format
64 func rawRef(ref string) string {
65         return strings.TrimRight(ref, "=")
66 }
67
68 // Both URL and unpadding conversions
69 func rawURLRef(ref string) string {
70         return rawRef(urlRef(ref))
71 }
72
73 // A nonstandard encoding with a funny padding character, for testing
74 var funnyEncoding = NewEncoding(encodeStd).WithPadding(rune('@'))
75
76 func funnyRef(ref string) string {
77         return strings.ReplaceAll(ref, "=", "@")
78 }
79
80 type encodingTest struct {
81         enc  *Encoding           // Encoding to test
82         conv func(string) string // Reference string converter
83 }
84
85 var encodingTests = []encodingTest{
86         {StdEncoding, stdRef},
87         {URLEncoding, urlRef},
88         {RawStdEncoding, rawRef},
89         {RawURLEncoding, rawURLRef},
90         {funnyEncoding, funnyRef},
91         {StdEncoding.Strict(), stdRef},
92         {URLEncoding.Strict(), urlRef},
93         {RawStdEncoding.Strict(), rawRef},
94         {RawURLEncoding.Strict(), rawURLRef},
95         {funnyEncoding.Strict(), funnyRef},
96 }
97
98 var bigtest = testpair{
99         "Twas brillig, and the slithy toves",
100         "VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==",
101 }
102
103 func testEqual(t *testing.T, msg string, args ...any) bool {
104         t.Helper()
105         if args[len(args)-2] != args[len(args)-1] {
106                 t.Errorf(msg, args...)
107                 return false
108         }
109         return true
110 }
111
112 func TestEncode(t *testing.T) {
113         for _, p := range pairs {
114                 for _, tt := range encodingTests {
115                         got := tt.enc.EncodeToString([]byte(p.decoded))
116                         testEqual(t, "Encode(%q) = %q, want %q", p.decoded, got, tt.conv(p.encoded))
117                         dst := tt.enc.AppendEncode([]byte("lead"), []byte(p.decoded))
118                         testEqual(t, `AppendEncode("lead", %q) = %q, want %q`, p.decoded, string(dst), "lead"+tt.conv(p.encoded))
119                 }
120         }
121 }
122
123 func TestEncoder(t *testing.T) {
124         for _, p := range pairs {
125                 bb := &strings.Builder{}
126                 encoder := NewEncoder(StdEncoding, bb)
127                 encoder.Write([]byte(p.decoded))
128                 encoder.Close()
129                 testEqual(t, "Encode(%q) = %q, want %q", p.decoded, bb.String(), p.encoded)
130         }
131 }
132
133 func TestEncoderBuffering(t *testing.T) {
134         input := []byte(bigtest.decoded)
135         for bs := 1; bs <= 12; bs++ {
136                 bb := &strings.Builder{}
137                 encoder := NewEncoder(StdEncoding, bb)
138                 for pos := 0; pos < len(input); pos += bs {
139                         end := pos + bs
140                         if end > len(input) {
141                                 end = len(input)
142                         }
143                         n, err := encoder.Write(input[pos:end])
144                         testEqual(t, "Write(%q) gave error %v, want %v", input[pos:end], err, error(nil))
145                         testEqual(t, "Write(%q) gave length %v, want %v", input[pos:end], n, end-pos)
146                 }
147                 err := encoder.Close()
148                 testEqual(t, "Close gave error %v, want %v", err, error(nil))
149                 testEqual(t, "Encoding/%d of %q = %q, want %q", bs, bigtest.decoded, bb.String(), bigtest.encoded)
150         }
151 }
152
153 func TestDecode(t *testing.T) {
154         for _, p := range pairs {
155                 for _, tt := range encodingTests {
156                         encoded := tt.conv(p.encoded)
157                         dbuf := make([]byte, tt.enc.DecodedLen(len(encoded)))
158                         count, err := tt.enc.Decode(dbuf, []byte(encoded))
159                         testEqual(t, "Decode(%q) = error %v, want %v", encoded, err, error(nil))
160                         testEqual(t, "Decode(%q) = length %v, want %v", encoded, count, len(p.decoded))
161                         testEqual(t, "Decode(%q) = %q, want %q", encoded, string(dbuf[0:count]), p.decoded)
162
163                         dbuf, err = tt.enc.DecodeString(encoded)
164                         testEqual(t, "DecodeString(%q) = error %v, want %v", encoded, err, error(nil))
165                         testEqual(t, "DecodeString(%q) = %q, want %q", encoded, string(dbuf), p.decoded)
166
167                         dst, err := tt.enc.AppendDecode([]byte("lead"), []byte(encoded))
168                         testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
169                         testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded)
170                 }
171         }
172 }
173
174 func TestDecoder(t *testing.T) {
175         for _, p := range pairs {
176                 decoder := NewDecoder(StdEncoding, strings.NewReader(p.encoded))
177                 dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded)))
178                 count, err := decoder.Read(dbuf)
179                 if err != nil && err != io.EOF {
180                         t.Fatal("Read failed", err)
181                 }
182                 testEqual(t, "Read from %q = length %v, want %v", p.encoded, count, len(p.decoded))
183                 testEqual(t, "Decoding of %q = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded)
184                 if err != io.EOF {
185                         _, err = decoder.Read(dbuf)
186                 }
187                 testEqual(t, "Read from %q = %v, want %v", p.encoded, err, io.EOF)
188         }
189 }
190
191 func TestDecoderBuffering(t *testing.T) {
192         for bs := 1; bs <= 12; bs++ {
193                 decoder := NewDecoder(StdEncoding, strings.NewReader(bigtest.encoded))
194                 buf := make([]byte, len(bigtest.decoded)+12)
195                 var total int
196                 var n int
197                 var err error
198                 for total = 0; total < len(bigtest.decoded) && err == nil; {
199                         n, err = decoder.Read(buf[total : total+bs])
200                         total += n
201                 }
202                 if err != nil && err != io.EOF {
203                         t.Errorf("Read from %q at pos %d = %d, unexpected error %v", bigtest.encoded, total, n, err)
204                 }
205                 testEqual(t, "Decoding/%d of %q = %q, want %q", bs, bigtest.encoded, string(buf[0:total]), bigtest.decoded)
206         }
207 }
208
209 func TestDecodeCorrupt(t *testing.T) {
210         testCases := []struct {
211                 input  string
212                 offset int // -1 means no corruption.
213         }{
214                 {"", -1},
215                 {"\n", -1},
216                 {"AAA=\n", -1},
217                 {"AAAA\n", -1},
218                 {"!!!!", 0},
219                 {"====", 0},
220                 {"x===", 1},
221                 {"=AAA", 0},
222                 {"A=AA", 1},
223                 {"AA=A", 2},
224                 {"AA==A", 4},
225                 {"AAA=AAAA", 4},
226                 {"AAAAA", 4},
227                 {"AAAAAA", 4},
228                 {"A=", 1},
229                 {"A==", 1},
230                 {"AA=", 3},
231                 {"AA==", -1},
232                 {"AAA=", -1},
233                 {"AAAA", -1},
234                 {"AAAAAA=", 7},
235                 {"YWJjZA=====", 8},
236                 {"A!\n", 1},
237                 {"A=\n", 1},
238         }
239         for _, tc := range testCases {
240                 dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input)))
241                 _, err := StdEncoding.Decode(dbuf, []byte(tc.input))
242                 if tc.offset == -1 {
243                         if err != nil {
244                                 t.Error("Decoder wrongly detected corruption in", tc.input)
245                         }
246                         continue
247                 }
248                 switch err := err.(type) {
249                 case CorruptInputError:
250                         testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset)
251                 default:
252                         t.Error("Decoder failed to detect corruption in", tc)
253                 }
254         }
255 }
256
257 func TestDecodeBounds(t *testing.T) {
258         var buf [32]byte
259         s := StdEncoding.EncodeToString(buf[:])
260         defer func() {
261                 if err := recover(); err != nil {
262                         t.Fatalf("Decode panicked unexpectedly: %v\n%s", err, debug.Stack())
263                 }
264         }()
265         n, err := StdEncoding.Decode(buf[:], []byte(s))
266         if n != len(buf) || err != nil {
267                 t.Fatalf("StdEncoding.Decode = %d, %v, want %d, nil", n, err, len(buf))
268         }
269 }
270
271 func TestEncodedLen(t *testing.T) {
272         type test struct {
273                 enc  *Encoding
274                 n    int
275                 want int64
276         }
277         tests := []test{
278                 {RawStdEncoding, 0, 0},
279                 {RawStdEncoding, 1, 2},
280                 {RawStdEncoding, 2, 3},
281                 {RawStdEncoding, 3, 4},
282                 {RawStdEncoding, 7, 10},
283                 {StdEncoding, 0, 0},
284                 {StdEncoding, 1, 4},
285                 {StdEncoding, 2, 4},
286                 {StdEncoding, 3, 4},
287                 {StdEncoding, 4, 8},
288                 {StdEncoding, 7, 12},
289         }
290         // check overflow
291         switch strconv.IntSize {
292         case 32:
293                 tests = append(tests, test{RawStdEncoding, (math.MaxInt-5)/8 + 1, 357913942})
294                 tests = append(tests, test{RawStdEncoding, math.MaxInt/4*3 + 2, math.MaxInt})
295         case 64:
296                 tests = append(tests, test{RawStdEncoding, (math.MaxInt-5)/8 + 1, 1537228672809129302})
297                 tests = append(tests, test{RawStdEncoding, math.MaxInt/4*3 + 2, math.MaxInt})
298         }
299         for _, tt := range tests {
300                 if got := tt.enc.EncodedLen(tt.n); int64(got) != tt.want {
301                         t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want)
302                 }
303         }
304 }
305
306 func TestDecodedLen(t *testing.T) {
307         type test struct {
308                 enc  *Encoding
309                 n    int
310                 want int64
311         }
312         tests := []test{
313                 {RawStdEncoding, 0, 0},
314                 {RawStdEncoding, 2, 1},
315                 {RawStdEncoding, 3, 2},
316                 {RawStdEncoding, 4, 3},
317                 {RawStdEncoding, 10, 7},
318                 {StdEncoding, 0, 0},
319                 {StdEncoding, 4, 3},
320                 {StdEncoding, 8, 6},
321         }
322         // check overflow
323         switch strconv.IntSize {
324         case 32:
325                 tests = append(tests, test{RawStdEncoding, math.MaxInt/6 + 1, 268435456})
326                 tests = append(tests, test{RawStdEncoding, math.MaxInt, 1610612735})
327         case 64:
328                 tests = append(tests, test{RawStdEncoding, math.MaxInt/6 + 1, 1152921504606846976})
329                 tests = append(tests, test{RawStdEncoding, math.MaxInt, 6917529027641081855})
330         }
331         for _, tt := range tests {
332                 if got := tt.enc.DecodedLen(tt.n); int64(got) != tt.want {
333                         t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want)
334                 }
335         }
336 }
337
338 func TestBig(t *testing.T) {
339         n := 3*1000 + 1
340         raw := make([]byte, n)
341         const alpha = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
342         for i := 0; i < n; i++ {
343                 raw[i] = alpha[i%len(alpha)]
344         }
345         encoded := new(bytes.Buffer)
346         w := NewEncoder(StdEncoding, encoded)
347         nn, err := w.Write(raw)
348         if nn != n || err != nil {
349                 t.Fatalf("Encoder.Write(raw) = %d, %v want %d, nil", nn, err, n)
350         }
351         err = w.Close()
352         if err != nil {
353                 t.Fatalf("Encoder.Close() = %v want nil", err)
354         }
355         decoded, err := io.ReadAll(NewDecoder(StdEncoding, encoded))
356         if err != nil {
357                 t.Fatalf("io.ReadAll(NewDecoder(...)): %v", err)
358         }
359
360         if !bytes.Equal(raw, decoded) {
361                 var i int
362                 for i = 0; i < len(decoded) && i < len(raw); i++ {
363                         if decoded[i] != raw[i] {
364                                 break
365                         }
366                 }
367                 t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i)
368         }
369 }
370
371 func TestNewLineCharacters(t *testing.T) {
372         // Each of these should decode to the string "sure", without errors.
373         const expected = "sure"
374         examples := []string{
375                 "c3VyZQ==",
376                 "c3VyZQ==\r",
377                 "c3VyZQ==\n",
378                 "c3VyZQ==\r\n",
379                 "c3VyZ\r\nQ==",
380                 "c3V\ryZ\nQ==",
381                 "c3V\nyZ\rQ==",
382                 "c3VyZ\nQ==",
383                 "c3VyZQ\n==",
384                 "c3VyZQ=\n=",
385                 "c3VyZQ=\r\n\r\n=",
386         }
387         for _, e := range examples {
388                 buf, err := StdEncoding.DecodeString(e)
389                 if err != nil {
390                         t.Errorf("Decode(%q) failed: %v", e, err)
391                         continue
392                 }
393                 if s := string(buf); s != expected {
394                         t.Errorf("Decode(%q) = %q, want %q", e, s, expected)
395                 }
396         }
397 }
398
399 type nextRead struct {
400         n   int   // bytes to return
401         err error // error to return
402 }
403
404 // faultInjectReader returns data from source, rate-limited
405 // and with the errors as written to nextc.
406 type faultInjectReader struct {
407         source string
408         nextc  <-chan nextRead
409 }
410
411 func (r *faultInjectReader) Read(p []byte) (int, error) {
412         nr := <-r.nextc
413         if len(p) > nr.n {
414                 p = p[:nr.n]
415         }
416         n := copy(p, r.source)
417         r.source = r.source[n:]
418         return n, nr.err
419 }
420
421 // tests that we don't ignore errors from our underlying reader
422 func TestDecoderIssue3577(t *testing.T) {
423         next := make(chan nextRead, 10)
424         wantErr := errors.New("my error")
425         next <- nextRead{5, nil}
426         next <- nextRead{10, wantErr}
427         next <- nextRead{0, wantErr}
428         d := NewDecoder(StdEncoding, &faultInjectReader{
429                 source: "VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==", // twas brillig...
430                 nextc:  next,
431         })
432         errc := make(chan error, 1)
433         go func() {
434                 _, err := io.ReadAll(d)
435                 errc <- err
436         }()
437         select {
438         case err := <-errc:
439                 if err != wantErr {
440                         t.Errorf("got error %v; want %v", err, wantErr)
441                 }
442         case <-time.After(5 * time.Second):
443                 t.Errorf("timeout; Decoder blocked without returning an error")
444         }
445 }
446
447 func TestDecoderIssue4779(t *testing.T) {
448         encoded := `CP/EAT8AAAEF
449 AQEBAQEBAAAAAAAAAAMAAQIEBQYHCAkKCwEAAQUBAQEBAQEAAAAAAAAAAQACAwQFBgcICQoLEAAB
450 BAEDAgQCBQcGCAUDDDMBAAIRAwQhEjEFQVFhEyJxgTIGFJGhsUIjJBVSwWIzNHKC0UMHJZJT8OHx
451 Y3M1FqKygyZEk1RkRcKjdDYX0lXiZfKzhMPTdePzRieUpIW0lcTU5PSltcXV5fVWZnaGlqa2xtbm
452 9jdHV2d3h5ent8fX5/cRAAICAQIEBAMEBQYHBwYFNQEAAhEDITESBEFRYXEiEwUygZEUobFCI8FS
453 0fAzJGLhcoKSQ1MVY3M08SUGFqKygwcmNcLSRJNUoxdkRVU2dGXi8rOEw9N14/NGlKSFtJXE1OT0
454 pbXF1eX1VmZ2hpamtsbW5vYnN0dXZ3eHl6e3x//aAAwDAQACEQMRAD8A9VSSSSUpJJJJSkkkJ+Tj
455 1kiy1jCJJDnAcCTykpKkuQ6p/jN6FgmxlNduXawwAzaGH+V6jn/R/wCt71zdn+N/qL3kVYFNYB4N
456 ji6PDVjWpKp9TSXnvTf8bFNjg3qOEa2n6VlLpj/rT/pf567DpX1i6L1hs9Py67X8mqdtg/rUWbbf
457 +gkp0kkkklKSSSSUpJJJJT//0PVUkkklKVLq3WMDpGI7KzrNjADtYNXvI/Mqr/Pd/q9W3vaxjnvM
458 NaCXE9gNSvGPrf8AWS3qmba5jjsJhoB0DAf0NDf6sevf+/lf8Hj0JJATfWT6/dV6oXU1uOLQeKKn
459 EQP+Hubtfe/+R7Mf/g7f5xcocp++Z11JMCJPgFBxOg7/AOuqDx8I/ikpkXkmSdU8mJIJA/O8EMAy
460 j+mSARB/17pKVXYWHXjsj7yIex0PadzXMO1zT5KHoNA3HT8ietoGhgjsfA+CSnvvqh/jJtqsrwOv
461 2b6NGNzXfTYexzJ+nU7/ALkf4P8Awv6P9KvTQQ4AgyDqCF85Pho3CTB7eHwXoH+LT65uZbX9X+o2
462 bqbPb06551Y4
463 `
464         encodedShort := strings.ReplaceAll(encoded, "\n", "")
465
466         dec := NewDecoder(StdEncoding, strings.NewReader(encoded))
467         res1, err := io.ReadAll(dec)
468         if err != nil {
469                 t.Errorf("ReadAll failed: %v", err)
470         }
471
472         dec = NewDecoder(StdEncoding, strings.NewReader(encodedShort))
473         var res2 []byte
474         res2, err = io.ReadAll(dec)
475         if err != nil {
476                 t.Errorf("ReadAll failed: %v", err)
477         }
478
479         if !bytes.Equal(res1, res2) {
480                 t.Error("Decoded results not equal")
481         }
482 }
483
484 func TestDecoderIssue7733(t *testing.T) {
485         s, err := StdEncoding.DecodeString("YWJjZA=====")
486         want := CorruptInputError(8)
487         if !reflect.DeepEqual(want, err) {
488                 t.Errorf("Error = %v; want CorruptInputError(8)", err)
489         }
490         if string(s) != "abcd" {
491                 t.Errorf("DecodeString = %q; want abcd", s)
492         }
493 }
494
495 func TestDecoderIssue15656(t *testing.T) {
496         _, err := StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
497         want := CorruptInputError(22)
498         if !reflect.DeepEqual(want, err) {
499                 t.Errorf("Error = %v; want CorruptInputError(22)", err)
500         }
501         _, err = StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDA==")
502         if err != nil {
503                 t.Errorf("Error = %v; want nil", err)
504         }
505         _, err = StdEncoding.DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
506         if err != nil {
507                 t.Errorf("Error = %v; want nil", err)
508         }
509 }
510
511 func BenchmarkEncodeToString(b *testing.B) {
512         data := make([]byte, 8192)
513         b.SetBytes(int64(len(data)))
514         for i := 0; i < b.N; i++ {
515                 StdEncoding.EncodeToString(data)
516         }
517 }
518
519 func BenchmarkDecodeString(b *testing.B) {
520         sizes := []int{2, 4, 8, 64, 8192}
521         benchFunc := func(b *testing.B, benchSize int) {
522                 data := StdEncoding.EncodeToString(make([]byte, benchSize))
523                 b.SetBytes(int64(len(data)))
524                 b.ResetTimer()
525                 for i := 0; i < b.N; i++ {
526                         StdEncoding.DecodeString(data)
527                 }
528         }
529         for _, size := range sizes {
530                 b.Run(fmt.Sprintf("%d", size), func(b *testing.B) {
531                         benchFunc(b, size)
532                 })
533         }
534 }
535
536 func BenchmarkNewEncoding(b *testing.B) {
537         b.SetBytes(int64(len(Encoding{}.decodeMap)))
538         for i := 0; i < b.N; i++ {
539                 e := NewEncoding(encodeStd)
540                 for _, v := range e.decodeMap {
541                         _ = v
542                 }
543         }
544 }
545
546 func TestDecoderRaw(t *testing.T) {
547         source := "AAAAAA"
548         want := []byte{0, 0, 0, 0}
549
550         // Direct.
551         dec1, err := RawURLEncoding.DecodeString(source)
552         if err != nil || !bytes.Equal(dec1, want) {
553                 t.Errorf("RawURLEncoding.DecodeString(%q) = %x, %v, want %x, nil", source, dec1, err, want)
554         }
555
556         // Through reader. Used to fail.
557         r := NewDecoder(RawURLEncoding, bytes.NewReader([]byte(source)))
558         dec2, err := io.ReadAll(io.LimitReader(r, 100))
559         if err != nil || !bytes.Equal(dec2, want) {
560                 t.Errorf("reading NewDecoder(RawURLEncoding, %q) = %x, %v, want %x, nil", source, dec2, err, want)
561         }
562
563         // Should work with padding.
564         r = NewDecoder(URLEncoding, bytes.NewReader([]byte(source+"==")))
565         dec3, err := io.ReadAll(r)
566         if err != nil || !bytes.Equal(dec3, want) {
567                 t.Errorf("reading NewDecoder(URLEncoding, %q) = %x, %v, want %x, nil", source+"==", dec3, err, want)
568         }
569 }