]> Cypherpunks.ru repositories - gostls13.git/commitdiff
encoding: optimize growth behavior in Encoding.AppendDecode
authorJoe Tsai <joetsai@digital-static.net>
Thu, 17 Aug 2023 04:27:15 +0000 (21:27 -0700)
committerGopher Robot <gobot@golang.org>
Sat, 19 Aug 2023 22:25:23 +0000 (22:25 +0000)
The Encoding.DecodedLen API only returns the maximum length of the
expected decoded output, since it does not know about padding.
Since we have the input, we can do better by computing the
input length without padding, and then perform the DecodedLen
calculation as if there were no padding.

This avoids over-growing the destination slice if possible.
Over-growth is still possible since the input may contain
ignore characters like newlines and carriage returns,
but those a rarely encountered in practice.

Change-Id: I38b8f91de1f4fbd3a7128c491a25098bd385cf74
Reviewed-on: https://go-review.googlesource.com/c/go/+/520267
Run-TryBot: Joseph Tsai <joetsai@digital-static.net>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Auto-Submit: Joseph Tsai <joetsai@digital-static.net>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/encoding/base32/base32.go
src/encoding/base32/base32_test.go
src/encoding/base64/base64.go
src/encoding/base64/base64_test.go

index de95df0043339fbb34ca686d431ac689e3a7c7ff..e92188728592a1383c88b9c69c72dff645af4f17 100644 (file)
@@ -402,7 +402,13 @@ func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
 // and returns the extended buffer.
 // If the input is malformed, it returns the partially decoded src and an error.
 func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
-       n := enc.DecodedLen(len(src))
+       // Compute the output size without padding to avoid over allocating.
+       n := len(src)
+       for n > 0 && rune(src[n-1]) == enc.padChar {
+               n--
+       }
+       n = decodedLen(n, NoPadding)
+
        dst = slices.Grow(dst, n)
        n, err := enc.Decode(dst[len(dst):][:n], src)
        return dst[:len(dst)+n], err
@@ -567,7 +573,11 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
 // DecodedLen returns the maximum length in bytes of the decoded data
 // corresponding to n bytes of base32-encoded data.
 func (enc *Encoding) DecodedLen(n int) int {
-       if enc.padChar == NoPadding {
+       return decodedLen(n, enc.padChar)
+}
+
+func decodedLen(n int, padChar rune) int {
+       if padChar == NoPadding {
                return n/8*5 + n%8*5/8
        }
        return n / 8 * 5
index 0132744507502720da87d9c98c44d9157f535e4e..33638adeac691313c0a12ab8e1edca32820ac215 100644 (file)
@@ -110,6 +110,13 @@ func TestDecode(t *testing.T) {
                dst, err := StdEncoding.AppendDecode([]byte("lead"), []byte(p.encoded))
                testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
                testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded)
+
+               dst2, err := StdEncoding.AppendDecode(dst[:0:len(p.decoded)], []byte(p.encoded))
+               testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
+               testEqual(t, `AppendDecode("", %q) = %q, want %q`, p.encoded, string(dst2), p.decoded)
+               if len(dst) > 0 && len(dst2) > 0 && &dst[0] != &dst2[0] {
+                       t.Errorf("unexpected capacity growth: got %d, want %d", cap(dst2), cap(dst))
+               }
        }
 }
 
index 802ef14c38992ec84d800beaffb268cbfd4db962..9445cbd4efb8913d240091f64664bc1d43c33e8e 100644 (file)
@@ -411,7 +411,13 @@ func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err err
 // and returns the extended buffer.
 // If the input is malformed, it returns the partially decoded src and an error.
 func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
-       n := enc.DecodedLen(len(src))
+       // Compute the output size without padding to avoid over allocating.
+       n := len(src)
+       for n > 0 && rune(src[n-1]) == enc.padChar {
+               n--
+       }
+       n = decodedLen(n, NoPadding)
+
        dst = slices.Grow(dst, n)
        n, err := enc.Decode(dst[len(dst):][:n], src)
        return dst[:len(dst)+n], err
@@ -643,7 +649,11 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
 // DecodedLen returns the maximum length in bytes of the decoded data
 // corresponding to n bytes of base64-encoded data.
 func (enc *Encoding) DecodedLen(n int) int {
-       if enc.padChar == NoPadding {
+       return decodedLen(n, enc.padChar)
+}
+
+func decodedLen(n int, padChar rune) int {
+       if padChar == NoPadding {
                // Unpadded data may end with partial block of 2-3 characters.
                return n/4*3 + n%4*6/8
        }
index 4d7437b919520a79c846a552343b070ef11a5b28..6dfdaef1f1a11b61f853c61576d08b2f96605dfe 100644 (file)
@@ -167,6 +167,13 @@ func TestDecode(t *testing.T) {
                        dst, err := tt.enc.AppendDecode([]byte("lead"), []byte(encoded))
                        testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
                        testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded)
+
+                       dst2, err := tt.enc.AppendDecode(dst[:0:len(p.decoded)], []byte(encoded))
+                       testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
+                       testEqual(t, `AppendDecode("", %q) = %q, want %q`, p.encoded, string(dst2), p.decoded)
+                       if len(dst) > 0 && len(dst2) > 0 && &dst[0] != &dst2[0] {
+                               t.Errorf("unexpected capacity growth: got %d, want %d", cap(dst2), cap(dst))
+                       }
                }
        }
 }