]> Cypherpunks.ru repositories - gohpenc.git/blobdiff - main.go
More secure and faster version
[gohpenc.git] / main.go
diff --git a/main.go b/main.go
index 43eb680f8f1e29193d12bc36a81adc60fa097fc2..9eacacb03ea5377b8c37e045ca2fec81e0323af0 100644 (file)
--- a/main.go
+++ b/main.go
@@ -20,196 +20,253 @@ package main
 
 import (
        "bufio"
-       "crypto/cipher"
        "crypto/rand"
+       "crypto/sha512"
        "encoding/base32"
        "encoding/binary"
        "flag"
        "fmt"
        "io"
+       "log"
        "os"
        "runtime"
        "sync"
 
+       "golang.org/x/crypto/chacha20"
        "golang.org/x/crypto/chacha20poly1305"
+       "golang.org/x/crypto/hkdf"
        "golang.org/x/crypto/poly1305"
 )
 
 const (
-       LenSize  = 4
+       Magic    = "GOHPENC\n"
        SaltSize = 16
 )
 
-var (
-       doPSK     = flag.Bool("psk", false, "Generate PSK")
-       keyB32    = flag.String("k", "", "Encryption key")
-       decrypt   = flag.Bool("d", false, "Decrypt, instead of encrypt")
-       blockSize = flag.Int("b", 1<<10, "Blocksize, in KiB")
-       threads   = flag.Int("c", runtime.NumCPU(), "Number of threads")
-
-       Base32Codec *base32.Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
-
-       key []byte
-       bs  int
-       wg  sync.WaitGroup
-)
-
-type Task struct {
-       ctr  uint64
-       size int
-}
+var Base32Codec *base32.Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
 
 type Worker struct {
-       aead    cipher.AEAD
-       nonce   []byte
-       input   []byte
-       ready   chan struct{}
-       task    chan Task
-       output  chan []byte
-       written chan struct{}
+       ctr      uint64
+       buf      []byte
+       readyIn  chan struct{}
+       readyOut chan struct{}
+       last     bool
 }
 
