]> Cypherpunks.ru repositories - gohpenc.git/blobdiff - main.go
More secure and faster version
[gohpenc.git] / main.go
diff --git a/main.go b/main.go
index 9793cf11532635086ce3462b284e0f213a3a8d81..9eacacb03ea5377b8c37e045ca2fec81e0323af0 100644 (file)
--- a/main.go
+++ b/main.go
@@ -1,6 +1,6 @@
 /*
 gohpenc -- Go high-performance encryption utility
-Copyright (C) 2017-2019 Sergey Matveev <stargrave@stargrave.org>
+Copyright (C) 2017-2022 Sergey Matveev <stargrave@stargrave.org>
 
 This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
@@ -21,178 +21,252 @@ package main
 import (
        "bufio"
        "crypto/rand"
+       "crypto/sha512"
+       "encoding/base32"
        "encoding/binary"
        "flag"
        "fmt"
        "io"
+       "log"
        "os"
        "runtime"
        "sync"
 
-       "golang.org/x/crypto/blake2b"
+       "golang.org/x/crypto/chacha20"
        "golang.org/x/crypto/chacha20poly1305"
+       "golang.org/x/crypto/hkdf"
+       "golang.org/x/crypto/poly1305"
 )
 
 const (
-       LenSize  = 4
-       SaltSize = 32
+       Magic    = "GOHPENC\n"
+       SaltSize = 16
 )
 
-type WorkerTask struct {
-       key []byte
-       n   int
-}
+var Base32Codec *base32.Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
 
 type Worker struct {
-       wg      *sync.WaitGroup
-       buf     []byte
-       ready   chan struct{}
-       task    chan WorkerTask
-       done    chan []byte
-       written chan struct{}
+       ctr      uint64
+       buf      []byte
+       readyIn  chan struct{}
+       readyOut chan struct{}
+       last     bool
 }
 
-func (w *Worker) Thread(doDecrypt bool) {
-       var task WorkerTask
-       nonce := make([]byte, chacha20poly1305.NonceSize)
-       var done []byte
-       for {
-               w.ready <- struct{}{}
-               task = <-w.task
-               aead, err := chacha20poly1305.New(task.key)
+func readBuf(dst []byte, src io.Reader) ([]byte, error) {
+       var n, full int
+       var err error
+       for full < len(dst) {
+               n, err = src.Read(dst[full:])
+               full += n
                if err != nil {
-                       panic(err)
-               }
-               if doDecrypt {
-                       done, err = aead.Open(w.buf[:0], nonce, w.buf[LenSize:LenSize+task.n], w.buf[:LenSize])
-                       if err != nil {
-                               panic(err)
+                       if err == io.EOF {
+                               break
                        }
-               } else {
-                       binary.BigEndian.PutUint32(w.buf, uint32(task.n))
-                       done = aead.Seal(w.buf[:LenSize], nonce, w.buf[LenSize:LenSize+task.n], w.buf[:LenSize])
+                       return nil, err
                }
-               w.done <- done
-               <-w.written
-               w.wg.Done()
        }
+       return dst[:full], err
+}
+
+type DummyReader struct{}
+
+func (r *DummyReader) Read(b []byte) (int, error) {
+       return len(b), nil
 }
 
 func main() {
-       var (
-               doPSK     = flag.Bool("psk", false, "Generate PSK")
-               key       = flag.String("k", "", "Encryption key")
-               doDecrypt = flag.Bool("d", false, "Decrypt, instead of encrypt")
-               blockSize = flag.Int("b", 1<<10, "Blocksize, in KiB")
-               threads   = flag.Int("c", runtime.NumCPU(), "Number of threads")
-       )
+       doRNG := flag.Bool("r", false, "Random number generator")
+       doGen := flag.Bool("psk", false, "Generate key")
+       doDec := flag.Bool("d", false, "Decrypt, instead of encrypt")
+       bs := flag.Int("b", 1<<10, "Blocksize, KiB")
+       jobs := flag.Int("c", runtime.NumCPU(), "Number of parallel threads")
+       keyB32 := flag.String("k", "", "Encryption key")
        flag.Parse()
-       bs := *blockSize * 1 << 10
+       log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile)
 
-       if *doPSK {
-               key := make([]byte, 32)
-               if _, err := rand.Read(key); err != nil {
-                       panic(err)
+       var err error
+       key := make([]byte, chacha20poly1305.KeySize)
+       if *doGen || *doRNG {
+               if _, err := io.ReadFull(rand.Reader, key); err != nil {
+                       log.Fatalln(err)
+               }
+               if *doGen {
+                       fmt.Println(Base32Codec.EncodeToString(key))
+                       return
+               }
+       } else {
+               key, err = Base32Codec.DecodeString(*keyB32)
+               if err != nil {
+                       log.Fatalln(err)
+               }
+               if len(key) != chacha20poly1305.KeySize {
+                       log.Fatalln("invalid key size")
                }
-               fmt.Println(ToBase32(key))
-               return
        }
 
-       if len(*key) != 52 {
-               panic("Invalid key size")
-       }
-       keyDecoded, err := FromBase32(*key)
-       if err != nil {
-               panic(err)
-       }
-       tmpAEAD, err := chacha20poly1305.New(make([]byte, chacha20poly1305.KeySize))
-       if err != nil {
-               panic(err)
+       salt := make([]byte, SaltSize)
+       if *doDec {
+               if _, err = io.ReadFull(os.Stdin, salt[:len(Magic)]); err != nil {
+                       log.Fatalln(err)
+               }
+               if string(salt[:len(Magic)]) != Magic {
+                       log.Fatalln("invalid magic")
+               }
+               if _, err = io.ReadFull(os.Stdin, salt[:4]); err != nil {
+                       log.Fatalln(err)
+               }
+               *bs = int(binary.BigEndian.Uint32(salt[:4]))
+               if _, err = io.ReadFull(os.Stdin, salt); err != nil {
+                       log.Fatalln(err)
+               }
+       } else {
+               if _, err = os.Stdout.WriteString(Magic); err != nil {
+                       log.Fatalln(err)
+               }
+               *bs = *bs * 1024
+               binary.BigEndian.PutUint32(salt, uint32(*bs))
+               if _, err = os.Stdout.Write(salt[:4]); err != nil {
+                       log.Fatalln(err)
+               }
+               if _, err = io.ReadFull(rand.Reader, salt); err != nil {
+                       log.Fatalln(err)
+               }
+               if _, err = os.Stdout.Write(salt); err != nil {
+                       log.Fatalln(err)
+               }
        }
-       keys, err := blake2b.NewXOF(blake2b.OutputLengthUnknown, keyDecoded)
-       if err != nil {
-               panic(err)
+
+       kdf := hkdf.New(sha512.New, key, salt, []byte(Magic))
+       if _, err = io.ReadFull(kdf, key); err != nil {
+               log.Fatalln(err)
        }
 
        var wg sync.WaitGroup
-       workers := make([]*Worker, 0, *threads)
-       for i := 0; i < *threads; i++ {
+       var lastMet bool
+       workers := make([]*Worker, 0, *jobs)
+       for i := 0; i < *jobs; i++ {
                w := Worker{
-                       wg:      &wg,
-                       buf:     make([]byte, LenSize+bs+tmpAEAD.Overhead()),
-                       ready:   make(chan struct{}),
-                       task:    make(chan WorkerTask),
-                       done:    make(chan []byte),
-                       written: make(chan struct{}),
-               }
-               go w.Thread(*doDecrypt)
+                       buf:      make([]byte, *bs+chacha20poly1305.Overhead),
+                       readyIn:  make(chan struct{}),
+                       readyOut: make(chan struct{}),
+               }
+               go func() {
+                       ciph, err := chacha20poly1305.New(key)
+                       if err != nil {
+                               log.Fatalln(err)
+                       }
+                       nonce := make([]byte, chacha20poly1305.NonceSize)
+                       var ciphertext, tag []byte
+                       var s *chacha20.Cipher
+                       var p *poly1305.MAC
+                       for {
+                               w.readyIn <- struct{}{}
+                               <-w.readyIn
+                               binary.BigEndian.PutUint64(nonce, w.ctr)
+                               if *doDec {
+                                       tag = w.buf[len(w.buf)-poly1305.TagSize:]
+                                       ciphertext = w.buf[:len(w.buf)-poly1305.TagSize]
+                                       s, err = chacha20.NewUnauthenticatedCipher(key, nonce)
+                                       if err != nil {
+                                               log.Fatalln(err)
+                                       }
+                                       var polyKey [32]byte
+                                       s.XORKeyStream(polyKey[:], polyKey[:])
+                                       s.SetCounter(1)
+                                       p = poly1305.New(&polyKey)
+                                       writeWithPadding(p, nil)
+                                       writeWithPadding(p, ciphertext)
+                                       writeUint64(p, 0)
+                                       writeUint64(p, len(ciphertext))
+                                       if p.Verify(tag) {
+                                               w.buf = ciphertext
+                                               s.XORKeyStream(ciphertext, ciphertext)
+                                       } else {
+                                               lastMet = true
+                                               if _, err = io.ReadFull(kdf, key); err != nil {
+                                                       log.Fatalln(err)
+                                               }
+                                               ciph, err = chacha20poly1305.New(key)
+                                               if err != nil {
+                                                       log.Fatalln(err)
+                                               }
+                                               w.buf, err = ciph.Open(w.buf[:0], nonce, w.buf, nil)
+                                               if err != nil {
+                                                       log.Fatalln(err)
+                                               }
+                                               lastMet = true
+                                       }
+                               } else {
+                                       if w.last {
+                                               if _, err = io.ReadFull(kdf, key); err != nil {
+                                                       log.Fatalln(err)
+                                               }
+                                               ciph, err = chacha20poly1305.New(key)
+                                               if err != nil {
+                                                       log.Fatalln(err)
+                                               }
+                                       }
+                                       w.buf = ciph.Seal(w.buf[:0], nonce, w.buf, nil)
+                               }
+                               w.readyOut <- struct{}{}
+                               <-w.readyOut
+                               wg.Done()
+                       }
+               }()
                workers = append(workers, &w)
        }
+
        go func() {
-               i := 0
+               var ctr int64
                var w *Worker
+               var err error
                for {
-                       w = workers[i%len(workers)]
-                       if _, err = os.Stdout.Write(<-w.done); err != nil {
-                               panic(err)
+                       w = workers[ctr%int64(len(workers))]
+                       <-w.readyOut
+                       if _, err = os.Stdout.Write(w.buf); err != nil {
+                               log.Fatalln(err)
                        }
-                       w.written <- struct{}{}
-                       i++
+                       w.readyOut <- struct{}{}
+                       ctr++
                }
        }()
 
-       stdin := bufio.NewReaderSize(os.Stdin, LenSize+bs)
-       if *doDecrypt {
-               if _, err = io.CopyN(keys, stdin, SaltSize); err != nil {
-                       panic(err)
-               }
+       var stdin io.Reader
+       if *doRNG {
+               stdin = &DummyReader{}
        } else {
-               salt := make([]byte, SaltSize)
-               if _, err = rand.Read(salt); err != nil {
-                       panic(err)
-               }
-               if _, err = keys.Write(salt); err != nil {
-                       panic(err)
-               }
-               if _, err = os.Stdout.Write(salt); err != nil {
-                       panic(err)
-               }
+               stdin = bufio.NewReaderSize(os.Stdin, *bs+chacha20poly1305.Overhead)
        }
-
-       i := 0
-       var n int
+       var ctr uint64
        var w *Worker
        for {
-               key := make([]byte, chacha20poly1305.KeySize)
-               if _, err = io.ReadFull(keys, key); err != nil {
-                       panic(err)
-               }
-               w = workers[i%len(workers)]
-               <-w.ready
-               if *doDecrypt {
-                       _, err = io.ReadFull(stdin, w.buf[:LenSize])
-                       if err != nil {
-                               if err == io.EOF {
-                                       break
-                               }
-                               panic(err)
-                       }
-                       n = int(binary.BigEndian.Uint32(w.buf[:LenSize]))
-                       if n, err = io.ReadFull(stdin, w.buf[LenSize:LenSize+n+tmpAEAD.Overhead()]); err != nil {
-                               panic(err)
-                       }
+               w = workers[ctr%uint64(len(workers))]
+               <-w.readyIn
+               if *doDec {
+                       w.buf, err = readBuf(w.buf[:*bs+chacha20poly1305.Overhead], stdin)
                } else {
-                       n, err = stdin.Read(w.buf[LenSize : LenSize+bs])
-                       if err != nil {
-                               if err == io.EOF {
-                                       break
-                               }
-                               panic(err)
-                       }
+                       w.buf, err = readBuf(w.buf[:*bs], stdin)
+               }
+               if err != nil && err != io.EOF {
+                       log.Fatalln(err)
+               }
+               if *doDec && len(w.buf) < chacha20poly1305.Overhead {
+                       break
                }
+               w.ctr = ctr
                wg.Add(1)
-               w.task <- WorkerTask{key, n}
-               i++
+               if err == io.EOF {
+                       w.last = true
+               }
+               w.readyIn <- struct{}{}
+               if err == io.EOF {
+                       break
+               }
+               ctr++
        }
        wg.Wait()
+       if *doDec && !lastMet {
+               log.Fatalln("did not meet explicit last block")
+       }
 }