]> Cypherpunks.ru repositories - netstring.git/blobdiff - r.go
Stricter header validation
[netstring.git] / r.go
diff --git a/r.go b/r.go
index 68a93ca4985d19ad81966c4f9327f350d7c24452..9b8f51069e541d86986497e383dec57349f563b8 100644 (file)
--- a/r.go
+++ b/r.go
@@ -20,6 +20,7 @@ package netstring
 import (
        "bufio"
        "errors"
+       "fmt"
        "io"
        "strconv"
 )
@@ -45,11 +46,15 @@ func (r *Reader) Next() (uint64, error) {
        }
        lenRaw, err := r.r.ReadSlice(':')
        if err != nil {
-               return 0, err
+               return 0, fmt.Errorf("netstring header: %w", err)
        }
-       size, err := strconv.ParseUint(string(lenRaw[:len(lenRaw)-1]), 10, 64)
+       lenRaw = lenRaw[:len(lenRaw)-1]
+       if len(lenRaw) > 1 && lenRaw[0] == '0' {
+               return 0, errors.New("netstring header: leading zero")
+       }
+       size, err := strconv.ParseUint(string(lenRaw), 10, 64)
        if err != nil {
-               return 0, err
+               return 0, fmt.Errorf("netstring header: %w", err)
        }
        r.left = size
        r.eof = false
@@ -59,10 +64,10 @@ func (r *Reader) Next() (uint64, error) {
 func (r *Reader) checkTerminator() error {
        b, err := r.r.ReadByte()
        if err != nil {
-               return err
+               return fmt.Errorf("netstring terminator: %w", err)
        }
        if b != ',' {
-               return errors.New("no terminator found")
+               return errors.New("netstring terminator: not found")
        }
        return nil
 }