X-Git-Url: http://www.git.cypherpunks.ru/?p=gohpenc.git;a=blobdiff_plain;f=main.go;fp=main.go;h=9eacacb03ea5377b8c37e045ca2fec81e0323af0;hp=43eb680f8f1e29193d12bc36a81adc60fa097fc2;hb=d863cee82aad34900144198abf740bb7f75a4642;hpb=43162227c8405de5a2d835ee306459993ff0ba6d diff --git a/main.go b/main.go index 43eb680..9eacacb 100644 --- a/main.go +++ b/main.go @@ -20,196 +20,253 @@ package main import ( "bufio" - "crypto/cipher" "crypto/rand" + "crypto/sha512" "encoding/base32" "encoding/binary" "flag" "fmt" "io" + "log" "os" "runtime" "sync" + "golang.org/x/crypto/chacha20" "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/hkdf" "golang.org/x/crypto/poly1305" ) const ( - LenSize = 4 + Magic = "GOHPENC\n" SaltSize = 16 ) -var ( - doPSK = flag.Bool("psk", false, "Generate PSK") - keyB32 = flag.String("k", "", "Encryption key") - decrypt = flag.Bool("d", false, "Decrypt, instead of encrypt") - blockSize = flag.Int("b", 1<<10, "Blocksize, in KiB") - threads = flag.Int("c", runtime.NumCPU(), "Number of threads") - - Base32Codec *base32.Encoding = base32.StdEncoding.WithPadding(base32.NoPadding) - - key []byte - bs int - wg sync.WaitGroup -) - -type Task struct { - ctr uint64 - size int -} +var Base32Codec *base32.Encoding = base32.StdEncoding.WithPadding(base32.NoPadding) type Worker struct { - aead cipher.AEAD - nonce []byte - input []byte - ready chan struct{} - task chan Task - output chan []byte - written chan struct{} + ctr uint64 + buf []byte + readyIn chan struct{} + readyOut chan struct{} + last bool } -func NewWorker(key, salt []byte) *Worker { - aead, err := chacha20poly1305.NewX(key) - if err != nil { - panic(err) - } - w := Worker{ - aead: aead, - nonce: make([]byte, chacha20poly1305.NonceSizeX), - input: make([]byte, LenSize+bs+poly1305.TagSize), - ready: make(chan struct{}), - task: make(chan Task), - output: make(chan []byte), - written: make(chan struct{}), - } - copy(w.nonce, salt) - go w.Run() - return &w -} - -func (w *Worker) Run() { - var output []byte +func readBuf(dst []byte, src io.Reader) ([]byte, error) { + var n, full int var err error - for { - w.ready <- struct{}{} - task := <-w.task - binary.BigEndian.PutUint64(w.nonce[SaltSize:], task.ctr) - if *decrypt { - output, err = w.aead.Open( - w.input[:LenSize], - w.nonce, - w.input[LenSize:LenSize+task.size+poly1305.TagSize], - w.input[:LenSize], - ) - if err != nil { - panic(err) + for full < len(dst) { + n, err = src.Read(dst[full:]) + full += n + if err != nil { + if err == io.EOF { + break } - output = output[LenSize:] - } else { - binary.BigEndian.PutUint32(w.input, uint32(task.size)) - output = w.aead.Seal( - w.input[:LenSize], - w.nonce, - w.input[LenSize:LenSize+task.size], - w.input[:LenSize], - ) - } - w.output <- output - <-w.written - wg.Done() + return nil, err + } } + return dst[:full], err +} + +type DummyReader struct{} + +func (r *DummyReader) Read(b []byte) (int, error) { + return len(b), nil } func main() { + doRNG := flag.Bool("r", false, "Random number generator") + doGen := flag.Bool("psk", false, "Generate key") + doDec := flag.Bool("d", false, "Decrypt, instead of encrypt") + bs := flag.Int("b", 1<<10, "Blocksize, KiB") + jobs := flag.Int("c", runtime.NumCPU(), "Number of parallel threads") + keyB32 := flag.String("k", "", "Encryption key") flag.Parse() + log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile) - if *doPSK { - key := make([]byte, chacha20poly1305.KeySize) + var err error + key := make([]byte, chacha20poly1305.KeySize) + if *doGen || *doRNG { if _, err := io.ReadFull(rand.Reader, key); err != nil { - panic(err) + log.Fatalln(err) + } + if *doGen { + fmt.Println(Base32Codec.EncodeToString(key)) + return + } + } else { + key, err = Base32Codec.DecodeString(*keyB32) + if err != nil { + log.Fatalln(err) + } + if len(key) != chacha20poly1305.KeySize { + log.Fatalln("invalid key size") } - fmt.Println(Base32Codec.EncodeToString(key)) - return } - var err error - key, err = Base32Codec.DecodeString(*keyB32) - if err != nil { - panic(err) - } - if len(key) != chacha20poly1305.KeySize { - panic("Invalid key size") - } salt := make([]byte, SaltSize) - - if *decrypt { + if *doDec { + if _, err = io.ReadFull(os.Stdin, salt[:len(Magic)]); err != nil { + log.Fatalln(err) + } + if string(salt[:len(Magic)]) != Magic { + log.Fatalln("invalid magic") + } + if _, err = io.ReadFull(os.Stdin, salt[:4]); err != nil { + log.Fatalln(err) + } + *bs = int(binary.BigEndian.Uint32(salt[:4])) if _, err = io.ReadFull(os.Stdin, salt); err != nil { - panic(err) + log.Fatalln(err) } } else { + if _, err = os.Stdout.WriteString(Magic); err != nil { + log.Fatalln(err) + } + *bs = *bs * 1024 + binary.BigEndian.PutUint32(salt, uint32(*bs)) + if _, err = os.Stdout.Write(salt[:4]); err != nil { + log.Fatalln(err) + } if _, err = io.ReadFull(rand.Reader, salt); err != nil { - panic(err) + log.Fatalln(err) } if _, err = os.Stdout.Write(salt); err != nil { - panic(err) + log.Fatalln(err) } } - bs = *blockSize * (1 << 10) - if bs > 1<<32 { - panic("blocksize exceeds 32-bits") + kdf := hkdf.New(sha512.New, key, salt, []byte(Magic)) + if _, err = io.ReadFull(kdf, key); err != nil { + log.Fatalln(err) } - stdin := bufio.NewReaderSize(os.Stdin, LenSize+bs+poly1305.TagSize) - workers := make([]*Worker, *threads) - for i := 0; i < *threads; i++ { - workers[i] = NewWorker(key, salt) + var wg sync.WaitGroup + var lastMet bool + workers := make([]*Worker, 0, *jobs) + for i := 0; i < *jobs; i++ { + w := Worker{ + buf: make([]byte, *bs+chacha20poly1305.Overhead), + readyIn: make(chan struct{}), + readyOut: make(chan struct{}), + } + go func() { + ciph, err := chacha20poly1305.New(key) + if err != nil { + log.Fatalln(err) + } + nonce := make([]byte, chacha20poly1305.NonceSize) + var ciphertext, tag []byte + var s *chacha20.Cipher + var p *poly1305.MAC + for { + w.readyIn <- struct{}{} + <-w.readyIn + binary.BigEndian.PutUint64(nonce, w.ctr) + if *doDec { + tag = w.buf[len(w.buf)-poly1305.TagSize:] + ciphertext = w.buf[:len(w.buf)-poly1305.TagSize] + s, err = chacha20.NewUnauthenticatedCipher(key, nonce) + if err != nil { + log.Fatalln(err) + } + var polyKey [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.SetCounter(1) + p = poly1305.New(&polyKey) + writeWithPadding(p, nil) + writeWithPadding(p, ciphertext) + writeUint64(p, 0) + writeUint64(p, len(ciphertext)) + if p.Verify(tag) { + w.buf = ciphertext + s.XORKeyStream(ciphertext, ciphertext) + } else { + lastMet = true + if _, err = io.ReadFull(kdf, key); err != nil { + log.Fatalln(err) + } + ciph, err = chacha20poly1305.New(key) + if err != nil { + log.Fatalln(err) + } + w.buf, err = ciph.Open(w.buf[:0], nonce, w.buf, nil) + if err != nil { + log.Fatalln(err) + } + lastMet = true + } + } else { + if w.last { + if _, err = io.ReadFull(kdf, key); err != nil { + log.Fatalln(err) + } + ciph, err = chacha20poly1305.New(key) + if err != nil { + log.Fatalln(err) + } + } + w.buf = ciph.Seal(w.buf[:0], nonce, w.buf, nil) + } + w.readyOut <- struct{}{} + <-w.readyOut + wg.Done() + } + }() + workers = append(workers, &w) } + go func() { - var ctr uint64 + var ctr int64 + var w *Worker + var err error for { - w := workers[ctr%uint64(len(workers))] - if _, err := os.Stdout.Write(<-w.output); err != nil { - panic(err) + w = workers[ctr%int64(len(workers))] + <-w.readyOut + if _, err = os.Stdout.Write(w.buf); err != nil { + log.Fatalln(err) } - w.written <- struct{}{} + w.readyOut <- struct{}{} ctr++ } }() + var stdin io.Reader + if *doRNG { + stdin = &DummyReader{} + } else { + stdin = bufio.NewReaderSize(os.Stdin, *bs+chacha20poly1305.Overhead) + } var ctr uint64 - var size int + var w *Worker for { - w := workers[ctr%uint64(len(workers))] - <-w.ready - if *decrypt { - _, err = io.ReadFull(stdin, w.input[:LenSize]) - if err != nil { - if err == io.EOF { - break - } - panic(err) - } - size = int(binary.BigEndian.Uint32(w.input[:LenSize])) - if _, err = io.ReadFull( - stdin, - w.input[LenSize:LenSize+size+poly1305.TagSize], - ); err != nil { - panic(err) - } + w = workers[ctr%uint64(len(workers))] + <-w.readyIn + if *doDec { + w.buf, err = readBuf(w.buf[:*bs+chacha20poly1305.Overhead], stdin) } else { - size, err = stdin.Read(w.input[LenSize : LenSize+bs]) - if err != nil { - if err == io.EOF { - break - } - panic(err) - } + w.buf, err = readBuf(w.buf[:*bs], stdin) + } + if err != nil && err != io.EOF { + log.Fatalln(err) + } + if *doDec && len(w.buf) < chacha20poly1305.Overhead { + break } + w.ctr = ctr wg.Add(1) - w.task <- Task{ctr, size} + if err == io.EOF { + w.last = true + } + w.readyIn <- struct{}{} + if err == io.EOF { + break + } ctr++ } wg.Wait() + if *doDec && !lastMet { + log.Fatalln("did not meet explicit last block") + } }