]> Cypherpunks.ru repositories - nncp.git/blob - src/pkt.go
Raise copyright years
[nncp.git] / src / pkt.go
1 /*
2 NNCP -- Node to Node copy, utilities for store-and-forward data exchange
3 Copyright (C) 2016-2022 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package nncp
19
20 import (
21         "bytes"
22         "crypto/cipher"
23         "crypto/rand"
24         "errors"
25         "io"
26
27         xdr "github.com/davecgh/go-xdr/xdr2"
28         "golang.org/x/crypto/chacha20poly1305"
29         "golang.org/x/crypto/curve25519"
30         "golang.org/x/crypto/ed25519"
31         "golang.org/x/crypto/nacl/box"
32         "golang.org/x/crypto/poly1305"
33         "lukechampine.com/blake3"
34 )
35
36 type PktType uint8
37
38 const (
39         EncBlkSize = 128 * (1 << 10)
40
41         PktTypeFile    PktType = iota
42         PktTypeFreq    PktType = iota
43         PktTypeExec    PktType = iota
44         PktTypeTrns    PktType = iota
45         PktTypeExecFat PktType = iota
46         PktTypeArea    PktType = iota
47
48         MaxPathSize = 1<<8 - 1
49
50         NNCPBundlePrefix = "NNCP"
51 )
52
53 var (
54         BadPktType error = errors.New("Unknown packet type")
55
56         DeriveKeyFullCtx = string(MagicNNCPEv6.B[:]) + " FULL"
57         DeriveKeySizeCtx = string(MagicNNCPEv6.B[:]) + " SIZE"
58         DeriveKeyPadCtx  = string(MagicNNCPEv6.B[:]) + " PAD"
59
60         PktOverhead     int64
61         PktEncOverhead  int64
62         PktSizeOverhead int64
63
64         TooBig = errors.New("Too big than allowed")
65 )
66
67 type Pkt struct {
68         Magic   [8]byte
69         Type    PktType
70         Nice    uint8
71         PathLen uint8
72         Path    [MaxPathSize]byte
73 }
74
75 type PktTbs struct {
76         Magic     [8]byte
77         Nice      uint8
78         Sender    *NodeId
79         Recipient *NodeId
80         ExchPub   [32]byte
81 }
82
83 type PktEnc struct {
84         Magic     [8]byte
85         Nice      uint8
86         Sender    *NodeId
87         Recipient *NodeId
88         ExchPub   [32]byte
89         Sign      [ed25519.SignatureSize]byte
90 }
91
92 type PktSize struct {
93         Payload uint64
94         Pad     uint64
95 }
96
97 func NewPkt(typ PktType, nice uint8, path []byte) (*Pkt, error) {
98         if len(path) > MaxPathSize {
99                 return nil, errors.New("Too long path")
100         }
101         pkt := Pkt{
102                 Magic:   MagicNNCPPv3.B,
103                 Type:    typ,
104                 Nice:    nice,
105                 PathLen: uint8(len(path)),
106         }
107         copy(pkt.Path[:], path)
108         return &pkt, nil
109 }
110
111 func init() {
112         var buf bytes.Buffer
113         pkt := Pkt{Type: PktTypeFile}
114         n, err := xdr.Marshal(&buf, pkt)
115         if err != nil {
116                 panic(err)
117         }
118         PktOverhead = int64(n)
119         buf.Reset()
120
121         dummyId, err := NodeIdFromString(DummyB32Id)
122         if err != nil {
123                 panic(err)
124         }
125         pktEnc := PktEnc{
126                 Magic:     MagicNNCPEv6.B,
127                 Sender:    dummyId,
128                 Recipient: dummyId,
129         }
130         n, err = xdr.Marshal(&buf, pktEnc)
131         if err != nil {
132                 panic(err)
133         }
134         PktEncOverhead = int64(n)
135         buf.Reset()
136
137         size := PktSize{}
138         n, err = xdr.Marshal(&buf, size)
139         if err != nil {
140                 panic(err)
141         }
142         PktSizeOverhead = int64(n)
143 }
144
145 func ctrIncr(b []byte) {
146         for i := len(b) - 1; i >= 0; i-- {
147                 b[i]++
148                 if b[i] != 0 {
149                         return
150                 }
151         }
152         panic("counter overflow")
153 }
154
155 func TbsPrepare(our *NodeOur, their *Node, pktEnc *PktEnc) []byte {
156         tbs := PktTbs{
157                 Magic:     MagicNNCPEv6.B,
158                 Nice:      pktEnc.Nice,
159                 Sender:    their.Id,
160                 Recipient: our.Id,
161                 ExchPub:   pktEnc.ExchPub,
162         }
163         var tbsBuf bytes.Buffer
164         if _, err := xdr.Marshal(&tbsBuf, &tbs); err != nil {
165                 panic(err)
166         }
167         return tbsBuf.Bytes()
168 }
169
170 func TbsVerify(our *NodeOur, their *Node, pktEnc *PktEnc) ([]byte, bool, error) {
171         tbs := TbsPrepare(our, their, pktEnc)
172         return tbs, ed25519.Verify(their.SignPub, tbs, pktEnc.Sign[:]), nil
173 }
174
175 func sizeWithTags(size int64) (fullSize int64) {
176         size += PktSizeOverhead
177         fullSize = size + (size/EncBlkSize)*poly1305.TagSize
178         if size%EncBlkSize != 0 {
179                 fullSize += poly1305.TagSize
180         }
181         return
182 }
183
184 func sizePadCalc(sizePayload, minSize int64, wrappers int) (sizePad int64) {
185         expectedSize := sizePayload - PktOverhead
186         for i := 0; i < wrappers; i++ {
187                 expectedSize = PktEncOverhead + sizeWithTags(PktOverhead+expectedSize)
188         }
189         sizePad = minSize - expectedSize
190         if sizePad < 0 {
191                 sizePad = 0
192         }
193         return
194 }
195
196 func PktEncWrite(
197         our *NodeOur, their *Node,
198         pkt *Pkt, nice uint8,
199         minSize, maxSize int64, wrappers int,
200         r io.Reader, w io.Writer,
201 ) (pktEncRaw []byte, size int64, err error) {
202         pub, prv, err := box.GenerateKey(rand.Reader)
203         if err != nil {
204                 return nil, 0, err
205         }
206
207         var buf bytes.Buffer
208         _, err = xdr.Marshal(&buf, pkt)
209         if err != nil {
210                 return
211         }
212         pktRaw := make([]byte, buf.Len())
213         copy(pktRaw, buf.Bytes())
214         buf.Reset()
215
216         tbs := PktTbs{
217                 Magic:     MagicNNCPEv6.B,
218                 Nice:      nice,
219                 Sender:    our.Id,
220                 Recipient: their.Id,
221                 ExchPub:   *pub,
222         }
223         _, err = xdr.Marshal(&buf, &tbs)
224         if err != nil {
225                 return
226         }
227         signature := new([ed25519.SignatureSize]byte)
228         copy(signature[:], ed25519.Sign(our.SignPrv, buf.Bytes()))
229         ad := blake3.Sum256(buf.Bytes())
230         buf.Reset()
231
232         pktEnc := PktEnc{
233                 Magic:     MagicNNCPEv6.B,
234                 Nice:      nice,
235                 Sender:    our.Id,
236                 Recipient: their.Id,
237                 ExchPub:   *pub,
238                 Sign:      *signature,
239         }
240         _, err = xdr.Marshal(&buf, &pktEnc)
241         if err != nil {
242                 return
243         }
244         pktEncRaw = make([]byte, buf.Len())
245         copy(pktEncRaw, buf.Bytes())
246         buf.Reset()
247         _, err = w.Write(pktEncRaw)
248         if err != nil {
249                 return
250         }
251
252         sharedKey := new([32]byte)
253         curve25519.ScalarMult(sharedKey, prv, their.ExchPub)
254         keyFull := make([]byte, chacha20poly1305.KeySize)
255         keySize := make([]byte, chacha20poly1305.KeySize)
256         blake3.DeriveKey(keyFull, DeriveKeyFullCtx, sharedKey[:])
257         blake3.DeriveKey(keySize, DeriveKeySizeCtx, sharedKey[:])
258         aeadFull, err := chacha20poly1305.New(keyFull)
259         if err != nil {
260                 return
261         }
262         aeadSize, err := chacha20poly1305.New(keySize)
263         if err != nil {
264                 return
265         }
266         nonce := make([]byte, aeadFull.NonceSize())
267
268         data := make([]byte, EncBlkSize, EncBlkSize+aeadFull.Overhead())
269         mr := io.MultiReader(bytes.NewReader(pktRaw), r)
270         var sizePayload int64
271         var n int
272         var ct []byte
273         for {
274                 n, err = io.ReadFull(mr, data)
275                 sizePayload += int64(n)
276                 if sizePayload > maxSize {
277                         err = TooBig
278                         return
279                 }
280                 if err == nil {
281                         ct = aeadFull.Seal(data[:0], nonce, data[:n], ad[:])
282                         _, err = w.Write(ct)
283                         if err != nil {
284                                 return
285                         }
286                         ctrIncr(nonce)
287                         continue
288                 }
289                 if !(err == io.EOF || err == io.ErrUnexpectedEOF) {
290                         return
291                 }
292                 break
293         }
294
295         sizePad := sizePadCalc(sizePayload, minSize, wrappers)
296         _, err = xdr.Marshal(&buf, &PktSize{uint64(sizePayload), uint64(sizePad)})
297         if err != nil {
298                 return
299         }
300
301         var aeadLast cipher.AEAD
302         if n+int(PktSizeOverhead) > EncBlkSize {
303                 left := make([]byte, (n+int(PktSizeOverhead))-EncBlkSize)
304                 copy(left, data[n-len(left):])
305                 copy(data[PktSizeOverhead:], data[:n-len(left)])
306                 copy(data[:PktSizeOverhead], buf.Bytes())
307                 ct = aeadSize.Seal(data[:0], nonce, data[:EncBlkSize], ad[:])
308                 _, err = w.Write(ct)
309                 if err != nil {
310                         return
311                 }
312                 ctrIncr(nonce)
313                 copy(data, left)
314                 n = len(left)
315                 aeadLast = aeadFull
316         } else {
317                 copy(data[PktSizeOverhead:], data[:n])
318                 copy(data[:PktSizeOverhead], buf.Bytes())
319                 n += int(PktSizeOverhead)
320                 aeadLast = aeadSize
321         }
322
323         var sizeBlockPadded int
324         var sizePadLeft int64
325         if sizePad > EncBlkSize-int64(n) {
326                 sizeBlockPadded = EncBlkSize
327                 sizePadLeft = sizePad - (EncBlkSize - int64(n))
328         } else {
329                 sizeBlockPadded = n + int(sizePad)
330                 sizePadLeft = 0
331         }
332         for i := n; i < sizeBlockPadded; i++ {
333                 data[i] = 0
334         }
335         ct = aeadLast.Seal(data[:0], nonce, data[:sizeBlockPadded], ad[:])
336         _, err = w.Write(ct)
337         if err != nil {
338                 return
339         }
340
341         size = sizePayload
342         if sizePadLeft > 0 {
343                 keyPad := make([]byte, chacha20poly1305.KeySize)
344                 blake3.DeriveKey(keyPad, DeriveKeyPadCtx, sharedKey[:])
345                 _, err = io.CopyN(w, blake3.New(32, keyPad).XOF(), sizePadLeft)
346         }
347         return
348 }
349
350 func PktEncRead(
351         our *NodeOur, nodes map[NodeId]*Node,
352         r io.Reader, w io.Writer,
353         signatureVerify bool,
354         sharedKeyCached []byte,
355 ) (sharedKey []byte, their *Node, size int64, err error) {
356         var pktEnc PktEnc
357         _, err = xdr.Unmarshal(r, &pktEnc)
358         if err != nil {
359                 return
360         }
361         switch pktEnc.Magic {
362         case MagicNNCPEv1.B:
363                 err = MagicNNCPEv1.TooOld()
364         case MagicNNCPEv2.B:
365                 err = MagicNNCPEv2.TooOld()
366         case MagicNNCPEv3.B:
367                 err = MagicNNCPEv3.TooOld()
368         case MagicNNCPEv4.B:
369                 err = MagicNNCPEv4.TooOld()
370         case MagicNNCPEv5.B:
371                 err = MagicNNCPEv5.TooOld()
372         case MagicNNCPEv6.B:
373         default:
374                 err = BadMagic
375         }
376         if err != nil {
377                 return
378         }
379         if *pktEnc.Recipient != *our.Id {
380                 err = errors.New("Invalid recipient")
381                 return
382         }
383
384         var tbsRaw []byte
385         if signatureVerify {
386                 their = nodes[*pktEnc.Sender]
387                 if their == nil {
388                         err = errors.New("Unknown sender")
389                         return
390                 }
391                 var verified bool
392                 tbsRaw, verified, err = TbsVerify(our, their, &pktEnc)
393                 if err != nil {
394                         return
395                 }
396                 if !verified {
397                         err = errors.New("Invalid signature")
398                         return
399                 }
400         } else {
401                 tbsRaw = TbsPrepare(our, &Node{Id: pktEnc.Sender}, &pktEnc)
402         }
403         ad := blake3.Sum256(tbsRaw)
404         if sharedKeyCached == nil {
405                 key := new([32]byte)
406                 curve25519.ScalarMult(key, our.ExchPrv, &pktEnc.ExchPub)
407                 sharedKey = key[:]
408         } else {
409                 sharedKey = sharedKeyCached
410         }
411
412         keyFull := make([]byte, chacha20poly1305.KeySize)
413         keySize := make([]byte, chacha20poly1305.KeySize)
414         blake3.DeriveKey(keyFull, DeriveKeyFullCtx, sharedKey[:])
415         blake3.DeriveKey(keySize, DeriveKeySizeCtx, sharedKey[:])
416         aeadFull, err := chacha20poly1305.New(keyFull)
417         if err != nil {
418                 return
419         }
420         aeadSize, err := chacha20poly1305.New(keySize)
421         if err != nil {
422                 return
423         }
424         nonce := make([]byte, aeadFull.NonceSize())
425
426         ct := make([]byte, EncBlkSize+aeadFull.Overhead())
427         pt := make([]byte, EncBlkSize)
428         var n int
429 FullRead:
430         for {
431                 n, err = io.ReadFull(r, ct)
432                 switch err {
433                 case nil:
434                         pt, err = aeadFull.Open(pt[:0], nonce, ct, ad[:])
435                         if err != nil {
436                                 break FullRead
437                         }
438                         size += EncBlkSize
439                         _, err = w.Write(pt)
440                         if err != nil {
441                                 return
442                         }
443                         ctrIncr(nonce)
444                         continue
445                 case io.ErrUnexpectedEOF:
446                         break FullRead
447                 default:
448                         return
449                 }
450         }
451
452         pt, err = aeadSize.Open(pt[:0], nonce, ct[:n], ad[:])
453         if err != nil {
454                 return
455         }
456         var pktSize PktSize
457         _, err = xdr.Unmarshal(bytes.NewReader(pt), &pktSize)
458         if err != nil {
459                 return
460         }
461         pt = pt[PktSizeOverhead:]
462
463         left := int64(pktSize.Payload) - size
464         for left > int64(len(pt)) {
465                 size += int64(len(pt))
466                 left -= int64(len(pt))
467                 _, err = w.Write(pt)
468                 if err != nil {
469                         return
470                 }
471                 n, err = io.ReadFull(r, ct)
472                 if err != nil && err != io.ErrUnexpectedEOF {
473                         return
474                 }
475                 ctrIncr(nonce)
476                 pt, err = aeadFull.Open(pt[:0], nonce, ct[:n], ad[:])
477                 if err != nil {
478                         return
479                 }
480         }
481         size += left
482         _, err = w.Write(pt[:left])
483         if err != nil {
484                 return
485         }
486         pt = pt[left:]
487
488         if pktSize.Pad < uint64(len(pt)) {
489                 err = errors.New("unexpected pad")
490                 return
491         }
492         for i := 0; i < len(pt); i++ {
493                 if pt[i] != 0 {
494                         err = errors.New("non-zero pad byte")
495                         return
496                 }
497         }
498         sizePad := int64(pktSize.Pad) - int64(len(pt))
499         if sizePad == 0 {
500                 return
501         }
502
503         keyPad := make([]byte, chacha20poly1305.KeySize)
504         blake3.DeriveKey(keyPad, DeriveKeyPadCtx, sharedKey[:])
505         xof := blake3.New(32, keyPad).XOF()
506         pt = make([]byte, len(ct))
507         for sizePad > 0 {
508                 n, err = io.ReadFull(r, ct)
509                 if err != nil && err != io.ErrUnexpectedEOF {
510                         return
511                 }
512                 _, err = io.ReadFull(xof, pt[:n])
513                 if err != nil {
514                         panic(err)
515                 }
516                 if bytes.Compare(ct[:n], pt[:n]) != 0 {
517                         err = errors.New("wrong pad value")
518                         return
519                 }
520                 sizePad -= int64(n)
521         }
522         if sizePad < 0 {
523                 err = errors.New("excess pad")
524         }
525         return
526 }