]> Cypherpunks.ru repositories - nncp.git/blob - src/pkt.go
ACK
[nncp.git] / src / pkt.go
1 /*
2 NNCP -- Node to Node copy, utilities for store-and-forward data exchange
3 Copyright (C) 2016-2022 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package nncp
19
20 import (
21         "bytes"
22         "crypto/cipher"
23         "crypto/rand"
24         "errors"
25         "io"
26
27         xdr "github.com/davecgh/go-xdr/xdr2"
28         "golang.org/x/crypto/chacha20poly1305"
29         "golang.org/x/crypto/curve25519"
30         "golang.org/x/crypto/ed25519"
31         "golang.org/x/crypto/nacl/box"
32         "golang.org/x/crypto/poly1305"
33         "lukechampine.com/blake3"
34 )
35
36 type PktType uint8
37
38 const (
39         EncBlkSize = 128 * (1 << 10)
40
41         PktTypeFile    PktType = iota
42         PktTypeFreq    PktType = iota
43         PktTypeExec    PktType = iota
44         PktTypeTrns    PktType = iota
45         PktTypeExecFat PktType = iota
46         PktTypeArea    PktType = iota
47         PktTypeACK     PktType = iota
48
49         MaxPathSize = 1<<8 - 1
50
51         NNCPBundlePrefix = "NNCP"
52 )
53
54 var (
55         BadPktType error = errors.New("Unknown packet type")
56
57         DeriveKeyFullCtx = string(MagicNNCPEv6.B[:]) + " FULL"
58         DeriveKeySizeCtx = string(MagicNNCPEv6.B[:]) + " SIZE"
59         DeriveKeyPadCtx  = string(MagicNNCPEv6.B[:]) + " PAD"
60
61         PktOverhead     int64
62         PktEncOverhead  int64
63         PktSizeOverhead int64
64
65         TooBig = errors.New("Too big than allowed")
66 )
67
68 type Pkt struct {
69         Magic   [8]byte
70         Type    PktType
71         Nice    uint8
72         PathLen uint8
73         Path    [MaxPathSize]byte
74 }
75
76 type PktTbs struct {
77         Magic     [8]byte
78         Nice      uint8
79         Sender    *NodeId
80         Recipient *NodeId
81         ExchPub   [32]byte
82 }
83
84 type PktEnc struct {
85         Magic     [8]byte
86         Nice      uint8
87         Sender    *NodeId
88         Recipient *NodeId
89         ExchPub   [32]byte
90         Sign      [ed25519.SignatureSize]byte
91 }
92
93 type PktSize struct {
94         Payload uint64
95         Pad     uint64
96 }
97
98 func NewPkt(typ PktType, nice uint8, path []byte) (*Pkt, error) {
99         if len(path) > MaxPathSize {
100                 return nil, errors.New("Too long path")
101         }
102         pkt := Pkt{
103                 Magic:   MagicNNCPPv3.B,
104                 Type:    typ,
105                 Nice:    nice,
106                 PathLen: uint8(len(path)),
107         }
108         copy(pkt.Path[:], path)
109         return &pkt, nil
110 }
111
112 func init() {
113         var buf bytes.Buffer
114         pkt := Pkt{Type: PktTypeFile}
115         n, err := xdr.Marshal(&buf, pkt)
116         if err != nil {
117                 panic(err)
118         }
119         PktOverhead = int64(n)
120         buf.Reset()
121
122         dummyId, err := NodeIdFromString(DummyB32Id)
123         if err != nil {
124                 panic(err)
125         }
126         pktEnc := PktEnc{
127                 Magic:     MagicNNCPEv6.B,
128                 Sender:    dummyId,
129                 Recipient: dummyId,
130         }
131         n, err = xdr.Marshal(&buf, pktEnc)
132         if err != nil {
133                 panic(err)
134         }
135         PktEncOverhead = int64(n)
136         buf.Reset()
137
138         size := PktSize{}
139         n, err = xdr.Marshal(&buf, size)
140         if err != nil {
141                 panic(err)
142         }
143         PktSizeOverhead = int64(n)
144 }
145
146 func ctrIncr(b []byte) {
147         for i := len(b) - 1; i >= 0; i-- {
148                 b[i]++
149                 if b[i] != 0 {
150                         return
151                 }
152         }
153         panic("counter overflow")
154 }
155
156 func TbsPrepare(our *NodeOur, their *Node, pktEnc *PktEnc) []byte {
157         tbs := PktTbs{
158                 Magic:     MagicNNCPEv6.B,
159                 Nice:      pktEnc.Nice,
160                 Sender:    their.Id,
161                 Recipient: our.Id,
162                 ExchPub:   pktEnc.ExchPub,
163         }
164         var tbsBuf bytes.Buffer
165         if _, err := xdr.Marshal(&tbsBuf, &tbs); err != nil {
166                 panic(err)
167         }
168         return tbsBuf.Bytes()
169 }
170
171 func TbsVerify(our *NodeOur, their *Node, pktEnc *PktEnc) ([]byte, bool, error) {
172         tbs := TbsPrepare(our, their, pktEnc)
173         return tbs, ed25519.Verify(their.SignPub, tbs, pktEnc.Sign[:]), nil
174 }
175
176 func sizeWithTags(size int64) (fullSize int64) {
177         size += PktSizeOverhead
178         fullSize = size + (size/EncBlkSize)*poly1305.TagSize
179         if size%EncBlkSize != 0 {
180                 fullSize += poly1305.TagSize
181         }
182         return
183 }
184
185 func sizePadCalc(sizePayload, minSize int64, wrappers int) (sizePad int64) {
186         expectedSize := sizePayload - PktOverhead
187         for i := 0; i < wrappers; i++ {
188                 expectedSize = PktEncOverhead + sizeWithTags(PktOverhead+expectedSize)
189         }
190         sizePad = minSize - expectedSize
191         if sizePad < 0 {
192                 sizePad = 0
193         }
194         return
195 }
196
197 func PktEncWrite(
198         our *NodeOur, their *Node,
199         pkt *Pkt, nice uint8,
200         minSize, maxSize int64, wrappers int,
201         r io.Reader, w io.Writer,
202 ) (pktEncRaw []byte, size int64, err error) {
203         pub, prv, err := box.GenerateKey(rand.Reader)
204         if err != nil {
205                 return nil, 0, err
206         }
207
208         var buf bytes.Buffer
209         _, err = xdr.Marshal(&buf, pkt)
210         if err != nil {
211                 return
212         }
213         pktRaw := make([]byte, buf.Len())
214         copy(pktRaw, buf.Bytes())
215         buf.Reset()
216
217         tbs := PktTbs{
218                 Magic:     MagicNNCPEv6.B,
219                 Nice:      nice,
220                 Sender:    our.Id,
221                 Recipient: their.Id,
222                 ExchPub:   *pub,
223         }
224         _, err = xdr.Marshal(&buf, &tbs)
225         if err != nil {
226                 return
227         }
228         signature := new([ed25519.SignatureSize]byte)
229         copy(signature[:], ed25519.Sign(our.SignPrv, buf.Bytes()))
230         ad := blake3.Sum256(buf.Bytes())
231         buf.Reset()
232
233         pktEnc := PktEnc{
234                 Magic:     MagicNNCPEv6.B,
235                 Nice:      nice,
236                 Sender:    our.Id,
237                 Recipient: their.Id,
238                 ExchPub:   *pub,
239                 Sign:      *signature,
240         }
241         _, err = xdr.Marshal(&buf, &pktEnc)
242         if err != nil {
243                 return
244         }
245         pktEncRaw = make([]byte, buf.Len())
246         copy(pktEncRaw, buf.Bytes())
247         buf.Reset()
248         _, err = w.Write(pktEncRaw)
249         if err != nil {
250                 return
251         }
252
253         sharedKey := new([32]byte)
254         curve25519.ScalarMult(sharedKey, prv, their.ExchPub)
255         keyFull := make([]byte, chacha20poly1305.KeySize)
256         keySize := make([]byte, chacha20poly1305.KeySize)
257         blake3.DeriveKey(keyFull, DeriveKeyFullCtx, sharedKey[:])
258         blake3.DeriveKey(keySize, DeriveKeySizeCtx, sharedKey[:])
259         aeadFull, err := chacha20poly1305.New(keyFull)
260         if err != nil {
261                 return
262         }
263         aeadSize, err := chacha20poly1305.New(keySize)
264         if err != nil {
265                 return
266         }
267         nonce := make([]byte, aeadFull.NonceSize())
268
269         data := make([]byte, EncBlkSize, EncBlkSize+aeadFull.Overhead())
270         mr := io.MultiReader(bytes.NewReader(pktRaw), r)
271         var sizePayload int64
272         var n int
273         var ct []byte
274         for {
275                 n, err = io.ReadFull(mr, data)
276                 sizePayload += int64(n)
277                 if sizePayload > maxSize {
278                         err = TooBig
279                         return
280                 }
281                 if err == nil {
282                         ct = aeadFull.Seal(data[:0], nonce, data[:n], ad[:])
283                         _, err = w.Write(ct)
284                         if err != nil {
285                                 return
286                         }
287                         ctrIncr(nonce)
288                         continue
289                 }
290                 if !(err == io.EOF || err == io.ErrUnexpectedEOF) {
291                         return
292                 }
293                 break
294         }
295
296         sizePad := sizePadCalc(sizePayload, minSize, wrappers)
297         _, err = xdr.Marshal(&buf, &PktSize{uint64(sizePayload), uint64(sizePad)})
298         if err != nil {
299                 return
300         }
301
302         var aeadLast cipher.AEAD
303         if n+int(PktSizeOverhead) > EncBlkSize {
304                 left := make([]byte, (n+int(PktSizeOverhead))-EncBlkSize)
305                 copy(left, data[n-len(left):])
306                 copy(data[PktSizeOverhead:], data[:n-len(left)])
307                 copy(data[:PktSizeOverhead], buf.Bytes())
308                 ct = aeadSize.Seal(data[:0], nonce, data[:EncBlkSize], ad[:])
309                 _, err = w.Write(ct)
310                 if err != nil {
311                         return
312                 }
313                 ctrIncr(nonce)
314                 copy(data, left)
315                 n = len(left)
316                 aeadLast = aeadFull
317         } else {
318                 copy(data[PktSizeOverhead:], data[:n])
319                 copy(data[:PktSizeOverhead], buf.Bytes())
320                 n += int(PktSizeOverhead)
321                 aeadLast = aeadSize
322         }
323
324         var sizeBlockPadded int
325         var sizePadLeft int64
326         if sizePad > EncBlkSize-int64(n) {
327                 sizeBlockPadded = EncBlkSize
328                 sizePadLeft = sizePad - (EncBlkSize - int64(n))
329         } else {
330                 sizeBlockPadded = n + int(sizePad)
331                 sizePadLeft = 0
332         }
333         for i := n; i < sizeBlockPadded; i++ {
334                 data[i] = 0
335         }
336         ct = aeadLast.Seal(data[:0], nonce, data[:sizeBlockPadded], ad[:])
337         _, err = w.Write(ct)
338         if err != nil {
339                 return
340         }
341
342         size = sizePayload
343         if sizePadLeft > 0 {
344                 keyPad := make([]byte, chacha20poly1305.KeySize)
345                 blake3.DeriveKey(keyPad, DeriveKeyPadCtx, sharedKey[:])
346                 _, err = io.CopyN(w, blake3.New(32, keyPad).XOF(), sizePadLeft)
347         }
348         return
349 }
350
351 func PktEncRead(
352         our *NodeOur, nodes map[NodeId]*Node,
353         r io.Reader, w io.Writer,
354         signatureVerify bool,
355         sharedKeyCached []byte,
356 ) (sharedKey []byte, their *Node, size int64, err error) {
357         var pktEnc PktEnc
358         _, err = xdr.Unmarshal(r, &pktEnc)
359         if err != nil {
360                 return
361         }
362         switch pktEnc.Magic {
363         case MagicNNCPEv1.B:
364                 err = MagicNNCPEv1.TooOld()
365         case MagicNNCPEv2.B:
366                 err = MagicNNCPEv2.TooOld()
367         case MagicNNCPEv3.B:
368                 err = MagicNNCPEv3.TooOld()
369         case MagicNNCPEv4.B:
370                 err = MagicNNCPEv4.TooOld()
371         case MagicNNCPEv5.B:
372                 err = MagicNNCPEv5.TooOld()
373         case MagicNNCPEv6.B:
374         default:
375                 err = BadMagic
376         }
377         if err != nil {
378                 return
379         }
380         if *pktEnc.Recipient != *our.Id {
381                 err = errors.New("Invalid recipient")
382                 return
383         }
384
385         var tbsRaw []byte
386         if signatureVerify {
387                 their = nodes[*pktEnc.Sender]
388                 if their == nil {
389                         err = errors.New("Unknown sender")
390                         return
391                 }
392                 var verified bool
393                 tbsRaw, verified, err = TbsVerify(our, their, &pktEnc)
394                 if err != nil {
395                         return
396                 }
397                 if !verified {
398                         err = errors.New("Invalid signature")
399                         return
400                 }
401         } else {
402                 tbsRaw = TbsPrepare(our, &Node{Id: pktEnc.Sender}, &pktEnc)
403         }
404         ad := blake3.Sum256(tbsRaw)
405         if sharedKeyCached == nil {
406                 key := new([32]byte)
407                 curve25519.ScalarMult(key, our.ExchPrv, &pktEnc.ExchPub)
408                 sharedKey = key[:]
409         } else {
410                 sharedKey = sharedKeyCached
411         }
412
413         keyFull := make([]byte, chacha20poly1305.KeySize)
414         keySize := make([]byte, chacha20poly1305.KeySize)
415         blake3.DeriveKey(keyFull, DeriveKeyFullCtx, sharedKey[:])
416         blake3.DeriveKey(keySize, DeriveKeySizeCtx, sharedKey[:])
417         aeadFull, err := chacha20poly1305.New(keyFull)
418         if err != nil {
419                 return
420         }
421         aeadSize, err := chacha20poly1305.New(keySize)
422         if err != nil {
423                 return
424         }
425         nonce := make([]byte, aeadFull.NonceSize())
426
427         ct := make([]byte, EncBlkSize+aeadFull.Overhead())
428         pt := make([]byte, EncBlkSize)
429         var n int
430 FullRead:
431         for {
432                 n, err = io.ReadFull(r, ct)
433                 switch err {
434                 case nil:
435                         pt, err = aeadFull.Open(pt[:0], nonce, ct, ad[:])
436                         if err != nil {
437                                 break FullRead
438                         }
439                         size += EncBlkSize
440                         _, err = w.Write(pt)
441                         if err != nil {
442                                 return
443                         }
444                         ctrIncr(nonce)
445                         continue
446                 case io.ErrUnexpectedEOF:
447                         break FullRead
448                 default:
449                         return
450                 }
451         }
452
453         pt, err = aeadSize.Open(pt[:0], nonce, ct[:n], ad[:])
454         if err != nil {
455                 return
456         }
457         var pktSize PktSize
458         _, err = xdr.Unmarshal(bytes.NewReader(pt), &pktSize)
459         if err != nil {
460                 return
461         }
462         pt = pt[PktSizeOverhead:]
463
464         left := int64(pktSize.Payload) - size
465         for left > int64(len(pt)) {
466                 size += int64(len(pt))
467                 left -= int64(len(pt))
468                 _, err = w.Write(pt)
469                 if err != nil {
470                         return
471                 }
472                 n, err = io.ReadFull(r, ct)
473                 if err != nil && err != io.ErrUnexpectedEOF {
474                         return
475                 }
476                 ctrIncr(nonce)
477                 pt, err = aeadFull.Open(pt[:0], nonce, ct[:n], ad[:])
478                 if err != nil {
479                         return
480                 }
481         }
482         size += left
483         _, err = w.Write(pt[:left])
484         if err != nil {
485                 return
486         }
487         pt = pt[left:]
488
489         if pktSize.Pad < uint64(len(pt)) {
490                 err = errors.New("unexpected pad")
491                 return
492         }
493         for i := 0; i < len(pt); i++ {
494                 if pt[i] != 0 {
495                         err = errors.New("non-zero pad byte")
496                         return
497                 }
498         }
499         sizePad := int64(pktSize.Pad) - int64(len(pt))
500         if sizePad == 0 {
501                 return
502         }
503
504         keyPad := make([]byte, chacha20poly1305.KeySize)
505         blake3.DeriveKey(keyPad, DeriveKeyPadCtx, sharedKey[:])
506         xof := blake3.New(32, keyPad).XOF()
507         pt = make([]byte, len(ct))
508         for sizePad > 0 {
509                 n, err = io.ReadFull(r, ct)
510                 if err != nil && err != io.ErrUnexpectedEOF {
511                         return
512                 }
513                 _, err = io.ReadFull(xof, pt[:n])
514                 if err != nil {
515                         panic(err)
516                 }
517                 if bytes.Compare(ct[:n], pt[:n]) != 0 {
518                         err = errors.New("wrong pad value")
519                         return
520                 }
521                 sizePad -= int64(n)
522         }
523         if sizePad < 0 {
524                 err = errors.New("excess pad")
525         }
526         return
527 }