]> Cypherpunks.ru repositories - gohpenc.git/blobdiff - main.go
Huge simplifications
[gohpenc.git] / main.go
diff --git a/main.go b/main.go
index 9793cf11532635086ce3462b284e0f213a3a8d81..a05432858e5dfece95fd30e9a1bec8ea7a9b9b62 100644 (file)
--- a/main.go
+++ b/main.go
@@ -1,6 +1,6 @@
 /*
 gohpenc -- Go high-performance encryption utility
-Copyright (C) 2017-2019 Sergey Matveev <stargrave@stargrave.org>
+Copyright (C) 2017-2020 Sergey Matveev <stargrave@stargrave.org>
 
 This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
@@ -20,7 +20,9 @@ package main
 
 import (
        "bufio"
+       "crypto/cipher"
        "crypto/rand"
+       "encoding/base32"
        "encoding/binary"
        "flag"
        "fmt"
@@ -29,160 +31,172 @@ import (
        "runtime"
        "sync"
 
-       "golang.org/x/crypto/blake2b"
        "golang.org/x/crypto/chacha20poly1305"
+       "golang.org/x/crypto/poly1305"
 )
 
 const (
        LenSize  = 4
-       SaltSize = 32
+       SaltSize = 16
 )
 
-type WorkerTask struct {
+var (
+       doPSK     = flag.Bool("psk", false, "Generate PSK")
+       keyB32    = flag.String("k", "", "Encryption key")
+       decrypt   = flag.Bool("d", false, "Decrypt, instead of encrypt")
+       blockSize = flag.Int("b", 1<<10, "Blocksize, in KiB")
+       threads   = flag.Int("c", runtime.NumCPU(), "Number of threads")
+
+       Base32Codec *base32.Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
+
        key []byte
-       n   int
+       bs  int
+       wg  sync.WaitGroup
+)
+
+type Task struct {
+       ctr  uint64
+       size int
 }
 
 type Worker struct {
-       wg      *sync.WaitGroup
-       buf     []byte
+       aead    cipher.AEAD
+       nonce   []byte
+       input   []byte
        ready   chan struct{}
-       task    chan WorkerTask
-       done    chan []byte
+       task    chan Task
+       output  chan []byte
        written chan struct{}
 }
 
-func (w *Worker) Thread(doDecrypt bool) {
-       var task WorkerTask
-       nonce := make([]byte, chacha20poly1305.NonceSize)
-       var done []byte
+func NewWorker(key, salt []byte) *Worker {
+       aead, err := chacha20poly1305.NewX(key)
+       if err != nil {
+               panic(err)
+       }
+       w := Worker{
+               aead:    aead,
+               nonce:   make([]byte, chacha20poly1305.NonceSizeX),
+               input:   make([]byte, LenSize+bs+poly1305.TagSize),
+               ready:   make(chan struct{}),
+               task:    make(chan Task),
+               output:  make(chan []byte),
+               written: make(chan struct{}),
+       }
+       copy(w.nonce, salt)
+       go w.Run()
+       return &w
+}
+
+func (w *Worker) Run() {
+       var output []byte
+       var err error
        for {
                w.ready <- struct{}{}
-               task = <-w.task
-               aead, err := chacha20poly1305.New(task.key)
-               if err != nil {
-                       panic(err)
-               }
-               if doDecrypt {
-                       done, err = aead.Open(w.buf[:0], nonce, w.buf[LenSize:LenSize+task.n], w.buf[:LenSize])
+               task := <-w.task
+               binary.BigEndian.PutUint64(w.nonce[SaltSize:], task.ctr)
+               if *decrypt {
+                       output, err = w.aead.Open(
+                               w.input[:LenSize],
+                               w.nonce,
+                               w.input[LenSize:LenSize+task.size+poly1305.TagSize],
+                               w.input[:LenSize],
+                       )
                        if err != nil {
                                panic(err)
                        }
+                       output = output[LenSize:]
                } else {
-                       binary.BigEndian.PutUint32(w.buf, uint32(task.n))
-                       done = aead.Seal(w.buf[:LenSize], nonce, w.buf[LenSize:LenSize+task.n], w.buf[:LenSize])
+                       binary.BigEndian.PutUint32(w.input, uint32(task.size))
+                       output = w.aead.Seal(
+                               w.input[:LenSize],
+                               w.nonce,
+                               w.input[LenSize:LenSize+task.size],
+                               w.input[:LenSize],
+                       )
                }
-               w.done <- done
+               w.output <- output
                <-w.written
-               w.wg.Done()
+               wg.Done()
        }
 }
 
 func main() {
-       var (
-               doPSK     = flag.Bool("psk", false, "Generate PSK")
-               key       = flag.String("k", "", "Encryption key")
-               doDecrypt = flag.Bool("d", false, "Decrypt, instead of encrypt")
-               blockSize = flag.Int("b", 1<<10, "Blocksize, in KiB")
-               threads   = flag.Int("c", runtime.NumCPU(), "Number of threads")
-       )
        flag.Parse()
-       bs := *blockSize * 1 << 10
 
        if *doPSK {
-               key := make([]byte, 32)
-               if _, err := rand.Read(key); err != nil {
+               key := make([]byte, chacha20poly1305.KeySize)
+               if _, err := io.ReadFull(rand.Reader, key); err != nil {
                        panic(err)
                }
-               fmt.Println(ToBase32(key))
+               fmt.Println(Base32Codec.EncodeToString(key))
                return
        }
 
-       if len(*key) != 52 {
-               panic("Invalid key size")
-       }
-       keyDecoded, err := FromBase32(*key)
+       var err error
+       key, err = Base32Codec.DecodeString(*keyB32)
        if err != nil {
                panic(err)
        }
-       tmpAEAD, err := chacha20poly1305.New(make([]byte, chacha20poly1305.KeySize))
-       if err != nil {
-               panic(err)
+       if len(key) != chacha20poly1305.KeySize {
+               panic("Invalid key size")
        }
-       keys, err := blake2b.NewXOF(blake2b.OutputLengthUnknown, keyDecoded)
-       if err != nil {
-               panic(err)
+       salt := make([]byte, SaltSize)
+
+       if *decrypt {
+               if _, err = io.ReadFull(os.Stdin, salt); err != nil {
+                       panic(err)
+               }
+       } else {
+               if _, err = io.ReadFull(rand.Reader, salt); err != nil {
+                       panic(err)
+               }
+               if _, err = os.Stdout.Write(salt); err != nil {
+                       panic(err)
+               }
        }
 
-       var wg sync.WaitGroup
-       workers := make([]*Worker, 0, *threads)
+       bs = *blockSize * (1 << 10)
+       stdin := bufio.NewReaderSize(os.Stdin, LenSize+bs+poly1305.TagSize)
+
+       workers := make([]*Worker, *threads)
        for i := 0; i < *threads; i++ {
-               w := Worker{
-                       wg:      &wg,
-                       buf:     make([]byte, LenSize+bs+tmpAEAD.Overhead()),
-                       ready:   make(chan struct{}),
-                       task:    make(chan WorkerTask),
-                       done:    make(chan []byte),
-                       written: make(chan struct{}),
-               }
-               go w.Thread(*doDecrypt)
-               workers = append(workers, &w)
+               workers[i] = NewWorker(key, salt)
        }
        go func() {
-               i := 0
-               var w *Worker
+               var ctr uint64
                for {
-                       w = workers[i%len(workers)]
-                       if _, err = os.Stdout.Write(<-w.done); err != nil {
+                       w := workers[ctr%uint64(len(workers))]
+                       if _, err := os.Stdout.Write(<-w.output); err != nil {
                                panic(err)
                        }
                        w.written <- struct{}{}
-                       i++
+                       ctr++
                }
        }()
 
-       stdin := bufio.NewReaderSize(os.Stdin, LenSize+bs)
-       if *doDecrypt {
-               if _, err = io.CopyN(keys, stdin, SaltSize); err != nil {
-                       panic(err)
-               }
-       } else {
-               salt := make([]byte, SaltSize)
-               if _, err = rand.Read(salt); err != nil {
-                       panic(err)
-               }
-               if _, err = keys.Write(salt); err != nil {
-                       panic(err)
-               }
-               if _, err = os.Stdout.Write(salt); err != nil {
-                       panic(err)
-               }
-       }
-
-       i := 0
-       var n int
-       var w *Worker
+       var ctr uint64
+       var size int
        for {
-               key := make([]byte, chacha20poly1305.KeySize)
-               if _, err = io.ReadFull(keys, key); err != nil {
-                       panic(err)
-               }
-               w = workers[i%len(workers)]
+               w := workers[ctr%uint64(len(workers))]
                <-w.ready
-               if *doDecrypt {
-                       _, err = io.ReadFull(stdin, w.buf[:LenSize])
+               if *decrypt {
+                       _, err = io.ReadFull(stdin, w.input[:LenSize])
                        if err != nil {
                                if err == io.EOF {
                                        break
                                }
                                panic(err)
                        }
-                       n = int(binary.BigEndian.Uint32(w.buf[:LenSize]))
-                       if n, err = io.ReadFull(stdin, w.buf[LenSize:LenSize+n+tmpAEAD.Overhead()]); err != nil {
+                       size = int(binary.BigEndian.Uint32(w.input[:LenSize]))
+                       if _, err = io.ReadFull(
+                               stdin,
+                               w.input[LenSize:LenSize+size+poly1305.TagSize],
+                       ); err != nil {
                                panic(err)
                        }
                } else {
-                       n, err = stdin.Read(w.buf[LenSize : LenSize+bs])
+                       size, err = stdin.Read(w.input[LenSize : LenSize+bs])
                        if err != nil {
                                if err == io.EOF {
                                        break
@@ -191,8 +205,8 @@ func main() {
                        }
                }
                wg.Add(1)
-               w.task <- WorkerTask{key, n}
-               i++
+               w.task <- Task{ctr, size}
+               ctr++
        }
        wg.Wait()
 }