]> Cypherpunks.ru repositories - gostls13.git/blob - src/encoding/base64/base64.go
encoding: add AppendEncode and AppendDecode
[gostls13.git] / src / encoding / base64 / base64.go
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // Package base64 implements base64 encoding as specified by RFC 4648.
6 package base64
7
8 import (
9         "encoding/binary"
10         "io"
11         "slices"
12         "strconv"
13 )
14
15 /*
16  * Encodings
17  */
18
19 // An Encoding is a radix 64 encoding/decoding scheme, defined by a
20 // 64-character alphabet. The most common encoding is the "base64"
21 // encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM
22 // (RFC 1421).  RFC 4648 also defines an alternate encoding, which is
23 // the standard encoding with - and _ substituted for + and /.
24 type Encoding struct {
25         encode    [64]byte
26         decodeMap [256]byte
27         padChar   rune
28         strict    bool
29 }
30
31 const (
32         StdPadding          rune = '=' // Standard padding character
33         NoPadding           rune = -1  // No padding
34         decodeMapInitialize      = "" +
35                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
36                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
37                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
38                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
39                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
40                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
41                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
42                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
43                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
44                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
45                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
46                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
47                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
48                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
49                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
50                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
51 )
52
53 const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
54 const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
55
56 // NewEncoding returns a new padded Encoding defined by the given alphabet,
57 // which must be a 64-byte string that does not contain the padding character
58 // or CR / LF ('\r', '\n'). The alphabet is treated as sequence of byte values
59 // without any special treatment for multi-byte UTF-8.
60 // The resulting Encoding uses the default padding character ('='),
61 // which may be changed or disabled via WithPadding.
62 func NewEncoding(encoder string) *Encoding {
63         if len(encoder) != 64 {
64                 panic("encoding alphabet is not 64-bytes long")
65         }
66         for i := 0; i < len(encoder); i++ {
67                 if encoder[i] == '\n' || encoder[i] == '\r' {
68                         panic("encoding alphabet contains newline character")
69                 }
70         }
71
72         e := new(Encoding)
73         e.padChar = StdPadding
74         copy(e.encode[:], encoder)
75         copy(e.decodeMap[:], decodeMapInitialize)
76
77         for i := 0; i < len(encoder); i++ {
78                 e.decodeMap[encoder[i]] = byte(i)
79         }
80         return e
81 }
82
83 // WithPadding creates a new encoding identical to enc except
84 // with a specified padding character, or NoPadding to disable padding.
85 // The padding character must not be '\r' or '\n', must not
86 // be contained in the encoding's alphabet and must be a rune equal or
87 // below '\xff'.
88 // Padding characters above '\x7f' are encoded as their exact byte value
89 // rather than using the UTF-8 representation of the codepoint.
90 func (enc Encoding) WithPadding(padding rune) *Encoding {
91         if padding == '\r' || padding == '\n' || padding > 0xff {
92                 panic("invalid padding")
93         }
94
95         for i := 0; i < len(enc.encode); i++ {
96                 if rune(enc.encode[i]) == padding {
97                         panic("padding contained in alphabet")
98                 }
99         }
100
101         enc.padChar = padding
102         return &enc
103 }
104
105 // Strict creates a new encoding identical to enc except with
106 // strict decoding enabled. In this mode, the decoder requires that
107 // trailing padding bits are zero, as described in RFC 4648 section 3.5.
108 //
109 // Note that the input is still malleable, as new line characters
110 // (CR and LF) are still ignored.
111 func (enc Encoding) Strict() *Encoding {
112         enc.strict = true
113         return &enc
114 }
115
116 // StdEncoding is the standard base64 encoding, as defined in
117 // RFC 4648.
118 var StdEncoding = NewEncoding(encodeStd)
119
120 // URLEncoding is the alternate base64 encoding defined in RFC 4648.
121 // It is typically used in URLs and file names.
122 var URLEncoding = NewEncoding(encodeURL)
123
124 // RawStdEncoding is the standard raw, unpadded base64 encoding,
125 // as defined in RFC 4648 section 3.2.
126 // This is the same as StdEncoding but omits padding characters.
127 var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
128
129 // RawURLEncoding is the unpadded alternate base64 encoding defined in RFC 4648.
130 // It is typically used in URLs and file names.
131 // This is the same as URLEncoding but omits padding characters.
132 var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
133
134 /*
135  * Encoder
136  */
137
138 // Encode encodes src using the encoding enc, writing
139 // EncodedLen(len(src)) bytes to dst.
140 //
141 // The encoding pads the output to a multiple of 4 bytes,
142 // so Encode is not appropriate for use on individual blocks
143 // of a large data stream. Use NewEncoder() instead.
144 func (enc *Encoding) Encode(dst, src []byte) {
145         if len(src) == 0 {
146                 return
147         }
148         // enc is a pointer receiver, so the use of enc.encode within the hot
149         // loop below means a nil check at every operation. Lift that nil check
150         // outside of the loop to speed up the encoder.
151         _ = enc.encode
152
153         di, si := 0, 0
154         n := (len(src) / 3) * 3
155         for si < n {
156                 // Convert 3x 8bit source bytes into 4 bytes
157                 val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
158
159                 dst[di+0] = enc.encode[val>>18&0x3F]
160                 dst[di+1] = enc.encode[val>>12&0x3F]
161                 dst[di+2] = enc.encode[val>>6&0x3F]
162                 dst[di+3] = enc.encode[val&0x3F]
163
164                 si += 3
165                 di += 4
166         }
167
168         remain := len(src) - si
169         if remain == 0 {
170                 return
171         }
172         // Add the remaining small block
173         val := uint(src[si+0]) << 16
174         if remain == 2 {
175                 val |= uint(src[si+1]) << 8
176         }
177
178         dst[di+0] = enc.encode[val>>18&0x3F]
179         dst[di+1] = enc.encode[val>>12&0x3F]
180
181         switch remain {
182         case 2:
183                 dst[di+2] = enc.encode[val>>6&0x3F]
184                 if enc.padChar != NoPadding {
185                         dst[di+3] = byte(enc.padChar)
186                 }
187         case 1:
188                 if enc.padChar != NoPadding {
189                         dst[di+2] = byte(enc.padChar)
190                         dst[di+3] = byte(enc.padChar)
191                 }
192         }
193 }
194
195 // AppendEncode appends the base64 encoded src to dst
196 // and returns the extended buffer.
197 func (enc *Encoding) AppendEncode(dst, src []byte) []byte {
198         n := enc.EncodedLen(len(src))
199         dst = slices.Grow(dst, n)
200         enc.Encode(dst[len(dst):][:n], src)
201         return dst[:len(dst)+n]
202 }
203
204 // EncodeToString returns the base64 encoding of src.
205 func (enc *Encoding) EncodeToString(src []byte) string {
206         buf := make([]byte, enc.EncodedLen(len(src)))
207         enc.Encode(buf, src)
208         return string(buf)
209 }
210
211 type encoder struct {
212         err  error
213         enc  *Encoding
214         w    io.Writer
215         buf  [3]byte    // buffered data waiting to be encoded
216         nbuf int        // number of bytes in buf
217         out  [1024]byte // output buffer
218 }
219
220 func (e *encoder) Write(p []byte) (n int, err error) {
221         if e.err != nil {
222                 return 0, e.err
223         }
224
225         // Leading fringe.
226         if e.nbuf > 0 {
227                 var i int
228                 for i = 0; i < len(p) && e.nbuf < 3; i++ {
229                         e.buf[e.nbuf] = p[i]
230                         e.nbuf++
231                 }
232                 n += i
233                 p = p[i:]
234                 if e.nbuf < 3 {
235                         return
236                 }
237                 e.enc.Encode(e.out[:], e.buf[:])
238                 if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
239                         return n, e.err
240                 }
241                 e.nbuf = 0
242         }
243
244         // Large interior chunks.
245         for len(p) >= 3 {
246                 nn := len(e.out) / 4 * 3
247                 if nn > len(p) {
248                         nn = len(p)
249                         nn -= nn % 3
250                 }
251                 e.enc.Encode(e.out[:], p[:nn])
252                 if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
253                         return n, e.err
254                 }
255                 n += nn
256                 p = p[nn:]
257         }
258
259         // Trailing fringe.
260         copy(e.buf[:], p)
261         e.nbuf = len(p)
262         n += len(p)
263         return
264 }
265
266 // Close flushes any pending output from the encoder.
267 // It is an error to call Write after calling Close.
268 func (e *encoder) Close() error {
269         // If there's anything left in the buffer, flush it out
270         if e.err == nil && e.nbuf > 0 {
271                 e.enc.Encode(e.out[:], e.buf[:e.nbuf])
272                 _, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
273                 e.nbuf = 0
274         }
275         return e.err
276 }
277
278 // NewEncoder returns a new base64 stream encoder. Data written to
279 // the returned writer will be encoded using enc and then written to w.
280 // Base64 encodings operate in 4-byte blocks; when finished
281 // writing, the caller must Close the returned encoder to flush any
282 // partially written blocks.
283 func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
284         return &encoder{enc: enc, w: w}
285 }
286
287 // EncodedLen returns the length in bytes of the base64 encoding
288 // of an input buffer of length n.
289 func (enc *Encoding) EncodedLen(n int) int {
290         if enc.padChar == NoPadding {
291                 return n/3*4 + (n%3*8+5)/6 // minimum # chars at 6 bits per char
292         }
293         return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
294 }
295
296 /*
297  * Decoder
298  */
299
300 type CorruptInputError int64
301
302 func (e CorruptInputError) Error() string {
303         return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
304 }
305
306 // decodeQuantum decodes up to 4 base64 bytes. The received parameters are
307 // the destination buffer dst, the source buffer src and an index in the
308 // source buffer si.
309 // It returns the number of bytes read from src, the number of bytes written
310 // to dst, and an error, if any.
311 func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
312         // Decode quantum using the base64 alphabet
313         var dbuf [4]byte
314         dlen := 4
315
316         // Lift the nil check outside of the loop.
317         _ = enc.decodeMap
318
319         for j := 0; j < len(dbuf); j++ {
320                 if len(src) == si {
321                         switch {
322                         case j == 0:
323                                 return si, 0, nil
324                         case j == 1, enc.padChar != NoPadding:
325                                 return si, 0, CorruptInputError(si - j)
326                         }
327                         dlen = j
328                         break
329                 }
330                 in := src[si]
331                 si++
332
333                 out := enc.decodeMap[in]
334                 if out != 0xff {
335                         dbuf[j] = out
336                         continue
337                 }
338
339                 if in == '\n' || in == '\r' {
340                         j--
341                         continue
342                 }
343
344                 if rune(in) != enc.padChar {
345                         return si, 0, CorruptInputError(si - 1)
346                 }
347
348                 // We've reached the end and there's padding
349                 switch j {
350                 case 0, 1:
351                         // incorrect padding
352                         return si, 0, CorruptInputError(si - 1)
353                 case 2:
354                         // "==" is expected, the first "=" is already consumed.
355                         // skip over newlines
356                         for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
357                                 si++
358                         }
359                         if si == len(src) {
360                                 // not enough padding
361                                 return si, 0, CorruptInputError(len(src))
362                         }
363                         if rune(src[si]) != enc.padChar {
364                                 // incorrect padding
365                                 return si, 0, CorruptInputError(si - 1)
366                         }
367
368                         si++
369                 }
370
371                 // skip over newlines
372                 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
373                         si++
374                 }
375                 if si < len(src) {
376                         // trailing garbage
377                         err = CorruptInputError(si)
378                 }
379                 dlen = j
380                 break
381         }
382
383         // Convert 4x 6bit source bytes into 3 bytes
384         val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
385         dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
386         switch dlen {
387         case 4:
388                 dst[2] = dbuf[2]
389                 dbuf[2] = 0
390                 fallthrough
391         case 3:
392                 dst[1] = dbuf[1]
393                 if enc.strict && dbuf[2] != 0 {
394                         return si, 0, CorruptInputError(si - 1)
395                 }
396                 dbuf[1] = 0
397                 fallthrough
398         case 2:
399                 dst[0] = dbuf[0]
400                 if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
401                         return si, 0, CorruptInputError(si - 2)
402                 }
403         }
404
405         return si, dlen - 1, err
406 }
407
408 // AppendDecode appends the base64 decoded src to dst
409 // and returns the extended buffer.
410 // If the input is malformed, it returns the partially decoded src and an error.
411 func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
412         n := enc.DecodedLen(len(src))
413         dst = slices.Grow(dst, n)
414         n, err := enc.Decode(dst[len(dst):][:n], src)
415         return dst[:len(dst)+n], err
416 }
417
418 // DecodeString returns the bytes represented by the base64 string s.
419 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
420         dbuf := make([]byte, enc.DecodedLen(len(s)))
421         n, err := enc.Decode(dbuf, []byte(s))
422         return dbuf[:n], err
423 }
424
425 type decoder struct {
426         err     error
427         readErr error // error from r.Read
428         enc     *Encoding
429         r       io.Reader
430         buf     [1024]byte // leftover input
431         nbuf    int
432         out     []byte // leftover decoded output
433         outbuf  [1024 / 4 * 3]byte
434 }
435
436 func (d *decoder) Read(p []byte) (n int, err error) {
437         // Use leftover decoded output from last read.
438         if len(d.out) > 0 {
439                 n = copy(p, d.out)
440                 d.out = d.out[n:]
441                 return n, nil
442         }
443
444         if d.err != nil {
445                 return 0, d.err
446         }
447
448         // This code assumes that d.r strips supported whitespace ('\r' and '\n').
449
450         // Refill buffer.
451         for d.nbuf < 4 && d.readErr == nil {
452                 nn := len(p) / 3 * 4
453                 if nn < 4 {
454                         nn = 4
455                 }
456                 if nn > len(d.buf) {
457                         nn = len(d.buf)
458                 }
459                 nn, d.readErr = d.r.Read(d.buf[d.nbuf:nn])
460                 d.nbuf += nn
461         }
462
463         if d.nbuf < 4 {
464                 if d.enc.padChar == NoPadding && d.nbuf > 0 {
465                         // Decode final fragment, without padding.
466                         var nw int
467                         nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
468                         d.nbuf = 0
469                         d.out = d.outbuf[:nw]
470                         n = copy(p, d.out)
471                         d.out = d.out[n:]
472                         if n > 0 || len(p) == 0 && len(d.out) > 0 {
473                                 return n, nil
474                         }
475                         if d.err != nil {
476                                 return 0, d.err
477                         }
478                 }
479                 d.err = d.readErr
480                 if d.err == io.EOF && d.nbuf > 0 {
481                         d.err = io.ErrUnexpectedEOF
482                 }
483                 return 0, d.err
484         }
485
486         // Decode chunk into p, or d.out and then p if p is too small.
487         nr := d.nbuf / 4 * 4
488         nw := d.nbuf / 4 * 3
489         if nw > len(p) {
490                 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
491                 d.out = d.outbuf[:nw]
492                 n = copy(p, d.out)
493                 d.out = d.out[n:]
494         } else {
495                 n, d.err = d.enc.Decode(p, d.buf[:nr])
496         }
497         d.nbuf -= nr
498         copy(d.buf[:d.nbuf], d.buf[nr:])
499         return n, d.err
500 }
501
502 // Decode decodes src using the encoding enc. It writes at most
503 // DecodedLen(len(src)) bytes to dst and returns the number of bytes
504 // written. If src contains invalid base64 data, it will return the
505 // number of bytes successfully written and CorruptInputError.
506 // New line characters (\r and \n) are ignored.
507 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
508         if len(src) == 0 {
509                 return 0, nil
510         }
511
512         // Lift the nil check outside of the loop. enc.decodeMap is directly
513         // used later in this function, to let the compiler know that the
514         // receiver can't be nil.
515         _ = enc.decodeMap
516
517         si := 0
518         for strconv.IntSize >= 64 && len(src)-si >= 8 && len(dst)-n >= 8 {
519                 src2 := src[si : si+8]
520                 if dn, ok := assemble64(
521                         enc.decodeMap[src2[0]],
522                         enc.decodeMap[src2[1]],
523                         enc.decodeMap[src2[2]],
524                         enc.decodeMap[src2[3]],
525                         enc.decodeMap[src2[4]],
526                         enc.decodeMap[src2[5]],
527                         enc.decodeMap[src2[6]],
528                         enc.decodeMap[src2[7]],
529                 ); ok {
530                         binary.BigEndian.PutUint64(dst[n:], dn)
531                         n += 6
532                         si += 8
533                 } else {
534                         var ninc int
535                         si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
536                         n += ninc
537                         if err != nil {
538                                 return n, err
539                         }
540                 }
541         }
542
543         for len(src)-si >= 4 && len(dst)-n >= 4 {
544                 src2 := src[si : si+4]
545                 if dn, ok := assemble32(
546                         enc.decodeMap[src2[0]],
547                         enc.decodeMap[src2[1]],
548                         enc.decodeMap[src2[2]],
549                         enc.decodeMap[src2[3]],
550                 ); ok {
551                         binary.BigEndian.PutUint32(dst[n:], dn)
552                         n += 3
553                         si += 4
554                 } else {
555                         var ninc int
556                         si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
557                         n += ninc
558                         if err != nil {
559                                 return n, err
560                         }
561                 }
562         }
563
564         for si < len(src) {
565                 var ninc int
566                 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
567                 n += ninc
568                 if err != nil {
569                         return n, err
570                 }
571         }
572         return n, err
573 }
574
575 // assemble32 assembles 4 base64 digits into 3 bytes.
576 // Each digit comes from the decode map, and will be 0xff
577 // if it came from an invalid character.
578 func assemble32(n1, n2, n3, n4 byte) (dn uint32, ok bool) {
579         // Check that all the digits are valid. If any of them was 0xff, their
580         // bitwise OR will be 0xff.
581         if n1|n2|n3|n4 == 0xff {
582                 return 0, false
583         }
584         return uint32(n1)<<26 |
585                         uint32(n2)<<20 |
586                         uint32(n3)<<14 |
587                         uint32(n4)<<8,
588                 true
589 }
590
591 // assemble64 assembles 8 base64 digits into 6 bytes.
592 // Each digit comes from the decode map, and will be 0xff
593 // if it came from an invalid character.
594 func assemble64(n1, n2, n3, n4, n5, n6, n7, n8 byte) (dn uint64, ok bool) {
595         // Check that all the digits are valid. If any of them was 0xff, their
596         // bitwise OR will be 0xff.
597         if n1|n2|n3|n4|n5|n6|n7|n8 == 0xff {
598                 return 0, false
599         }
600         return uint64(n1)<<58 |
601                         uint64(n2)<<52 |
602                         uint64(n3)<<46 |
603                         uint64(n4)<<40 |
604                         uint64(n5)<<34 |
605                         uint64(n6)<<28 |
606                         uint64(n7)<<22 |
607                         uint64(n8)<<16,
608                 true
609 }
610
611 type newlineFilteringReader struct {
612         wrapped io.Reader
613 }
614
615 func (r *newlineFilteringReader) Read(p []byte) (int, error) {
616         n, err := r.wrapped.Read(p)
617         for n > 0 {
618                 offset := 0
619                 for i, b := range p[:n] {
620                         if b != '\r' && b != '\n' {
621                                 if i != offset {
622                                         p[offset] = b
623                                 }
624                                 offset++
625                         }
626                 }
627                 if offset > 0 {
628                         return offset, err
629                 }
630                 // Previous buffer entirely whitespace, read again
631                 n, err = r.wrapped.Read(p)
632         }
633         return n, err
634 }
635
636 // NewDecoder constructs a new base64 stream decoder.
637 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
638         return &decoder{enc: enc, r: &newlineFilteringReader{r}}
639 }
640
641 // DecodedLen returns the maximum length in bytes of the decoded data
642 // corresponding to n bytes of base64-encoded data.
643 func (enc *Encoding) DecodedLen(n int) int {
644         if enc.padChar == NoPadding {
645                 // Unpadded data may end with partial block of 2-3 characters.
646                 return n/4*3 + n%4*6/8
647         }
648         // Padded base64 should always be a multiple of 4 characters in length.
649         return n / 4 * 3
650 }