]> Cypherpunks.ru repositories - udpobfs.git/blobdiff - main.go
Reset nonces with the new key
[udpobfs.git] / main.go
diff --git a/main.go b/main.go
index 7c71433b3f51cae7fb19312bb7ccb232cae33a9dda9b2da5a6e452663a03c08a..b4c317e48200e7ffe0e120cf032c9c10257b65c58ef2f0a0fe363c64ae779035 100644 (file)
--- a/main.go
+++ b/main.go
@@ -56,9 +56,11 @@ func incr(buf []byte) (overflow bool) {
        return
 }
 
-type Keys struct {
-       ourKey, theirKey   []byte
-       ourObfs, theirObfs *blowfish.Cipher
+type State struct {
+       ourKey, theirKey     []byte
+       ourObfs, theirObfs   *blowfish.Cipher
+       ourNonce, theirNonce []byte
+       ourSeq, theirSeq     []byte
 }
 
 var Base32Codec *base32.Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
@@ -80,8 +82,8 @@ func main() {
                return
        }
 
-       keys := &Keys{}
-       hasKeys := make(chan struct{})
+       var state *State
+       stateReady := make(chan struct{})
        go func() {
                first := true
                s := bufio.NewScanner(os.Stdin)
@@ -121,13 +123,31 @@ func main() {
                        if err != nil {
                                log.Fatal(err)
                        }
+                       var newState State
                        if *responder {
-                               keys.ourKey, keys.theirKey, keys.ourObfs, keys.theirObfs = rEncKey, iEncKey, rObfs, iObfs
+                               newState = State{
+                                       ourKey:     rEncKey,
+                                       theirKey:   iEncKey,
+                                       ourObfs:    rObfs,
+                                       theirObfs:  iObfs,
+                                       ourNonce:   make([]byte, chacha20.NonceSize),
+                                       theirNonce: make([]byte, chacha20.NonceSize),
+                               }
                        } else {
-                               keys.ourKey, keys.theirKey, keys.ourObfs, keys.theirObfs = iEncKey, rEncKey, iObfs, rObfs
+                               newState = State{
+                                       ourKey:     iEncKey,
+                                       theirKey:   rEncKey,
+                                       ourObfs:    iObfs,
+                                       theirObfs:  rObfs,
+                                       ourNonce:   make([]byte, chacha20.NonceSize),
+                                       theirNonce: make([]byte, chacha20.NonceSize),
+                               }
                        }
+                       newState.ourSeq = newState.ourNonce[4:]
+                       newState.theirSeq = newState.theirNonce[4:]
+                       state = &newState
                        if first {
-                               close(hasKeys)
+                               close(stateReady)
                                first = false
                        }
                }
@@ -135,7 +155,7 @@ func main() {
                        log.Fatal(s.Err())
                }
        }()
-       <-hasKeys
+       <-stateReady
 
        addr, err := net.ResolveUDPAddr("udp", *bind)
        if err != nil {
@@ -173,8 +193,6 @@ func main() {
                var p *poly1305.MAC
                tag := make([]byte, poly1305.TagSize)
                var from *net.UDPAddr
-               nonce := make([]byte, chacha20.NonceSize)
-               seq := nonce[4:]
                for {
                        if *responder {
                                n, err = connLocal.Read(rx)
@@ -191,11 +209,11 @@ func main() {
                                from.Port != addrLocal.Port || !from.IP.Equal(addrLocal.IP)) {
                                addrLocal = from
                        }
-                       if incr(seq[5:]) {
-                               incr(seq[:5])
+                       if incr(state.ourSeq[5:]) {
+                               incr(state.ourSeq[:5])
                        }
-                       copy(tx, seq[5:])
-                       s, err = chacha20.NewUnauthenticatedCipher(keys.ourKey, nonce)
+                       copy(tx, state.ourSeq[5:])
+                       s, err = chacha20.NewUnauthenticatedCipher(state.ourKey, state.ourNonce)
                        if err != nil {
                                log.Fatal(err)
                        }
@@ -204,11 +222,11 @@ func main() {
                        s.SetCounter(1)
                        s.XORKeyStream(tx[8:], rx[:n])
                        p = poly1305.New(&polyKey)
-                       mustWrite(p, seq)
+                       mustWrite(p, state.ourSeq)
                        mustWrite(p, tx[8:8+n])
                        p.Sum(tag[:0])
                        copy(tx[3:8], tag)
-                       keys.ourObfs.Encrypt(tx[:8], tx[:8])
+                       state.ourObfs.Encrypt(tx[:8], tx[:8])
                        if *responder {
                                connBind.WriteTo(tx[:8+n], addrRemote)
                        } else {
@@ -216,7 +234,7 @@ func main() {
                        }
                }
        }()
-       go func() {
+       go func(state **State) {
                rx := make([]byte, 1<<14)
                tx := make([]byte, 1<<14)
                var n int
@@ -225,8 +243,6 @@ func main() {
                var p *poly1305.MAC
                var from *net.UDPAddr
                tag := make([]byte, poly1305.TagSize)
-               ourNonce := make([]byte, chacha20.NonceSize)
-               ourSeq := ourNonce[4:]
                nonce := make([]byte, chacha20.NonceSize)
                seq := nonce[4:]
                var seqOur, seqTheir uint32
@@ -247,19 +263,21 @@ func main() {
                                from.Port != addrRemote.Port || !from.IP.Equal(addrRemote.IP)) {
                                addrRemote = from
                        }
-                       keys.theirObfs.Decrypt(rx[:8], rx[:8])
-                       seqOur = uint32(ourSeq[0])<<16 | uint32(ourSeq[1])<<8 | uint32(ourSeq[2])
+                       (*state).theirObfs.Decrypt(rx[:8], rx[:8])
+                       seqOur = uint32((*state).theirSeq[0])<<16 |
+                               uint32((*state).theirSeq[1])<<8 |
+                               uint32((*state).theirSeq[2])
                        seqTheir = uint32(rx[0])<<16 | uint32(rx[1])<<8 | uint32(rx[2])
                        if seqOur == seqTheir {
                                log.Println("replay")
                                continue
                        }
-                       copy(seq, ourNonce[:5])
+                       copy(seq, (*state).theirNonce[:5])
                        copy(seq[5:], rx[:3])
                        if seqTheir < seqOur && incr(seq[:5]) {
                                log.Fatal("seq is overflowed")
                        }
-                       s, err = chacha20.NewUnauthenticatedCipher(keys.theirKey, nonce)
+                       s, err = chacha20.NewUnauthenticatedCipher((*state).theirKey, nonce)
                        if err != nil {
                                log.Fatal(err)
                        }
@@ -274,7 +292,7 @@ func main() {
                                log.Print("bad MAC")
                                continue
                        }
-                       copy(ourSeq, seq)
+                       copy((*state).theirSeq, seq)
                        s.XORKeyStream(tx, rx[8:n])
                        if *responder {
                                connLocal.Write(tx[:n-8])
@@ -282,7 +300,7 @@ func main() {
                                connBind.WriteTo(tx[:n-8], addrLocal)
                        }
                }
-       }()
+       }(&state)
        exit := make(chan os.Signal, 1)
        signal.Notify(exit, syscall.SIGTERM, syscall.SIGINT)
        <-exit