]> Cypherpunks.ru repositories - netstring.git/blobdiff - r.go
Unify copyright comment format
[netstring.git] / r.go
diff --git a/r.go b/r.go
index ca1925908d09e22a1c6b33fe487da9880c06216b..9bb12adf92a3b09e8df9a8569ed77bc15fb192b1 100644 (file)
--- a/r.go
+++ b/r.go
@@ -1,26 +1,24 @@
-/*
-netstring -- netstring format serialization library
-Copyright (C) 2015-2021 Sergey Matveev <stargrave@stargrave.org>
-
-This program is free software: you can redistribute it and/or modify
-it under the terms of the GNU General Public License as published by
-the Free Software Foundation, version 3 of the License.
-
-This program is distributed in the hope that it will be useful,
-but WITHOUT ANY WARRANTY; without even the implied warranty of
-MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-GNU General Public License for more details.
-
-You should have received a copy of the GNU General Public License
-along with this program.  If not, see <http://www.gnu.org/licenses/>.
-*/
+// netstring -- netstring format serialization library
+// Copyright (C) 2015-2024 Sergey Matveev <stargrave@stargrave.org>
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, version 3 of the License.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 package netstring
 
 import (
        "bufio"
-       "bytes"
        "errors"
+       "fmt"
        "io"
        "strconv"
 )
@@ -44,33 +42,36 @@ func (r *Reader) Next() (uint64, error) {
        if !r.eof {
                return 0, errors.New("current chunk is unread")
        }
-       p, _ := r.r.Peek(21)
-       if len(p) == 0 {
-               return 0, io.EOF
+       lenRaw, err := r.r.ReadSlice(':')
+       if err != nil {
+               return 0, fmt.Errorf("netstring header: %w", err)
        }
-       idx := bytes.Index(p, []byte{':'})
-       if idx == -1 {
-               return 0, errors.New("no length separator found")
+       lenRaw = lenRaw[:len(lenRaw)-1]
+       if len(lenRaw) > 1 && lenRaw[0] == '0' {
+               return 0, errors.New("netstring header: leading zero")
        }
-       size, err := strconv.ParseUint(string(p[:idx]), 10, 64)
+       size, err := strconv.ParseUint(string(lenRaw), 10, 64)
        if err != nil {
-               return 0, err
-       }
-       if _, err = r.r.Discard(idx + 1); err != nil {
-               return 0, err
+               return 0, fmt.Errorf("netstring header: %w", err)
        }
        r.left = size
        r.eof = false
-       return size, nil
+       if r.left == 0 {
+               err = r.checkTerminator()
+               if err == nil {
+                       r.eof = true
+               }
+       }
+       return size, err
 }
 
 func (r *Reader) checkTerminator() error {
        b, err := r.r.ReadByte()
        if err != nil {
-               return err
+               return fmt.Errorf("netstring terminator: %w", err)
        }
        if b != ',' {
-               return errors.New("no terminator found")
+               return errors.New("netstring terminator: not found")
        }
        return nil
 }