-func NewWorker(key, salt []byte) *Worker {
-       aead, err := chacha20poly1305.NewX(key)
-       if err != nil {
-               panic(err)
-       }
-       w := Worker{
-               aead:    aead,
-               nonce:   make([]byte, chacha20poly1305.NonceSizeX),
-               input:   make([]byte, LenSize+bs+poly1305.TagSize),
-               ready:   make(chan struct{}),
-               task:    make(chan Task),
-               output:  make(chan []byte),
-               written: make(chan struct{}),
-       }
-       copy(w.nonce, salt)
-       go w.Run()
-       return &w
-}
-
-func (w *Worker) Run() {
-       var output []byte
+func readBuf(dst []byte, src io.Reader) ([]byte, error) {
+       var n, full int
        var err error
-       for {
-               w.ready <- struct{}{}
-               task := <-w.task
-               binary.BigEndian.PutUint64(w.nonce[SaltSize:], task.ctr)
-               if *decrypt {
-                       output, err = w.aead.Open(
-                               w.input[:LenSize],
-                               w.nonce,
-                               w.input[LenSize:LenSize+task.size+poly1305.TagSize],
-                               w.input[:LenSize],
-                       )
-                       if err != nil {
-                               panic(err)
+       for full < len(dst) {
+               n, err = src.Read(dst[full:])
+               full += n
+               if err != nil {
+                       if err == io.EOF {
+                               break
                        }
-                       output = output[LenSize:]
-               } else {
-                       binary.BigEndian.PutUint32(w.input, uint32(task.size))
-                       output = w.aead.Seal(
-                               w.input[:LenSize],
-                               w.nonce,
-                               w.input[LenSize:LenSize+task.size],
-                               w.input[:LenSize],
-                       )
-               }
-               w.output <- output
-               <-w.written
-               wg.Done()
+                       return nil, err
+               }
        }
+       return dst[:full], err
+}
+
+type DummyReader struct{}
+
+func (r *DummyReader) Read(b []byte) (int, error) {
+       return len(b), nil
 }
 
 func main() {
+       doRNG := flag.Bool("r", false, "Random number generator")
+       doGen := flag.Bool("psk", false, "Generate key")
+       doDec := flag.Bool("d", false, "Decrypt, instead of encrypt")
+       bs := flag.Int("b", 1<<10, "Blocksize, KiB")
+       jobs := flag.Int("c", runtime.NumCPU(), "Number of parallel threads")
+       keyB32 := flag.String("k", "", "Encryption key")
        flag.Parse()
+       log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile)
 
-       if *doPSK {
-               key := make([]byte, chacha20poly1305.KeySize)
+       var err error
+       key := make([]byte, chacha20poly1305.KeySize)
+       if *doGen || *doRNG {
                if _, err := io.ReadFull(rand.Reader, key); err != nil {
-                       panic(err)
+                       log.Fatalln(err)
+               }
+               if *doGen {
+                       fmt.Println(Base32Codec.EncodeToString(key))
+                       return
+               }
+       } else {
+               key, err = Base32Codec.DecodeString(*keyB32)
+               if err != nil {
+                       log.Fatalln(err)
+               }
+               if len(key) != chacha20poly1305.KeySize {
+                       log.Fatalln("invalid key size")
                }
-               fmt.Println(Base32Codec.EncodeToString(key))
-               return
        }
 
-       var err error
-       key, err = Base32Codec.DecodeString(*keyB32)
-       if err != nil {
-               panic(err)
-       }
-       if len(key) != chacha20poly1305.KeySize {
-               panic("Invalid key size")
-       }
        salt := make([]byte, SaltSize)
-
-       if *decrypt {
+       if *doDec {
+               if _, err = io.ReadFull(os.Stdin, salt[:len(Magic)]); err != nil {
+                       log.Fatalln(err)
+               }
+               if string(salt[:len(Magic)]) != Magic {
+                       log.Fatalln("invalid magic")
+               }
+               if _, err = io.ReadFull(os.Stdin, salt[:4]); err != nil {
+                       log.Fatalln(err)
+               }
+               *bs = int(binary.BigEndian.Uint32(salt[:4]))
                if _, err = io.ReadFull(os.Stdin, salt); err != nil {
-                       panic(err)
+                       log.Fatalln(err)
                }
        } else {
+               if _, err = os.Stdout.WriteString(Magic); err != nil {
+                       log.Fatalln(err)
+               }
+               *bs = *bs * 1024
+               binary.BigEndian.PutUint32(salt, uint32(*bs))
+               if _, err = os.Stdout.Write(salt[:4]); err != nil {
+                       log.Fatalln(err)
+               }
                if _, err = io.ReadFull(rand.Reader, salt); err != nil {
-                       panic(err)
+                       log.Fatalln(err)
                }
                if _, err = os.Stdout.Write(salt); err != nil {
-                       panic(err)
+                       log.Fatalln(err)
                }
        }
 
-       bs = *blockSize * (1 << 10)
-       if bs > 1<<32 {
-               panic("blocksize exceeds 32-bits")
+       kdf := hkdf.New(sha512.New, key, salt, []byte(Magic))
+       if _, err = io.ReadFull(kdf, key); err != nil {
+               log.Fatalln(err)
        }
-       stdin := bufio.NewReaderSize(os.Stdin, LenSize+bs+poly1305.TagSize)
 
-       workers := make([]*Worker, *threads)
-       for i := 0; i < *threads; i++ {
-               workers[i] = NewWorker(key, salt)
+       var wg sync.WaitGroup
+       var lastMet bool
+       workers := make([]*Worker, 0, *jobs)
+       for i := 0; i < *jobs; i++ {
+               w := Worker{
+                       buf:      make([]byte, *bs+chacha20poly1305.Overhead),
+                       readyIn:  make(chan struct{}),
+                       readyOut: make(chan struct{}),
+               }
+               go func() {
+                       ciph, err := chacha20poly1305.New(key)
+                       if err != nil {
+                               log.Fatalln(err)
+                       }
+                       nonce := make([]byte, chacha20poly1305.NonceSize)
+                       var ciphertext, tag []byte
+                       var s *chacha20.Cipher
+                       var p *poly1305.MAC
+                       for {
+                               w.readyIn <- struct{}{}
+                               <-w.readyIn
+                               binary.BigEndian.PutUint64(nonce, w.ctr)
+                               if *doDec {
+                                       tag = w.buf[len(w.buf)-poly1305.TagSize:]
+                                       ciphertext = w.buf[:len(w.buf)-poly1305.TagSize]
+                                       s, err = chacha20.NewUnauthenticatedCipher(key, nonce)
+                                       if err != nil {
+                                               log.Fatalln(err)
+                                       }
+                                       var polyKey [32]byte
+                                       s.XORKeyStream(polyKey[:], polyKey[:])
+                                       s.SetCounter(1)
+                                       p = poly1305.New(&polyKey)
+                                       writeWithPadding(p, nil)
+                                       writeWithPadding(p, ciphertext)
+                                       writeUint64(p, 0)
+                                       writeUint64(p, len(ciphertext))
+                                       if p.Verify(tag) {
+                                               w.buf = ciphertext
+                                               s.XORKeyStream(ciphertext, ciphertext)
+                                       } else {
+                                               lastMet = true
+                                               if _, err = io.ReadFull(kdf, key); err != nil {
+                                                       log.Fatalln(err)
+                                               }
+                                               ciph, err = chacha20poly1305.New(key)
+                                               if err != nil {
+                                                       log.Fatalln(err)
+                                               }
+                                               w.buf, err = ciph.Open(w.buf[:0], nonce, w.buf, nil)
+                                               if err != nil {
+                                                       log.Fatalln(err)
+                                               }
+                                               lastMet = true
+                                       }
+                               } else {
+                                       if w.last {
+                                               if _, err = io.ReadFull(kdf, key); err != nil {
+                                                       log.Fatalln(err)
+                                               }
+                                               ciph, err = chacha20poly1305.New(key)
+                                               if err != nil {
+                                                       log.Fatalln(err)
+                                               }
+                                       }
+                                       w.buf = ciph.Seal(w.buf[:0], nonce, w.buf, nil)
+                               }
+                               w.readyOut <- struct{}{}
+                               <-w.readyOut
+                               wg.Done()
+                       }
+               }()
+               workers = append(workers, &w)
        }
+
        go func() {
-               var ctr uint64
+               var ctr int64
+               var w *Worker
+               var err error
                for {
-                       w := workers[ctr%uint64(len(workers))]
-                       if _, err := os.Stdout.Write(<-w.output); err != nil {
-                               panic(err)
+                       w = workers[ctr%int64(len(workers))]
+                       <-w.readyOut
+                       if _, err = os.Stdout.Write(w.buf); err != nil {
+                               log.Fatalln(err)
                        }
-                       w.written <- struct{}{}
+                       w.readyOut <- struct{}{}
                        ctr++
                }
        }()
 
+       var stdin io.Reader
+       if *doRNG {
+               stdin = &DummyReader{}
+       } else {
+               stdin = bufio.NewReaderSize(os.Stdin, *bs+chacha20poly1305.Overhead)
+       }
        var ctr uint64
-       var size int
+       var w *Worker
        for {
-               w := workers[ctr%uint64(len(workers))]
-               <-w.ready
-               if *decrypt {
-                       _, err = io.ReadFull(stdin, w.input[:LenSize])
-                       if err != nil {
-                               if err == io.EOF {
-                                       break
-                               }
-                               panic(err)
-                       }
-                       size = int(binary.BigEndian.Uint32(w.input[:LenSize]))
-                       if _, err = io.ReadFull(
-                               stdin,
-                               w.input[LenSize:LenSize+size+poly1305.TagSize],
-                       ); err != nil {
-                               panic(err)
-                       }
+               w = workers[ctr%uint64(len(workers))]
+               <-w.readyIn
+               if *doDec {
+                       w.buf, err = readBuf(w.buf[:*bs+chacha20poly1305.Overhead], stdin)
                } else {
-                       size, err = stdin.Read(w.input[LenSize : LenSize+bs])
-                       if err != nil {
-                               if err == io.EOF {
-                                       break
-                               }
-                               panic(err)
-                       }
+                       w.buf, err = readBuf(w.buf[:*bs], stdin)
+               }
+               if err != nil && err != io.EOF {
+                       log.Fatalln(err)
+               }
+               if *doDec && len(w.buf) < chacha20poly1305.Overhead {
+                       break
                }
+               w.ctr = ctr
                wg.Add(1)
-               w.task <- Task{ctr, size}
+               if err == io.EOF {
+                       w.last = true
+               }
+               w.readyIn <- struct{}{}
+               if err == io.EOF {
+                       break
+               }
                ctr++
        }
        wg.Wait()
+       if *doDec && !lastMet {
+               log.Fatalln("did not meet explicit last block")
+       }
 }