]> Cypherpunks.ru repositories - udpobfs.git/commitdiff
Reset nonces with the new key
authorSergey Matveev <stargrave@stargrave.org>
Thu, 17 Aug 2023 12:24:36 +0000 (15:24 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Thu, 17 Aug 2023 12:24:36 +0000 (15:24 +0300)
PROTOCOL
main.go

index 4cef2f15a80410a191ea86579a55b0c488eb45602111838ccfba0b6ba4a36e48..257eacd441afabc58e595582e537d0a621fb9e63c1a0441c61a95c952821f887 100644 (file)
--- a/PROTOCOL
+++ b/PROTOCOL
@@ -1,5 +1,5 @@
-Protocol is trivial. Both peers has shared 256-bit key. SHA3 is used to
-derive four more keys from it:
+Protocol is trivial. Both peers have shared 256-bit key.
+SHA3 is used to derive four more keys from it:
 
     SHAKE128("go.cypherpunks.ru/udpobfs" || key) ->
         256-bit InitiatorEncryptionKey ||
@@ -7,11 +7,11 @@ derive four more keys from it:
         256-bit ResponderEncryptionKey ||
         256-bit ResponderObfuscationKey
 
-Each side has 64-bit packet number counter, that is used as a nonce.
-That counter is kept in memory and only its lower 24 bits are sent.
-When remote side receives 24-bit counter with lower value, then it
-increments in-memory counter's remaining part. Completely the same
-as Extended Sequence Numbers are done in IPsec's ESP.
+Each side has big-endian 64-bit packet number counter, that is used as a
+nonce. That counter is kept in memory and only its lower 24 bits are
+sent. When remote side receives 24-bit counter with lower value, then it
+increments in-memory counter's remaining part. Completely the same as
+Extended Sequence Numbers are done in IPsec's ESP.
 
 ChaCha20 is initialised with corresponding EncryptionKey and nonce equal
 to the full sequence number value. Its first 256-bit of output will be
diff --git a/main.go b/main.go
index 7c71433b3f51cae7fb19312bb7ccb232cae33a9dda9b2da5a6e452663a03c08a..b4c317e48200e7ffe0e120cf032c9c10257b65c58ef2f0a0fe363c64ae779035 100644 (file)
--- a/main.go
+++ b/main.go
@@ -56,9 +56,11 @@ func incr(buf []byte) (overflow bool) {
        return
 }
 
-type Keys struct {
-       ourKey, theirKey   []byte
-       ourObfs, theirObfs *blowfish.Cipher
+type State struct {
+       ourKey, theirKey     []byte
+       ourObfs, theirObfs   *blowfish.Cipher
+       ourNonce, theirNonce []byte
+       ourSeq, theirSeq     []byte
 }
 
 var Base32Codec *base32.Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
@@ -80,8 +82,8 @@ func main() {
                return
        }
 
-       keys := &Keys{}
-       hasKeys := make(chan struct{})
+       var state *State
+       stateReady := make(chan struct{})
        go func() {
                first := true
                s := bufio.NewScanner(os.Stdin)
@@ -121,13 +123,31 @@ func main() {
                        if err != nil {
                                log.Fatal(err)
                        }
+                       var newState State
                        if *responder {
-                               keys.ourKey, keys.theirKey, keys.ourObfs, keys.theirObfs = rEncKey, iEncKey, rObfs, iObfs
+                               newState = State{
+                                       ourKey:     rEncKey,
+                                       theirKey:   iEncKey,
+                                       ourObfs:    rObfs,
+                                       theirObfs:  iObfs,
+                                       ourNonce:   make([]byte, chacha20.NonceSize),
+                                       theirNonce: make([]byte, chacha20.NonceSize),
+                               }
                        } else {
-                               keys.ourKey, keys.theirKey, keys.ourObfs, keys.theirObfs = iEncKey, rEncKey, iObfs, rObfs
+                               newState = State{
+                                       ourKey:     iEncKey,
+                                       theirKey:   rEncKey,
+                                       ourObfs:    iObfs,
+                                       theirObfs:  rObfs,
+                                       ourNonce:   make([]byte, chacha20.NonceSize),
+                                       theirNonce: make([]byte, chacha20.NonceSize),
+                               }
                        }
+                       newState.ourSeq = newState.ourNonce[4:]
+                       newState.theirSeq = newState.theirNonce[4:]
+                       state = &newState
                        if first {
-                               close(hasKeys)
+                               close(stateReady)
                                first = false
                        }
                }
@@ -135,7 +155,7 @@ func main() {
                        log.Fatal(s.Err())
                }
        }()
-       <-hasKeys
+       <-stateReady
 
        addr, err := net.ResolveUDPAddr("udp", *bind)
        if err != nil {
@@ -173,8 +193,6 @@ func main() {
                var p *poly1305.MAC
                tag := make([]byte, poly1305.TagSize)
                var from *net.UDPAddr
-               nonce := make([]byte, chacha20.NonceSize)
-               seq := nonce[4:]
                for {
                        if *responder {
                                n, err = connLocal.Read(rx)
@@ -191,11 +209,11 @@ func main() {
                                from.Port != addrLocal.Port || !from.IP.Equal(addrLocal.IP)) {
                                addrLocal = from
                        }
-                       if incr(seq[5:]) {
-                               incr(seq[:5])
+                       if incr(state.ourSeq[5:]) {
+                               incr(state.ourSeq[:5])
                        }
-                       copy(tx, seq[5:])
-                       s, err = chacha20.NewUnauthenticatedCipher(keys.ourKey, nonce)
+                       copy(tx, state.ourSeq[5:])
+                       s, err = chacha20.NewUnauthenticatedCipher(state.ourKey, state.ourNonce)
                        if err != nil {
                                log.Fatal(err)
                        }
@@ -204,11 +222,11 @@ func main() {
                        s.SetCounter(1)
                        s.XORKeyStream(tx[8:], rx[:n])
                        p = poly1305.New(&polyKey)
-                       mustWrite(p, seq)
+                       mustWrite(p, state.ourSeq)
                        mustWrite(p, tx[8:8+n])
                        p.Sum(tag[:0])
                        copy(tx[3:8], tag)
-                       keys.ourObfs.Encrypt(tx[:8], tx[:8])
+                       state.ourObfs.Encrypt(tx[:8], tx[:8])
                        if *responder {
                                connBind.WriteTo(tx[:8+n], addrRemote)
                        } else {
@@ -216,7 +234,7 @@ func main() {
                        }
                }
        }()
-       go func() {
+       go func(state **State) {
                rx := make([]byte, 1<<14)
                tx := make([]byte, 1<<14)
                var n int
@@ -225,8 +243,6 @@ func main() {
                var p *poly1305.MAC
                var from *net.UDPAddr
                tag := make([]byte, poly1305.TagSize)
-               ourNonce := make([]byte, chacha20.NonceSize)
-               ourSeq := ourNonce[4:]
                nonce := make([]byte, chacha20.NonceSize)
                seq := nonce[4:]
                var seqOur, seqTheir uint32
@@ -247,19 +263,21 @@ func main() {
                                from.Port != addrRemote.Port || !from.IP.Equal(addrRemote.IP)) {
                                addrRemote = from
                        }
-                       keys.theirObfs.Decrypt(rx[:8], rx[:8])
-                       seqOur = uint32(ourSeq[0])<<16 | uint32(ourSeq[1])<<8 | uint32(ourSeq[2])
+                       (*state).theirObfs.Decrypt(rx[:8], rx[:8])
+                       seqOur = uint32((*state).theirSeq[0])<<16 |
+                               uint32((*state).theirSeq[1])<<8 |
+                               uint32((*state).theirSeq[2])
                        seqTheir = uint32(rx[0])<<16 | uint32(rx[1])<<8 | uint32(rx[2])
                        if seqOur == seqTheir {
                                log.Println("replay")
                                continue
                        }
-                       copy(seq, ourNonce[:5])
+                       copy(seq, (*state).theirNonce[:5])
                        copy(seq[5:], rx[:3])
                        if seqTheir < seqOur && incr(seq[:5]) {
                                log.Fatal("seq is overflowed")
                        }
-                       s, err = chacha20.NewUnauthenticatedCipher(keys.theirKey, nonce)
+                       s, err = chacha20.NewUnauthenticatedCipher((*state).theirKey, nonce)
                        if err != nil {
                                log.Fatal(err)
                        }
@@ -274,7 +292,7 @@ func main() {
                                log.Print("bad MAC")
                                continue
                        }
-                       copy(ourSeq, seq)
+                       copy((*state).theirSeq, seq)
                        s.XORKeyStream(tx, rx[8:n])
                        if *responder {
                                connLocal.Write(tx[:n-8])
@@ -282,7 +300,7 @@ func main() {
                                connBind.WriteTo(tx[:n-8], addrLocal)
                        }
                }
-       }()
+       }(&state)
        exit := make(chan os.Signal, 1)
        signal.Notify(exit, syscall.SIGTERM, syscall.SIGINT)
        <-exit