]> Cypherpunks.ru repositories - netstring.git/blobdiff - ns_test.go
Do not require explicit Read for zero length
[netstring.git] / ns_test.go
index 1f63e31bbef2c90b54dd91717e9fa5307a9b5ceb..97e8deb2aac886d0c3aa1e4d1b31462a6ef34741 100644 (file)
@@ -1,6 +1,6 @@
 /*
 netstring -- netstring format serialization library
-Copyright (C) 2015-2021 Sergey Matveev <stargrave@stargrave.org>
+Copyright (C) 2015-2023 Sergey Matveev <stargrave@stargrave.org>
 
 This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
@@ -19,7 +19,7 @@ package netstring
 
 import (
        "bytes"
-       "io/ioutil"
+       "io"
        "testing"
        "testing/quick"
 )
@@ -36,7 +36,7 @@ func TestTrivial(t *testing.T) {
        if n, err := w.WriteChunk([]byte("barz")); err != nil || n != 7 {
                t.FailNow()
        }
-       if string(buf.Bytes()) != "3:foo,4:barz," {
+       if buf.String() != "3:foo,4:barz," {
                t.FailNow()
        }
        r := NewReader(&buf)
@@ -53,7 +53,7 @@ func TestTrivial(t *testing.T) {
        if n, err := r.Read(m); err != nil || n != 4 {
                t.FailNow()
        }
-       if bytes.Compare(m, []byte("barz")) != 0 {
+       if !bytes.Equal(m, []byte("barz")) {
                t.FailNow()
        }
 }
@@ -72,8 +72,8 @@ func TestSymmetric(t *testing.T) {
                        if n, err := r.Next(); err != nil || n != uint64(len(data)) {
                                return false
                        }
-                       got, err := ioutil.ReadAll(r)
-                       if err != nil || bytes.Compare(got, data) != 0 {
+                       got, err := io.ReadAll(r)
+                       if err != nil || !bytes.Equal(got, data) {
                                return false
                        }
                }
@@ -107,14 +107,11 @@ func TestErrors(t *testing.T) {
 
        b = bytes.NewBufferString("0:foobar,")
        r = NewReader(b)
-       if _, err := r.Next(); err != nil {
-               t.FailNow()
-       }
-       if _, err := r.Read(data); err == nil {
+       if _, err := r.Next(); err == nil {
                t.FailNow()
        }
 
-       b = bytes.NewBufferString("0:foobar")
+       b = bytes.NewBufferString("6:foobar")
        r = NewReader(b)
        if _, err := r.Next(); err != nil {
                t.FailNow()
@@ -123,12 +120,15 @@ func TestErrors(t *testing.T) {
                t.FailNow()
        }
 
-       b = bytes.NewBufferString("6:foobar")
+       b = bytes.NewBufferString(":foobar,")
        r = NewReader(b)
-       if _, err := r.Next(); err != nil {
+       if _, err := r.Next(); err == nil {
                t.FailNow()
        }
-       if _, err := r.Read(data); err == nil {
+
+       b = bytes.NewBufferString("06:foobar,")
+       r = NewReader(b)
+       if _, err := r.Next(); err == nil {
                t.FailNow()
        }
 }
@@ -158,7 +158,7 @@ func TestExample(t *testing.T) {
        if size, err := r.Next(); err != nil || size != 6 {
                t.FailNow()
        }
-       if data, err := ioutil.ReadAll(r); err != nil || string(data) != "world!" {
+       if data, err := io.ReadAll(r); err != nil || string(data) != "world!" {
                t.FailNow()
        }
 }