]> Cypherpunks.ru repositories - nncp.git/blob - src/pkt.go
Update dependencies
[nncp.git] / src / pkt.go
1 /*
2 NNCP -- Node to Node copy, utilities for store-and-forward data exchange
3 Copyright (C) 2016-2023 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package nncp
19
20 import (
21         "bytes"
22         "crypto/cipher"
23         "crypto/rand"
24         "errors"
25         "io"
26
27         xdr "github.com/davecgh/go-xdr/xdr2"
28         "golang.org/x/crypto/chacha20poly1305"
29         "golang.org/x/crypto/curve25519"
30         "golang.org/x/crypto/ed25519"
31         "golang.org/x/crypto/nacl/box"
32         "golang.org/x/crypto/poly1305"
33         "lukechampine.com/blake3"
34 )
35
36 type PktType uint8
37
38 const (
39         EncBlkSize = 128 * (1 << 10)
40
41         PktTypeFile    PktType = iota
42         PktTypeFreq    PktType = iota
43         PktTypeExec    PktType = iota
44         PktTypeTrns    PktType = iota
45         PktTypeExecFat PktType = iota
46         PktTypeArea    PktType = iota
47         PktTypeACK     PktType = iota
48
49         MaxPathSize = 1<<8 - 1
50
51         NNCPBundlePrefix = "NNCP"
52 )
53
54 var (
55         BadPktType error = errors.New("Unknown packet type")
56
57         DeriveKeyFullCtx = string(MagicNNCPEv6.B[:]) + " FULL"
58         DeriveKeySizeCtx = string(MagicNNCPEv6.B[:]) + " SIZE"
59         DeriveKeyPadCtx  = string(MagicNNCPEv6.B[:]) + " PAD"
60
61         PktOverhead     int64
62         PktEncOverhead  int64
63         PktSizeOverhead int64
64
65         TooBig = errors.New("Too big than allowed")
66 )
67
68 type Pkt struct {
69         Magic   [8]byte
70         Type    PktType
71         Nice    uint8
72         PathLen uint8
73         Path    [MaxPathSize]byte
74 }
75
76 type PktTbs struct {
77         Magic     [8]byte
78         Nice      uint8
79         Sender    *NodeId
80         Recipient *NodeId
81         ExchPub   [32]byte
82 }
83
84 type PktEnc struct {
85         Magic     [8]byte
86         Nice      uint8
87         Sender    *NodeId
88         Recipient *NodeId
89         ExchPub   [32]byte
90         Sign      [ed25519.SignatureSize]byte
91 }
92
93 type PktSize struct {
94         Payload uint64
95         Pad     uint64
96 }
97
98 func NewPkt(typ PktType, nice uint8, path []byte) (*Pkt, error) {
99         if len(path) > MaxPathSize {
100                 return nil, errors.New("Too long path")
101         }
102         pkt := Pkt{
103                 Magic:   MagicNNCPPv3.B,
104                 Type:    typ,
105                 Nice:    nice,
106                 PathLen: uint8(len(path)),
107         }
108         copy(pkt.Path[:], path)
109         return &pkt, nil
110 }
111
112 func init() {
113         var buf bytes.Buffer
114         pkt := Pkt{Type: PktTypeFile}
115         n, err := xdr.Marshal(&buf, pkt)
116         if err != nil {
117                 panic(err)
118         }
119         PktOverhead = int64(n)
120         buf.Reset()
121
122         dummyId, err := NodeIdFromString(DummyB32Id)
123         if err != nil {
124                 panic(err)
125         }
126         pktEnc := PktEnc{
127                 Magic:     MagicNNCPEv6.B,
128                 Sender:    dummyId,
129                 Recipient: dummyId,
130         }
131         n, err = xdr.Marshal(&buf, pktEnc)
132         if err != nil {
133                 panic(err)
134         }
135         PktEncOverhead = int64(n)
136         buf.Reset()
137
138         size := PktSize{}
139         n, err = xdr.Marshal(&buf, size)
140         if err != nil {
141                 panic(err)
142         }
143         PktSizeOverhead = int64(n)
144 }
145
146 func ctrIncr(b []byte) {
147         for i := len(b) - 1; i >= 0; i-- {
148                 b[i]++
149                 if b[i] != 0 {
150                         return
151                 }
152         }
153         panic("counter overflow")
154 }
155
156 func TbsPrepare(our *NodeOur, their *Node, pktEnc *PktEnc) []byte {
157         tbs := PktTbs{
158                 Magic:     MagicNNCPEv6.B,
159                 Nice:      pktEnc.Nice,
160                 Sender:    their.Id,
161                 Recipient: our.Id,
162                 ExchPub:   pktEnc.ExchPub,
163         }
164         var tbsBuf bytes.Buffer
165         if _, err := xdr.Marshal(&tbsBuf, &tbs); err != nil {
166                 panic(err)
167         }
168         return tbsBuf.Bytes()
169 }
170
171 func TbsVerify(our *NodeOur, their *Node, pktEnc *PktEnc) ([]byte, bool, error) {
172         tbs := TbsPrepare(our, their, pktEnc)
173         return tbs, ed25519.Verify(their.SignPub, tbs, pktEnc.Sign[:]), nil
174 }
175
176 func sizeWithTags(size int64) (fullSize int64) {
177         size += PktSizeOverhead
178         fullSize = size + (size/EncBlkSize)*poly1305.TagSize
179         if size%EncBlkSize != 0 {
180                 fullSize += poly1305.TagSize
181         }
182         return
183 }
184
185 func sizePadCalc(sizePayload, minSize int64, wrappers int) (sizePad int64) {
186         expectedSize := sizePayload - PktOverhead
187         for i := 0; i < wrappers; i++ {
188                 expectedSize = PktEncOverhead + sizeWithTags(PktOverhead+expectedSize)
189         }
190         sizePad = minSize - expectedSize
191         if sizePad < 0 {
192                 sizePad = 0
193         }
194         return
195 }
196
197 func PktEncWrite(
198         our *NodeOur, their *Node,
199         pkt *Pkt, nice uint8,
200         minSize, maxSize int64, wrappers int,
201         r io.Reader, w io.Writer,
202 ) (pktEncRaw []byte, size int64, err error) {
203         pub, prv, err := box.GenerateKey(rand.Reader)
204         if err != nil {
205                 return nil, 0, err
206         }
207
208         var buf bytes.Buffer
209         _, err = xdr.Marshal(&buf, pkt)
210         if err != nil {
211                 return
212         }
213         pktRaw := make([]byte, buf.Len())
214         copy(pktRaw, buf.Bytes())
215         buf.Reset()
216
217         tbs := PktTbs{
218                 Magic:     MagicNNCPEv6.B,
219                 Nice:      nice,
220                 Sender:    our.Id,
221                 Recipient: their.Id,
222                 ExchPub:   *pub,
223         }
224         _, err = xdr.Marshal(&buf, &tbs)
225         if err != nil {
226                 return
227         }
228         signature := new([ed25519.SignatureSize]byte)
229         copy(signature[:], ed25519.Sign(our.SignPrv, buf.Bytes()))
230         ad := blake3.Sum256(buf.Bytes())
231         buf.Reset()
232
233         pktEnc := PktEnc{
234                 Magic:     MagicNNCPEv6.B,
235                 Nice:      nice,
236                 Sender:    our.Id,
237                 Recipient: their.Id,
238                 ExchPub:   *pub,
239                 Sign:      *signature,
240         }
241         _, err = xdr.Marshal(&buf, &pktEnc)
242         if err != nil {
243                 return
244         }
245         pktEncRaw = make([]byte, buf.Len())
246         copy(pktEncRaw, buf.Bytes())
247         buf.Reset()
248         _, err = w.Write(pktEncRaw)
249         if err != nil {
250                 return
251         }
252
253         sharedKey, err := curve25519.X25519(prv[:], their.ExchPub[:])
254         if err != nil {
255                 return
256         }
257         keyFull := make([]byte, chacha20poly1305.KeySize)
258         keySize := make([]byte, chacha20poly1305.KeySize)
259         blake3.DeriveKey(keyFull, DeriveKeyFullCtx, sharedKey)
260         blake3.DeriveKey(keySize, DeriveKeySizeCtx, sharedKey)
261         aeadFull, err := chacha20poly1305.New(keyFull)
262         if err != nil {
263                 return
264         }
265         aeadSize, err := chacha20poly1305.New(keySize)
266         if err != nil {
267                 return
268         }
269         nonce := make([]byte, aeadFull.NonceSize())
270
271         data := make([]byte, EncBlkSize, EncBlkSize+aeadFull.Overhead())
272         mr := io.MultiReader(bytes.NewReader(pktRaw), r)
273         var sizePayload int64
274         var n int
275         var ct []byte
276         for {
277                 n, err = io.ReadFull(mr, data)
278                 sizePayload += int64(n)
279                 if sizePayload > maxSize {
280                         err = TooBig
281                         return
282                 }
283                 if err == nil {
284                         ct = aeadFull.Seal(data[:0], nonce, data[:n], ad[:])
285                         _, err = w.Write(ct)
286                         if err != nil {
287                                 return
288                         }
289                         ctrIncr(nonce)
290                         continue
291                 }
292                 if !(err == io.EOF || err == io.ErrUnexpectedEOF) {
293                         return
294                 }
295                 break
296         }
297
298         sizePad := sizePadCalc(sizePayload, minSize, wrappers)
299         _, err = xdr.Marshal(&buf, &PktSize{uint64(sizePayload), uint64(sizePad)})
300         if err != nil {
301                 return
302         }
303
304         var aeadLast cipher.AEAD
305         if n+int(PktSizeOverhead) > EncBlkSize {
306                 left := make([]byte, (n+int(PktSizeOverhead))-EncBlkSize)
307                 copy(left, data[n-len(left):])
308                 copy(data[PktSizeOverhead:], data[:n-len(left)])
309                 copy(data[:PktSizeOverhead], buf.Bytes())
310                 ct = aeadSize.Seal(data[:0], nonce, data[:EncBlkSize], ad[:])
311                 _, err = w.Write(ct)
312                 if err != nil {
313                         return
314                 }
315                 ctrIncr(nonce)
316                 copy(data, left)
317                 n = len(left)
318                 aeadLast = aeadFull
319         } else {
320                 copy(data[PktSizeOverhead:], data[:n])
321                 copy(data[:PktSizeOverhead], buf.Bytes())
322                 n += int(PktSizeOverhead)
323                 aeadLast = aeadSize
324         }
325
326         var sizeBlockPadded int
327         var sizePadLeft int64
328         if sizePad > EncBlkSize-int64(n) {
329                 sizeBlockPadded = EncBlkSize
330                 sizePadLeft = sizePad - (EncBlkSize - int64(n))
331         } else {
332                 sizeBlockPadded = n + int(sizePad)
333                 sizePadLeft = 0
334         }
335         for i := n; i < sizeBlockPadded; i++ {
336                 data[i] = 0
337         }
338         ct = aeadLast.Seal(data[:0], nonce, data[:sizeBlockPadded], ad[:])
339         _, err = w.Write(ct)
340         if err != nil {
341                 return
342         }
343
344         size = sizePayload
345         if sizePadLeft > 0 {
346                 keyPad := make([]byte, chacha20poly1305.KeySize)
347                 blake3.DeriveKey(keyPad, DeriveKeyPadCtx, sharedKey[:])
348                 _, err = io.CopyN(w, blake3.New(32, keyPad).XOF(), sizePadLeft)
349         }
350         return
351 }
352
353 func PktEncRead(
354         our *NodeOur, nodes map[NodeId]*Node,
355         r io.Reader, w io.Writer,
356         signatureVerify bool,
357         sharedKeyCached []byte,
358 ) (sharedKey []byte, their *Node, size int64, err error) {
359         var pktEnc PktEnc
360         _, err = xdr.Unmarshal(r, &pktEnc)
361         if err != nil {
362                 return
363         }
364         switch pktEnc.Magic {
365         case MagicNNCPEv1.B:
366                 err = MagicNNCPEv1.TooOld()
367         case MagicNNCPEv2.B:
368                 err = MagicNNCPEv2.TooOld()
369         case MagicNNCPEv3.B:
370                 err = MagicNNCPEv3.TooOld()
371         case MagicNNCPEv4.B:
372                 err = MagicNNCPEv4.TooOld()
373         case MagicNNCPEv5.B:
374                 err = MagicNNCPEv5.TooOld()
375         case MagicNNCPEv6.B:
376         default:
377                 err = BadMagic
378         }
379         if err != nil {
380                 return
381         }
382         if *pktEnc.Recipient != *our.Id {
383                 err = errors.New("Invalid recipient")
384                 return
385         }
386
387         var tbsRaw []byte
388         if signatureVerify {
389                 their = nodes[*pktEnc.Sender]
390                 if their == nil {
391                         err = errors.New("Unknown sender")
392                         return
393                 }
394                 var verified bool
395                 tbsRaw, verified, err = TbsVerify(our, their, &pktEnc)
396                 if err != nil {
397                         return
398                 }
399                 if !verified {
400                         err = errors.New("Invalid signature")
401                         return
402                 }
403         } else {
404                 tbsRaw = TbsPrepare(our, &Node{Id: pktEnc.Sender}, &pktEnc)
405         }
406         ad := blake3.Sum256(tbsRaw)
407         if sharedKeyCached == nil {
408                 var key []byte
409                 key, err = curve25519.X25519(our.ExchPrv[:], pktEnc.ExchPub[:])
410                 if err != nil {
411                         return
412                 }
413                 sharedKey = key[:]
414         } else {
415                 sharedKey = sharedKeyCached
416         }
417
418         keyFull := make([]byte, chacha20poly1305.KeySize)
419         keySize := make([]byte, chacha20poly1305.KeySize)
420         blake3.DeriveKey(keyFull, DeriveKeyFullCtx, sharedKey[:])
421         blake3.DeriveKey(keySize, DeriveKeySizeCtx, sharedKey[:])
422         aeadFull, err := chacha20poly1305.New(keyFull)
423         if err != nil {
424                 return
425         }
426         aeadSize, err := chacha20poly1305.New(keySize)
427         if err != nil {
428                 return
429         }
430         nonce := make([]byte, aeadFull.NonceSize())
431
432         ct := make([]byte, EncBlkSize+aeadFull.Overhead())
433         pt := make([]byte, EncBlkSize)
434         var n int
435 FullRead:
436         for {
437                 n, err = io.ReadFull(r, ct)
438                 switch err {
439                 case nil:
440                         pt, err = aeadFull.Open(pt[:0], nonce, ct, ad[:])
441                         if err != nil {
442                                 break FullRead
443                         }
444                         size += EncBlkSize
445                         _, err = w.Write(pt)
446                         if err != nil {
447                                 return
448                         }
449                         ctrIncr(nonce)
450                         continue
451                 case io.ErrUnexpectedEOF:
452                         break FullRead
453                 default:
454                         return
455                 }
456         }
457
458         pt, err = aeadSize.Open(pt[:0], nonce, ct[:n], ad[:])
459         if err != nil {
460                 return
461         }
462         var pktSize PktSize
463         _, err = xdr.Unmarshal(bytes.NewReader(pt), &pktSize)
464         if err != nil {
465                 return
466         }
467         pt = pt[PktSizeOverhead:]
468
469         left := int64(pktSize.Payload) - size
470         for left > int64(len(pt)) {
471                 size += int64(len(pt))
472                 left -= int64(len(pt))
473                 _, err = w.Write(pt)
474                 if err != nil {
475                         return
476                 }
477                 n, err = io.ReadFull(r, ct)
478                 if err != nil && err != io.ErrUnexpectedEOF {
479                         return
480                 }
481                 ctrIncr(nonce)
482                 pt, err = aeadFull.Open(pt[:0], nonce, ct[:n], ad[:])
483                 if err != nil {
484                         return
485                 }
486         }
487         size += left
488         _, err = w.Write(pt[:left])
489         if err != nil {
490                 return
491         }
492         pt = pt[left:]
493
494         if pktSize.Pad < uint64(len(pt)) {
495                 err = errors.New("unexpected pad")
496                 return
497         }
498         for i := 0; i < len(pt); i++ {
499                 if pt[i] != 0 {
500                         err = errors.New("non-zero pad byte")
501                         return
502                 }
503         }
504         sizePad := int64(pktSize.Pad) - int64(len(pt))
505         if sizePad == 0 {
506                 return
507         }
508
509         keyPad := make([]byte, chacha20poly1305.KeySize)
510         blake3.DeriveKey(keyPad, DeriveKeyPadCtx, sharedKey[:])
511         xof := blake3.New(32, keyPad).XOF()
512         pt = make([]byte, len(ct))
513         for sizePad > 0 {
514                 n, err = io.ReadFull(r, ct)
515                 if err != nil && err != io.ErrUnexpectedEOF {
516                         return
517                 }
518                 _, err = io.ReadFull(xof, pt[:n])
519                 if err != nil {
520                         panic(err)
521                 }
522                 if !bytes.Equal(ct[:n], pt[:n]) {
523                         err = errors.New("wrong pad value")
524                         return
525                 }
526                 sizePad -= int64(n)
527         }
528         if sizePad < 0 {
529                 err = errors.New("excess pad")
530         }
531         return
532 }