]> Cypherpunks.ru repositories - nncp.git/blobdiff - src/pkt.go
Streamed NNCPE format
[nncp.git] / src / pkt.go
index aa3025d988b40cc73e41436a84deb36a7a3337c3..a1993e7a8ed98334e2c96dc14e22cc28d73235a1 100644 (file)
@@ -21,7 +21,6 @@ import (
        "bytes"
        "crypto/cipher"
        "crypto/rand"
-       "encoding/binary"
        "errors"
        "io"
 
@@ -49,15 +48,20 @@ const (
        MaxPathSize = 1<<8 - 1
 
        NNCPBundlePrefix = "NNCP"
-
-       PktSizeOverhead = 8 + poly1305.TagSize
 )
 
 var (
        BadPktType error = errors.New("Unknown packet type")
 
-       PktOverhead    int64
-       PktEncOverhead int64
+       DeriveKeyFullCtx = string(MagicNNCPEv6.B[:]) + " FULL"
+       DeriveKeySizeCtx = string(MagicNNCPEv6.B[:]) + " SIZE"
+       DeriveKeyPadCtx  = string(MagicNNCPEv6.B[:]) + " PAD"
+
+       PktOverhead     int64
+       PktEncOverhead  int64
+       PktSizeOverhead int64
+
+       TooBig = errors.New("Too big than allowed")
 )
 
 type Pkt struct {
@@ -85,9 +89,28 @@ type PktEnc struct {
        Sign      [ed25519.SignatureSize]byte
 }
 
+type PktSize struct {
+       Payload uint64
+       Pad     uint64
+}
+
+func NewPkt(typ PktType, nice uint8, path []byte) (*Pkt, error) {
+       if len(path) > MaxPathSize {
+               return nil, errors.New("Too long path")
+       }
+       pkt := Pkt{
+               Magic:   MagicNNCPPv3.B,
+               Type:    typ,
+               Nice:    nice,
+               PathLen: uint8(len(path)),
+       }
+       copy(pkt.Path[:], path)
+       return &pkt, nil
+}
+
 func init() {
-       pkt := Pkt{Type: PktTypeFile}
        var buf bytes.Buffer
+       pkt := Pkt{Type: PktTypeFile}
        n, err := xdr.Marshal(&buf, pkt)
        if err != nil {
                panic(err)
@@ -100,7 +123,7 @@ func init() {
                panic(err)
        }
        pktEnc := PktEnc{
-               Magic:     MagicNNCPEv5.B,
+               Magic:     MagicNNCPEv6.B,
                Sender:    dummyId,
                Recipient: dummyId,
        }
@@ -109,20 +132,14 @@ func init() {
                panic(err)
        }
        PktEncOverhead = int64(n)
-}
+       buf.Reset()
 
-func NewPkt(typ PktType, nice uint8, path []byte) (*Pkt, error) {
-       if len(path) > MaxPathSize {
-               return nil, errors.New("Too long path")
-       }
-       pkt := Pkt{
-               Magic:   MagicNNCPPv3.B,
-               Type:    typ,
-               Nice:    nice,
-               PathLen: uint8(len(path)),
+       size := PktSize{}
+       n, err = xdr.Marshal(&buf, size)
+       if err != nil {
+               panic(err)
        }
-       copy(pkt.Path[:], path)
-       return &pkt, nil
+       PktSizeOverhead = int64(n)
 }
 
 func ctrIncr(b []byte) {
@@ -135,53 +152,28 @@ func ctrIncr(b []byte) {
        panic("counter overflow")
 }
 
-func aeadProcess(
-       aead cipher.AEAD,
-       nonce, ad []byte,
-       doEncrypt bool,
-       r io.Reader,
-       w io.Writer,
-) (int64, error) {
-       ciphCtr := nonce[len(nonce)-8:]
-       buf := make([]byte, EncBlkSize+aead.Overhead())
-       var toRead []byte
-       var toWrite []byte
-       var n int
-       var readBytes int64
-       var err error
-       if doEncrypt {
-               toRead = buf[:EncBlkSize]
-       } else {
-               toRead = buf
+func TbsPrepare(our *NodeOur, their *Node, pktEnc *PktEnc) []byte {
+       tbs := PktTbs{
+               Magic:     MagicNNCPEv6.B,
+               Nice:      pktEnc.Nice,
+               Sender:    their.Id,
+               Recipient: our.Id,
+               ExchPub:   pktEnc.ExchPub,
        }
-       for {
-               n, err = io.ReadFull(r, toRead)
-               if err != nil {
-                       if err == io.EOF {
-                               break
-                       }
-                       if err != io.ErrUnexpectedEOF {
-                               return readBytes + int64(n), err
-                       }
-               }
-               readBytes += int64(n)
-               ctrIncr(ciphCtr)
-               if doEncrypt {
-                       toWrite = aead.Seal(buf[:0], nonce, buf[:n], ad)
-               } else {
-                       toWrite, err = aead.Open(buf[:0], nonce, buf[:n], ad)
-                       if err != nil {
-                               return readBytes, err
-                       }
-               }
-               if _, err = w.Write(toWrite); err != nil {
-                       return readBytes, err
-               }
+       var tbsBuf bytes.Buffer
+       if _, err := xdr.Marshal(&tbsBuf, &tbs); err != nil {
+               panic(err)
        }
-       return readBytes, nil
+       return tbsBuf.Bytes()
+}
+
+func TbsVerify(our *NodeOur, their *Node, pktEnc *PktEnc) ([]byte, bool, error) {
+       tbs := TbsPrepare(our, their, pktEnc)
+       return tbs, ed25519.Verify(their.SignPub, tbs, pktEnc.Sign[:]), nil
 }
 
 func sizeWithTags(size int64) (fullSize int64) {
+       size += PktSizeOverhead
        fullSize = size + (size/EncBlkSize)*poly1305.TagSize
        if size%EncBlkSize != 0 {
                fullSize += poly1305.TagSize
@@ -189,122 +181,182 @@ func sizeWithTags(size int64) (fullSize int64) {
        return
 }
 
+func sizePadCalc(sizePayload, minSize int64, wrappers int) (sizePad int64) {
+       expectedSize := sizePayload - PktOverhead
+       for i := 0; i < wrappers; i++ {
+               expectedSize = PktEncOverhead + sizeWithTags(PktOverhead+expectedSize)
+       }
+       sizePad = minSize - expectedSize
+       if sizePad < 0 {
+               sizePad = 0
+       }
+       return
+}
+
 func PktEncWrite(
-       our *NodeOur,
-       their *Node,
-       pkt *Pkt,
-       nice uint8,
-       size, padSize int64,
-       data io.Reader,
-       out io.Writer,
-) ([]byte, error) {
-       pubEph, prvEph, err := box.GenerateKey(rand.Reader)
+       our *NodeOur, their *Node,
+       pkt *Pkt, nice uint8,
+       minSize, maxSize int64, wrappers int,
+       r io.Reader, w io.Writer,
+) (pktEncRaw []byte, size int64, err error) {
+       pub, prv, err := box.GenerateKey(rand.Reader)
        if err != nil {
-               return nil, err
+               return nil, 0, err
        }
-       var pktBuf bytes.Buffer
-       if _, err := xdr.Marshal(&pktBuf, pkt); err != nil {
-               return nil, err
+
+       var buf bytes.Buffer
+       _, err = xdr.Marshal(&buf, pkt)
+       if err != nil {
+               return
        }
+       pktRaw := make([]byte, buf.Len())
+       copy(pktRaw, buf.Bytes())
+       buf.Reset()
+
        tbs := PktTbs{
-               Magic:     MagicNNCPEv5.B,
+               Magic:     MagicNNCPEv6.B,
                Nice:      nice,
                Sender:    our.Id,
                Recipient: their.Id,
-               ExchPub:   *pubEph,
+               ExchPub:   *pub,
        }
-       var tbsBuf bytes.Buffer
-       if _, err = xdr.Marshal(&tbsBuf, &tbs); err != nil {
-               return nil, err
+       _, err = xdr.Marshal(&buf, &tbs)
+       if err != nil {
+               return
        }
        signature := new([ed25519.SignatureSize]byte)
-       copy(signature[:], ed25519.Sign(our.SignPrv, tbsBuf.Bytes()))
+       copy(signature[:], ed25519.Sign(our.SignPrv, buf.Bytes()))
+       ad := blake3.Sum256(buf.Bytes())
+       buf.Reset()
+
        pktEnc := PktEnc{
-               Magic:     MagicNNCPEv5.B,
+               Magic:     MagicNNCPEv6.B,
                Nice:      nice,
                Sender:    our.Id,
                Recipient: their.Id,
-               ExchPub:   *pubEph,
+               ExchPub:   *pub,
                Sign:      *signature,
        }
-       ad := blake3.Sum256(tbsBuf.Bytes())
-       tbsBuf.Reset()
-       if _, err = xdr.Marshal(&tbsBuf, &pktEnc); err != nil {
-               return nil, err
+       _, err = xdr.Marshal(&buf, &pktEnc)
+       if err != nil {
+               return
        }
-       pktEncRaw := tbsBuf.Bytes()
-       if _, err = out.Write(pktEncRaw); err != nil {
-               return nil, err
+       pktEncRaw = make([]byte, buf.Len())
+       copy(pktEncRaw, buf.Bytes())
+       buf.Reset()
+       _, err = w.Write(pktEncRaw)
+       if err != nil {
+               return
        }
-       sharedKey := new([32]byte)
-       curve25519.ScalarMult(sharedKey, prvEph, their.ExchPub)
 
-       key := make([]byte, chacha20poly1305.KeySize)
-       blake3.DeriveKey(key, string(MagicNNCPEv5.B[:]), sharedKey[:])
-       aead, err := chacha20poly1305.New(key)
+       sharedKey := new([32]byte)
+       curve25519.ScalarMult(sharedKey, prv, their.ExchPub)
+       keyFull := make([]byte, chacha20poly1305.KeySize)
+       keySize := make([]byte, chacha20poly1305.KeySize)
+       blake3.DeriveKey(keyFull, DeriveKeyFullCtx, sharedKey[:])
+       blake3.DeriveKey(keySize, DeriveKeySizeCtx, sharedKey[:])
+       aeadFull, err := chacha20poly1305.New(keyFull)
+       if err != nil {
+               return
+       }
+       aeadSize, err := chacha20poly1305.New(keySize)
        if err != nil {
-               return nil, err
+               return
        }
-       nonce := make([]byte, aead.NonceSize())
+       nonce := make([]byte, aeadFull.NonceSize())
 
-       fullSize := int64(pktBuf.Len()) + size
-       sizeBuf := make([]byte, 8+aead.Overhead())
-       binary.BigEndian.PutUint64(sizeBuf, uint64(sizeWithTags(fullSize)))
-       if _, err = out.Write(aead.Seal(sizeBuf[:0], nonce, sizeBuf[:8], ad[:])); err != nil {
-               return nil, err
+       data := make([]byte, EncBlkSize, EncBlkSize+aeadFull.Overhead())
+       mr := io.MultiReader(bytes.NewReader(pktRaw), r)
+       var sizePayload int64
+       var n int
+       var ct []byte
+       for {
+               n, err = io.ReadFull(mr, data)
+               sizePayload += int64(n)
+               if sizePayload > maxSize {
+                       err = TooBig
+                       return
+               }
+               if err == nil {
+                       ct = aeadFull.Seal(data[:0], nonce, data[:n], ad[:])
+                       _, err = w.Write(ct)
+                       if err != nil {
+                               return
+                       }
+                       ctrIncr(nonce)
+                       continue
+               }
+               if !(err == io.EOF || err == io.ErrUnexpectedEOF) {
+                       return
+               }
+               break
        }
 
-       lr := io.LimitedReader{R: data, N: size}
-       mr := io.MultiReader(&pktBuf, &lr)
-       written, err := aeadProcess(aead, nonce, ad[:], true, mr, out)
+       sizePad := sizePadCalc(sizePayload, minSize, wrappers)
+       _, err = xdr.Marshal(&buf, &PktSize{uint64(sizePayload), uint64(sizePad)})
        if err != nil {
-               return nil, err
-       }
-       if written != fullSize {
-               return nil, io.ErrUnexpectedEOF
+               return
        }
-       if padSize > 0 {
-               blake3.DeriveKey(key, string(MagicNNCPEv5.B[:])+" PAD", sharedKey[:])
-               xof := blake3.New(32, key).XOF()
-               if _, err = io.CopyN(out, xof, padSize); err != nil {
-                       return nil, err
+
+       var aeadLast cipher.AEAD
+       if n+int(PktSizeOverhead) > EncBlkSize {
+               left := make([]byte, (n+int(PktSizeOverhead))-EncBlkSize)
+               copy(left, data[n-len(left):])
+               copy(data[PktSizeOverhead:], data[:n-len(left)])
+               copy(data[:PktSizeOverhead], buf.Bytes())
+               ct = aeadSize.Seal(data[:0], nonce, data[:EncBlkSize], ad[:])
+               _, err = w.Write(ct)
+               if err != nil {
+                       return
                }
+               ctrIncr(nonce)
+               copy(data, left)
+               n = len(left)
+               aeadLast = aeadFull
+       } else {
+               copy(data[PktSizeOverhead:], data[:n])
+               copy(data[:PktSizeOverhead], buf.Bytes())
+               n += int(PktSizeOverhead)
+               aeadLast = aeadSize
        }
-       return pktEncRaw, nil
-}
 
-func TbsPrepare(our *NodeOur, their *Node, pktEnc *PktEnc) []byte {
-       tbs := PktTbs{
-               Magic:     MagicNNCPEv5.B,
-               Nice:      pktEnc.Nice,
-               Sender:    their.Id,
-               Recipient: our.Id,
-               ExchPub:   pktEnc.ExchPub,
+       var sizeBlockPadded int
+       var sizePadLeft int64
+       if sizePad > EncBlkSize-int64(n) {
+               sizeBlockPadded = EncBlkSize
+               sizePadLeft = sizePad - (EncBlkSize - int64(n))
+       } else {
+               sizeBlockPadded = n + int(sizePad)
+               sizePadLeft = 0
        }
-       var tbsBuf bytes.Buffer
-       if _, err := xdr.Marshal(&tbsBuf, &tbs); err != nil {
-               panic(err)
+       for i := n; i < sizeBlockPadded; i++ {
+               data[i] = 0
+       }
+       ct = aeadLast.Seal(data[:0], nonce, data[:sizeBlockPadded], ad[:])
+       _, err = w.Write(ct)
+       if err != nil {
+               return
        }
-       return tbsBuf.Bytes()
-}
 
-func TbsVerify(our *NodeOur, their *Node, pktEnc *PktEnc) ([]byte, bool, error) {
-       tbs := TbsPrepare(our, their, pktEnc)
-       return tbs, ed25519.Verify(their.SignPub, tbs, pktEnc.Sign[:]), nil
+       size = sizePayload
+       if sizePadLeft > 0 {
+               keyPad := make([]byte, chacha20poly1305.KeySize)
+               blake3.DeriveKey(keyPad, DeriveKeyPadCtx, sharedKey[:])
+               _, err = io.CopyN(w, blake3.New(32, keyPad).XOF(), sizePadLeft)
+       }
+       return
 }
 
 func PktEncRead(
-       our *NodeOur,
-       nodes map[NodeId]*Node,
-       data io.Reader,
-       out io.Writer,
+       our *NodeOur, nodes map[NodeId]*Node,
+       r io.Reader, w io.Writer,
        signatureVerify bool,
        sharedKeyCached []byte,
-) ([]byte, *Node, int64, error) {
+) (sharedKey []byte, their *Node, size int64, err error) {
        var pktEnc PktEnc
-       _, err := xdr.Unmarshal(data, &pktEnc)
+       _, err = xdr.Unmarshal(r, &pktEnc)
        if err != nil {
-               return nil, nil, 0, err
+               return
        }
        switch pktEnc.Magic {
        case MagicNNCPEv1.B:
@@ -316,66 +368,159 @@ func PktEncRead(
        case MagicNNCPEv4.B:
                err = MagicNNCPEv4.TooOld()
        case MagicNNCPEv5.B:
+               err = MagicNNCPEv5.TooOld()
+       case MagicNNCPEv6.B:
        default:
                err = BadMagic
        }
        if err != nil {
-               return nil, nil, 0, err
+               return
        }
        if *pktEnc.Recipient != *our.Id {
-               return nil, nil, 0, errors.New("Invalid recipient")
+               err = errors.New("Invalid recipient")
+               return
        }
+
        var tbsRaw []byte
-       var their *Node
        if signatureVerify {
                their = nodes[*pktEnc.Sender]
                if their == nil {
-                       return nil, nil, 0, errors.New("Unknown sender")
+                       err = errors.New("Unknown sender")
+                       return
                }
                var verified bool
                tbsRaw, verified, err = TbsVerify(our, their, &pktEnc)
                if err != nil {
-                       return nil, nil, 0, err
+                       return
                }
                if !verified {
-                       return nil, their, 0, errors.New("Invalid signature")
+                       err = errors.New("Invalid signature")
+                       return
                }
        } else {
                tbsRaw = TbsPrepare(our, &Node{Id: pktEnc.Sender}, &pktEnc)
        }
        ad := blake3.Sum256(tbsRaw)
-       sharedKey := new([32]byte)
        if sharedKeyCached == nil {
-               curve25519.ScalarMult(sharedKey, our.ExchPrv, &pktEnc.ExchPub)
+               key := new([32]byte)
+               curve25519.ScalarMult(key, our.ExchPrv, &pktEnc.ExchPub)
+               sharedKey = key[:]
        } else {
-               copy(sharedKey[:], sharedKeyCached)
+               sharedKey = sharedKeyCached
        }
 
-       key := make([]byte, chacha20poly1305.KeySize)
-       blake3.DeriveKey(key, string(MagicNNCPEv5.B[:]), sharedKey[:])
-       aead, err := chacha20poly1305.New(key)
+       keyFull := make([]byte, chacha20poly1305.KeySize)
+       keySize := make([]byte, chacha20poly1305.KeySize)
+       blake3.DeriveKey(keyFull, DeriveKeyFullCtx, sharedKey[:])
+       blake3.DeriveKey(keySize, DeriveKeySizeCtx, sharedKey[:])
+       aeadFull, err := chacha20poly1305.New(keyFull)
        if err != nil {
-               return sharedKey[:], their, 0, err
+               return
+       }
+       aeadSize, err := chacha20poly1305.New(keySize)
+       if err != nil {
+               return
+       }
+       nonce := make([]byte, aeadFull.NonceSize())
+
+       ct := make([]byte, EncBlkSize+aeadFull.Overhead())
+       pt := make([]byte, EncBlkSize)
+       var n int
+FullRead:
+       for {
+               n, err = io.ReadFull(r, ct)
+               switch err {
+               case nil:
+                       pt, err = aeadFull.Open(pt[:0], nonce, ct, ad[:])
+                       if err != nil {
+                               break FullRead
+                       }
+                       size += EncBlkSize
+                       _, err = w.Write(pt)
+                       if err != nil {
+                               return
+                       }
+                       ctrIncr(nonce)
+                       continue
+               case io.ErrUnexpectedEOF:
+                       break FullRead
+               default:
+                       return
+               }
        }
-       nonce := make([]byte, aead.NonceSize())
 
-       sizeBuf := make([]byte, 8+aead.Overhead())
-       if _, err = io.ReadFull(data, sizeBuf); err != nil {
-               return sharedKey[:], their, 0, err
+       pt, err = aeadSize.Open(pt[:0], nonce, ct[:n], ad[:])
+       if err != nil {
+               return
        }
-       sizeBuf, err = aead.Open(sizeBuf[:0], nonce, sizeBuf, ad[:])
+       var pktSize PktSize
+       _, err = xdr.Unmarshal(bytes.NewReader(pt), &pktSize)
        if err != nil {
-               return sharedKey[:], their, 0, err
+               return
        }
-       size := int64(binary.BigEndian.Uint64(sizeBuf))
+       pt = pt[PktSizeOverhead:]
 
-       lr := io.LimitedReader{R: data, N: size}
-       written, err := aeadProcess(aead, nonce, ad[:], false, &lr, out)
+       left := int64(pktSize.Payload) - size
+       for left > int64(len(pt)) {
+               size += int64(len(pt))
+               left -= int64(len(pt))
+               _, err = w.Write(pt)
+               if err != nil {
+                       return
+               }
+               n, err = io.ReadFull(r, ct)
+               if err != nil && err != io.ErrUnexpectedEOF {
+                       return
+               }
+               ctrIncr(nonce)
+               pt, err = aeadFull.Open(pt[:0], nonce, ct[:n], ad[:])
+               if err != nil {
+                       return
+               }
+       }
+       size += left
+       _, err = w.Write(pt[:left])
        if err != nil {
-               return sharedKey[:], their, written, err
+               return
+       }
+       pt = pt[left:]
+
+       if pktSize.Pad < uint64(len(pt)) {
+               err = errors.New("unexpected pad")
+               return
        }
-       if written != size {
-               return sharedKey[:], their, written, io.ErrUnexpectedEOF
+       for i := 0; i < len(pt); i++ {
+               if pt[i] != 0 {
+                       err = errors.New("non-zero pad byte")
+                       return
+               }
        }
-       return sharedKey[:], their, size, nil
+       sizePad := int64(pktSize.Pad) - int64(len(pt))
+       if sizePad == 0 {
+               return
+       }
+
+       keyPad := make([]byte, chacha20poly1305.KeySize)
+       blake3.DeriveKey(keyPad, DeriveKeyPadCtx, sharedKey[:])
+       xof := blake3.New(32, keyPad).XOF()
+       pt = make([]byte, len(ct))
+       for sizePad > 0 {
+               n, err = io.ReadFull(r, ct)
+               if err != nil && err != io.ErrUnexpectedEOF {
+                       return
+               }
+               _, err = io.ReadFull(xof, pt[:n])
+               if err != nil {
+                       panic(err)
+               }
+               if bytes.Compare(ct[:n], pt[:n]) != 0 {
+                       err = errors.New("wrong pad value")
+                       return
+               }
+               sizePad -= int64(n)
+       }
+       if sizePad < 0 {
+               err = errors.New("excess pad")
+       }
+       return
 }