/* udpobfs -- simple point-to-point UDP obfuscation proxy Copyright (C) 2023 Sergey Matveev This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 3 of the License. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ package main import ( "bufio" "crypto/rand" "crypto/subtle" "encoding/base32" "flag" "fmt" "io" "log" "net" "os" "os/signal" "syscall" "golang.org/x/crypto/blowfish" "golang.org/x/crypto/chacha20" "golang.org/x/crypto/poly1305" "golang.org/x/crypto/sha3" ) const KeyLen = 32 func mustWrite(w io.Writer, data []byte) { if n, err := w.Write(data); err != nil || n != len(data) { log.Fatal("non full write") } } func incr(buf []byte) (overflow bool) { for i := len(buf) - 1; i >= 0; i-- { buf[i]++ if buf[i] != 0 { return } } overflow = true return } type State struct { ourKey, theirKey []byte ourObfs, theirObfs *blowfish.Cipher ourNonce, theirNonce []byte ourSeq, theirSeq []byte } var Base32Codec *base32.Encoding = base32.StdEncoding.WithPadding(base32.NoPadding) func main() { keygen := flag.Bool("keygen", false, "Generate random key") responder := flag.Bool("responder", false, "Are we responder?") bind := flag.String("bind", "[::]:1194", "Address to bind to") dst := flag.String("dst", "[2001:db8::1234]::1194", "Address to connect to") flag.Parse() log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile) if *keygen { key := make([]byte, KeyLen) if _, err := io.ReadFull(rand.Reader, key); err != nil { log.Fatal(err) } fmt.Println(Base32Codec.EncodeToString(key)) return } var state *State stateReady := make(chan struct{}) go func() { first := true s := bufio.NewScanner(os.Stdin) h := sha3.NewShake128() for s.Scan() { key, err := Base32Codec.DecodeString(s.Text()) if err != nil { log.Fatal(err) } if len(key) != KeyLen { log.Fatal("wrong key length") } h.Reset() mustWrite(h, []byte("go.cypherpunks.ru/udpobfs")) mustWrite(h, key) iEncKey := make([]byte, chacha20.KeySize) iBlkKey := make([]byte, 32) rEncKey := make([]byte, chacha20.KeySize) rBlkKey := make([]byte, 32) if _, err := io.ReadFull(h, iEncKey); err != nil { log.Fatal(err) } if _, err := io.ReadFull(h, iBlkKey); err != nil { log.Fatal(err) } if _, err := io.ReadFull(h, rEncKey); err != nil { log.Fatal(err) } if _, err := io.ReadFull(h, rBlkKey); err != nil { log.Fatal(err) } iObfs, err := blowfish.NewCipher(iBlkKey) if err != nil { log.Fatal(err) } rObfs, err := blowfish.NewCipher(rBlkKey) if err != nil { log.Fatal(err) } var newState State if *responder { newState = State{ ourKey: rEncKey, theirKey: iEncKey, ourObfs: rObfs, theirObfs: iObfs, ourNonce: make([]byte, chacha20.NonceSize), theirNonce: make([]byte, chacha20.NonceSize), } } else { newState = State{ ourKey: iEncKey, theirKey: rEncKey, ourObfs: iObfs, theirObfs: rObfs, ourNonce: make([]byte, chacha20.NonceSize), theirNonce: make([]byte, chacha20.NonceSize), } } newState.ourSeq = newState.ourNonce[4:] newState.theirSeq = newState.theirNonce[4:] state = &newState if first { close(stateReady) first = false } } if s.Err() != nil { log.Fatal(s.Err()) } }() <-stateReady addr, err := net.ResolveUDPAddr("udp", *bind) if err != nil { log.Fatal(err) } connBind, err := net.ListenUDP("udp", addr) if err != nil { log.Fatal(err) } var connLocal *net.UDPConn var connRemote *net.UDPConn var addrLocal *net.UDPAddr var addrRemote *net.UDPAddr addr, err = net.ResolveUDPAddr("udp", *dst) if err != nil { log.Fatal(err) } if *responder { connLocal, err = net.DialUDP("udp", nil, addr) } else { connRemote, err = net.DialUDP("udp", nil, addr) } if err != nil { log.Fatal(err) } log.Println(*bind, "->", *dst) go func() { rx := make([]byte, 1<<14) tx := make([]byte, 1<<14) var n int var polyKey [32]byte var s *chacha20.Cipher var p *poly1305.MAC tag := make([]byte, poly1305.TagSize) var from *net.UDPAddr for { if *responder { n, err = connLocal.Read(rx) } else { n, from, err = connBind.ReadFromUDP(rx) } if err != nil { log.Fatal(err) } if *responder && addrRemote == nil { continue } if !*responder && (addrLocal == nil || from.Port != addrLocal.Port || !from.IP.Equal(addrLocal.IP)) { addrLocal = from } if incr(state.ourSeq[5:]) { incr(state.ourSeq[:5]) } copy(tx, state.ourSeq[5:]) s, err = chacha20.NewUnauthenticatedCipher(state.ourKey, state.ourNonce) if err != nil { log.Fatal(err) } clear(polyKey[:]) s.XORKeyStream(polyKey[:], polyKey[:]) s.SetCounter(1) s.XORKeyStream(tx[8:], rx[:n]) p = poly1305.New(&polyKey) mustWrite(p, state.ourSeq) mustWrite(p, tx[8:8+n]) p.Sum(tag[:0]) copy(tx[3:8], tag) state.ourObfs.Encrypt(tx[:8], tx[:8]) if *responder { connBind.WriteTo(tx[:8+n], addrRemote) } else { connRemote.Write(tx[:8+n]) } } }() go func(state **State) { rx := make([]byte, 1<<14) tx := make([]byte, 1<<14) var n int var polyKey [32]byte var s *chacha20.Cipher var p *poly1305.MAC var from *net.UDPAddr tag := make([]byte, poly1305.TagSize) nonce := make([]byte, chacha20.NonceSize) seq := nonce[4:] var seqOur, seqTheir uint32 for { if *responder { n, from, err = connBind.ReadFromUDP(rx) } else { n, err = connRemote.Read(rx) } if err != nil { log.Fatal(err) } if n < 8 { log.Println("too short") continue } if *responder && (addrRemote == nil || from.Port != addrRemote.Port || !from.IP.Equal(addrRemote.IP)) { addrRemote = from } (*state).theirObfs.Decrypt(rx[:8], rx[:8]) seqOur = uint32((*state).theirSeq[0])<<16 | uint32((*state).theirSeq[1])<<8 | uint32((*state).theirSeq[2]) seqTheir = uint32(rx[0])<<16 | uint32(rx[1])<<8 | uint32(rx[2]) if seqOur == seqTheir { log.Println("replay") continue } copy(seq, (*state).theirNonce[:5]) copy(seq[5:], rx[:3]) if seqTheir < seqOur && incr(seq[:5]) { log.Fatal("seq is overflowed") } s, err = chacha20.NewUnauthenticatedCipher((*state).theirKey, nonce) if err != nil { log.Fatal(err) } clear(polyKey[:]) s.XORKeyStream(polyKey[:], polyKey[:]) s.SetCounter(1) p = poly1305.New(&polyKey) mustWrite(p, seq) mustWrite(p, rx[8:n]) p.Sum(tag[:0]) if subtle.ConstantTimeCompare(tag[:5], rx[3:8]) != 1 { log.Print("bad MAC") continue } copy((*state).theirSeq, seq) s.XORKeyStream(tx, rx[8:n]) if *responder { connLocal.Write(tx[:n-8]) } else { connBind.WriteTo(tx[:n-8], addrLocal) } } }(&state) exit := make(chan os.Signal, 1) signal.Notify(exit, syscall.SIGTERM, syscall.SIGINT) <-exit }