]> Cypherpunks.ru repositories - gostls13.git/blob - src/encoding/base32/base32.go
encoding: add AppendEncode and AppendDecode
[gostls13.git] / src / encoding / base32 / base32.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // Package base32 implements base32 encoding as specified by RFC 4648.
6 package base32
7
8 import (
9         "io"
10         "slices"
11         "strconv"
12 )
13
14 /*
15  * Encodings
16  */
17
18 // An Encoding is a radix 32 encoding/decoding scheme, defined by a
19 // 32-character alphabet. The most common is the "base32" encoding
20 // introduced for SASL GSSAPI and standardized in RFC 4648.
21 // The alternate "base32hex" encoding is used in DNSSEC.
22 type Encoding struct {
23         encode    [32]byte
24         decodeMap [256]byte
25         padChar   rune
26 }
27
28 const (
29         StdPadding          rune = '=' // Standard padding character
30         NoPadding           rune = -1  // No padding
31         decodeMapInitialize      = "" +
32                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
33                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
34                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
35                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
36                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
37                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
38                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
39                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
40                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
41                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
42                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
43                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
44                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
45                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
46                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
47                 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
48 )
49
50 const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
51 const encodeHex = "0123456789ABCDEFGHIJKLMNOPQRSTUV"
52
53 // NewEncoding returns a new Encoding defined by the given alphabet,
54 // which must be a 32-byte string. The alphabet is treated as sequence
55 // of byte values without any special treatment for multi-byte UTF-8.
56 func NewEncoding(encoder string) *Encoding {
57         if len(encoder) != 32 {
58                 panic("encoding alphabet is not 32-bytes long")
59         }
60
61         e := new(Encoding)
62         e.padChar = StdPadding
63         copy(e.encode[:], encoder)
64         copy(e.decodeMap[:], decodeMapInitialize)
65
66         for i := 0; i < len(encoder); i++ {
67                 e.decodeMap[encoder[i]] = byte(i)
68         }
69         return e
70 }
71
72 // StdEncoding is the standard base32 encoding, as defined in
73 // RFC 4648.
74 var StdEncoding = NewEncoding(encodeStd)
75
76 // HexEncoding is the “Extended Hex Alphabet” defined in RFC 4648.
77 // It is typically used in DNS.
78 var HexEncoding = NewEncoding(encodeHex)
79
80 // WithPadding creates a new encoding identical to enc except
81 // with a specified padding character, or NoPadding to disable padding.
82 // The padding character must not be '\r' or '\n', must not
83 // be contained in the encoding's alphabet and must be a rune equal or
84 // below '\xff'.
85 // Padding characters above '\x7f' are encoded as their exact byte value
86 // rather than using the UTF-8 representation of the codepoint.
87 func (enc Encoding) WithPadding(padding rune) *Encoding {
88         if padding == '\r' || padding == '\n' || padding > 0xff {
89                 panic("invalid padding")
90         }
91
92         for i := 0; i < len(enc.encode); i++ {
93                 if rune(enc.encode[i]) == padding {
94                         panic("padding contained in alphabet")
95                 }
96         }
97
98         enc.padChar = padding
99         return &enc
100 }
101
102 /*
103  * Encoder
104  */
105
106 // Encode encodes src using the encoding enc, writing
107 // EncodedLen(len(src)) bytes to dst.
108 //
109 // The encoding pads the output to a multiple of 8 bytes,
110 // so Encode is not appropriate for use on individual blocks
111 // of a large data stream. Use NewEncoder() instead.
112 func (enc *Encoding) Encode(dst, src []byte) {
113         if len(src) == 0 {
114                 return
115         }
116         // enc is a pointer receiver, so the use of enc.encode within the hot
117         // loop below means a nil check at every operation. Lift that nil check
118         // outside of the loop to speed up the encoder.
119         _ = enc.encode
120
121         di, si := 0, 0
122         n := (len(src) / 5) * 5
123         for si < n {
124                 // Combining two 32 bit loads allows the same code to be used
125                 // for 32 and 64 bit platforms.
126                 hi := uint32(src[si+0])<<24 | uint32(src[si+1])<<16 | uint32(src[si+2])<<8 | uint32(src[si+3])
127                 lo := hi<<8 | uint32(src[si+4])
128
129                 dst[di+0] = enc.encode[(hi>>27)&0x1F]
130                 dst[di+1] = enc.encode[(hi>>22)&0x1F]
131                 dst[di+2] = enc.encode[(hi>>17)&0x1F]
132                 dst[di+3] = enc.encode[(hi>>12)&0x1F]
133                 dst[di+4] = enc.encode[(hi>>7)&0x1F]
134                 dst[di+5] = enc.encode[(hi>>2)&0x1F]
135                 dst[di+6] = enc.encode[(lo>>5)&0x1F]
136                 dst[di+7] = enc.encode[(lo)&0x1F]
137
138                 si += 5
139                 di += 8
140         }
141
142         // Add the remaining small block
143         remain := len(src) - si
144         if remain == 0 {
145                 return
146         }
147
148         // Encode the remaining bytes in reverse order.
149         val := uint32(0)
150         switch remain {
151         case 4:
152                 val |= uint32(src[si+3])
153                 dst[di+6] = enc.encode[val<<3&0x1F]
154                 dst[di+5] = enc.encode[val>>2&0x1F]
155                 fallthrough
156         case 3:
157                 val |= uint32(src[si+2]) << 8
158                 dst[di+4] = enc.encode[val>>7&0x1F]
159                 fallthrough
160         case 2:
161                 val |= uint32(src[si+1]) << 16
162                 dst[di+3] = enc.encode[val>>12&0x1F]
163                 dst[di+2] = enc.encode[val>>17&0x1F]
164                 fallthrough
165         case 1:
166                 val |= uint32(src[si+0]) << 24
167                 dst[di+1] = enc.encode[val>>22&0x1F]
168                 dst[di+0] = enc.encode[val>>27&0x1F]
169         }
170
171         // Pad the final quantum
172         if enc.padChar != NoPadding {
173                 nPad := (remain * 8 / 5) + 1
174                 for i := nPad; i < 8; i++ {
175                         dst[di+i] = byte(enc.padChar)
176                 }
177         }
178 }
179
180 // AppendEncode appends the base32 encoded src to dst
181 // and returns the extended buffer.
182 func (enc *Encoding) AppendEncode(dst, src []byte) []byte {
183         n := enc.EncodedLen(len(src))
184         dst = slices.Grow(dst, n)
185         enc.Encode(dst[len(dst):][:n], src)
186         return dst[:len(dst)+n]
187 }
188
189 // EncodeToString returns the base32 encoding of src.
190 func (enc *Encoding) EncodeToString(src []byte) string {
191         buf := make([]byte, enc.EncodedLen(len(src)))
192         enc.Encode(buf, src)
193         return string(buf)
194 }
195
196 type encoder struct {
197         err  error
198         enc  *Encoding
199         w    io.Writer
200         buf  [5]byte    // buffered data waiting to be encoded
201         nbuf int        // number of bytes in buf
202         out  [1024]byte // output buffer
203 }
204
205 func (e *encoder) Write(p []byte) (n int, err error) {
206         if e.err != nil {
207                 return 0, e.err
208         }
209
210         // Leading fringe.
211         if e.nbuf > 0 {
212                 var i int
213                 for i = 0; i < len(p) && e.nbuf < 5; i++ {
214                         e.buf[e.nbuf] = p[i]
215                         e.nbuf++
216                 }
217                 n += i
218                 p = p[i:]
219                 if e.nbuf < 5 {
220                         return
221                 }
222                 e.enc.Encode(e.out[0:], e.buf[0:])
223                 if _, e.err = e.w.Write(e.out[0:8]); e.err != nil {
224                         return n, e.err
225                 }
226                 e.nbuf = 0
227         }
228
229         // Large interior chunks.
230         for len(p) >= 5 {
231                 nn := len(e.out) / 8 * 5
232                 if nn > len(p) {
233                         nn = len(p)
234                         nn -= nn % 5
235                 }
236                 e.enc.Encode(e.out[0:], p[0:nn])
237                 if _, e.err = e.w.Write(e.out[0 : nn/5*8]); e.err != nil {
238                         return n, e.err
239                 }
240                 n += nn
241                 p = p[nn:]
242         }
243
244         // Trailing fringe.
245         copy(e.buf[:], p)
246         e.nbuf = len(p)
247         n += len(p)
248         return
249 }
250
251 // Close flushes any pending output from the encoder.
252 // It is an error to call Write after calling Close.
253 func (e *encoder) Close() error {
254         // If there's anything left in the buffer, flush it out
255         if e.err == nil && e.nbuf > 0 {
256                 e.enc.Encode(e.out[0:], e.buf[0:e.nbuf])
257                 encodedLen := e.enc.EncodedLen(e.nbuf)
258                 e.nbuf = 0
259                 _, e.err = e.w.Write(e.out[0:encodedLen])
260         }
261         return e.err
262 }
263
264 // NewEncoder returns a new base32 stream encoder. Data written to
265 // the returned writer will be encoded using enc and then written to w.
266 // Base32 encodings operate in 5-byte blocks; when finished
267 // writing, the caller must Close the returned encoder to flush any
268 // partially written blocks.
269 func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
270         return &encoder{enc: enc, w: w}
271 }
272
273 // EncodedLen returns the length in bytes of the base32 encoding
274 // of an input buffer of length n.
275 func (enc *Encoding) EncodedLen(n int) int {
276         if enc.padChar == NoPadding {
277                 return n/5*8 + (n%5*8+4)/5
278         }
279         return (n + 4) / 5 * 8
280 }
281
282 /*
283  * Decoder
284  */
285
286 type CorruptInputError int64
287
288 func (e CorruptInputError) Error() string {
289         return "illegal base32 data at input byte " + strconv.FormatInt(int64(e), 10)
290 }
291
292 // decode is like Decode but returns an additional 'end' value, which
293 // indicates if end-of-message padding was encountered and thus any
294 // additional data is an error. This method assumes that src has been
295 // stripped of all supported whitespace ('\r' and '\n').
296 func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
297         // Lift the nil check outside of the loop.
298         _ = enc.decodeMap
299
300         dsti := 0
301         olen := len(src)
302
303         for len(src) > 0 && !end {
304                 // Decode quantum using the base32 alphabet
305                 var dbuf [8]byte
306                 dlen := 8
307
308                 for j := 0; j < 8; {
309
310                         if len(src) == 0 {
311                                 if enc.padChar != NoPadding {
312                                         // We have reached the end and are missing padding
313                                         return n, false, CorruptInputError(olen - len(src) - j)
314                                 }
315                                 // We have reached the end and are not expecting any padding
316                                 dlen, end = j, true
317                                 break
318                         }
319                         in := src[0]
320                         src = src[1:]
321                         if in == byte(enc.padChar) && j >= 2 && len(src) < 8 {
322                                 // We've reached the end and there's padding
323                                 if len(src)+j < 8-1 {
324                                         // not enough padding
325                                         return n, false, CorruptInputError(olen)
326                                 }
327                                 for k := 0; k < 8-1-j; k++ {
328                                         if len(src) > k && src[k] != byte(enc.padChar) {
329                                                 // incorrect padding
330                                                 return n, false, CorruptInputError(olen - len(src) + k - 1)
331                                         }
332                                 }
333                                 dlen, end = j, true
334                                 // 7, 5 and 2 are not valid padding lengths, and so 1, 3 and 6 are not
335                                 // valid dlen values. See RFC 4648 Section 6 "Base 32 Encoding" listing
336                                 // the five valid padding lengths, and Section 9 "Illustrations and
337                                 // Examples" for an illustration for how the 1st, 3rd and 6th base32
338                                 // src bytes do not yield enough information to decode a dst byte.
339                                 if dlen == 1 || dlen == 3 || dlen == 6 {
340                                         return n, false, CorruptInputError(olen - len(src) - 1)
341                                 }
342                                 break
343                         }
344                         dbuf[j] = enc.decodeMap[in]
345                         if dbuf[j] == 0xFF {
346                                 return n, false, CorruptInputError(olen - len(src) - 1)
347                         }
348                         j++
349                 }
350
351                 // Pack 8x 5-bit source blocks into 5 byte destination
352                 // quantum
353                 switch dlen {
354                 case 8:
355                         dst[dsti+4] = dbuf[6]<<5 | dbuf[7]
356                         n++
357                         fallthrough
358                 case 7:
359                         dst[dsti+3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3
360                         n++
361                         fallthrough
362                 case 5:
363                         dst[dsti+2] = dbuf[3]<<4 | dbuf[4]>>1
364                         n++
365                         fallthrough
366                 case 4:
367                         dst[dsti+1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4
368                         n++
369                         fallthrough
370                 case 2:
371                         dst[dsti+0] = dbuf[0]<<3 | dbuf[1]>>2
372                         n++
373                 }
374                 dsti += 5
375         }
376         return n, end, nil
377 }
378
379 // Decode decodes src using the encoding enc. It writes at most
380 // DecodedLen(len(src)) bytes to dst and returns the number of bytes
381 // written. If src contains invalid base32 data, it will return the
382 // number of bytes successfully written and CorruptInputError.
383 // New line characters (\r and \n) are ignored.
384 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
385         buf := make([]byte, len(src))
386         l := stripNewlines(buf, src)
387         n, _, err = enc.decode(dst, buf[:l])
388         return
389 }
390
391 // AppendDecode appends the base32 decoded src to dst
392 // and returns the extended buffer.
393 // If the input is malformed, it returns the partially decoded src and an error.
394 func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
395         n := enc.DecodedLen(len(src))
396         dst = slices.Grow(dst, n)
397         n, err := enc.Decode(dst[len(dst):][:n], src)
398         return dst[:len(dst)+n], err
399 }
400
401 // DecodeString returns the bytes represented by the base32 string s.
402 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
403         buf := []byte(s)
404         l := stripNewlines(buf, buf)
405         n, _, err := enc.decode(buf, buf[:l])
406         return buf[:n], err
407 }
408
409 type decoder struct {
410         err    error
411         enc    *Encoding
412         r      io.Reader
413         end    bool       // saw end of message
414         buf    [1024]byte // leftover input
415         nbuf   int
416         out    []byte // leftover decoded output
417         outbuf [1024 / 8 * 5]byte
418 }
419
420 func readEncodedData(r io.Reader, buf []byte, min int, expectsPadding bool) (n int, err error) {
421         for n < min && err == nil {
422                 var nn int
423                 nn, err = r.Read(buf[n:])
424                 n += nn
425         }
426         // data was read, less than min bytes could be read
427         if n < min && n > 0 && err == io.EOF {
428                 err = io.ErrUnexpectedEOF
429         }
430         // no data was read, the buffer already contains some data
431         // when padding is disabled this is not an error, as the message can be of
432         // any length
433         if expectsPadding && min < 8 && n == 0 && err == io.EOF {
434                 err = io.ErrUnexpectedEOF
435         }
436         return
437 }
438
439 func (d *decoder) Read(p []byte) (n int, err error) {
440         // Use leftover decoded output from last read.
441         if len(d.out) > 0 {
442                 n = copy(p, d.out)
443                 d.out = d.out[n:]
444                 if len(d.out) == 0 {
445                         return n, d.err
446                 }
447                 return n, nil
448         }
449
450         if d.err != nil {
451                 return 0, d.err
452         }
453
454         // Read a chunk.
455         nn := len(p) / 5 * 8
456         if nn < 8 {
457                 nn = 8
458         }
459         if nn > len(d.buf) {
460                 nn = len(d.buf)
461         }
462
463         // Minimum amount of bytes that needs to be read each cycle
464         var min int
465         var expectsPadding bool
466         if d.enc.padChar == NoPadding {
467                 min = 1
468                 expectsPadding = false
469         } else {
470                 min = 8 - d.nbuf
471                 expectsPadding = true
472         }
473
474         nn, d.err = readEncodedData(d.r, d.buf[d.nbuf:nn], min, expectsPadding)
475         d.nbuf += nn
476         if d.nbuf < min {
477                 return 0, d.err
478         }
479         if nn > 0 && d.end {
480                 return 0, CorruptInputError(0)
481         }
482
483         // Decode chunk into p, or d.out and then p if p is too small.
484         var nr int
485         if d.enc.padChar == NoPadding {
486                 nr = d.nbuf
487         } else {
488                 nr = d.nbuf / 8 * 8
489         }
490         nw := d.enc.DecodedLen(d.nbuf)
491
492         if nw > len(p) {
493                 nw, d.end, err = d.enc.decode(d.outbuf[0:], d.buf[0:nr])
494                 d.out = d.outbuf[0:nw]
495                 n = copy(p, d.out)
496                 d.out = d.out[n:]
497         } else {
498                 n, d.end, err = d.enc.decode(p, d.buf[0:nr])
499         }
500         d.nbuf -= nr
501         for i := 0; i < d.nbuf; i++ {
502                 d.buf[i] = d.buf[i+nr]
503         }
504
505         if err != nil && (d.err == nil || d.err == io.EOF) {
506                 d.err = err
507         }
508
509         if len(d.out) > 0 {
510                 // We cannot return all the decoded bytes to the caller in this
511                 // invocation of Read, so we return a nil error to ensure that Read
512                 // will be called again.  The error stored in d.err, if any, will be
513                 // returned with the last set of decoded bytes.
514                 return n, nil
515         }
516
517         return n, d.err
518 }
519
520 type newlineFilteringReader struct {
521         wrapped io.Reader
522 }
523
524 // stripNewlines removes newline characters and returns the number
525 // of non-newline characters copied to dst.
526 func stripNewlines(dst, src []byte) int {
527         offset := 0
528         for _, b := range src {
529                 if b == '\r' || b == '\n' {
530                         continue
531                 }
532                 dst[offset] = b
533                 offset++
534         }
535         return offset
536 }
537
538 func (r *newlineFilteringReader) Read(p []byte) (int, error) {
539         n, err := r.wrapped.Read(p)
540         for n > 0 {
541                 s := p[0:n]
542                 offset := stripNewlines(s, s)
543                 if err != nil || offset > 0 {
544                         return offset, err
545                 }
546                 // Previous buffer entirely whitespace, read again
547                 n, err = r.wrapped.Read(p)
548         }
549         return n, err
550 }
551
552 // NewDecoder constructs a new base32 stream decoder.
553 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
554         return &decoder{enc: enc, r: &newlineFilteringReader{r}}
555 }
556
557 // DecodedLen returns the maximum length in bytes of the decoded data
558 // corresponding to n bytes of base32-encoded data.
559 func (enc *Encoding) DecodedLen(n int) int {
560         if enc.padChar == NoPadding {
561                 return n/8*5 + n%8*5/8
562         }
563         return n / 8 * 5
564 }