--- /dev/null
+/*
+udpobfs -- simple point-to-point UDP obfuscation proxy
+Copyright (C) 2023 Sergey Matveev <stargrave@stargrave.org>
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, version 3 of the License.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package main
+
+import (
+ "crypto/tls"
+ "flag"
+ "io"
+ "log"
+ "log/slog"
+ "net"
+ "os"
+ "sync"
+ "time"
+
+ "go.cypherpunks.ru/udpobfs/v2"
+)
+
+var (
+ DstAddrUDP *net.UDPAddr
+ TLSConfig *tls.Config
+ LnUDP *net.UDPConn
+ LnTCP *net.TCPListener
+ Peers = make(map[string]chan udpobfs.Buf)
+ Bufs = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }}
+)
+
+func newPeer(conn net.Conn) {
+ logger := slog.With("remote", conn.RemoteAddr().String())
+ logger.Info("connected")
+ defer conn.Close()
+ remoteAddr := udpobfs.MustResolveUDPAddr(conn.RemoteAddr().String())
+ localUDP, err := net.DialUDP("udp", nil, DstAddrUDP)
+ if err != nil {
+ log.Fatal(err)
+ }
+ logger = logger.With("local", localUDP.LocalAddr().String())
+ connTLS := tls.Server(conn, TLSConfig)
+ err = connTLS.Handshake()
+ if err != nil {
+ logger.Error(err.Error())
+ return
+ }
+ defer connTLS.Close()
+ tlsState := connTLS.ConnectionState()
+ logger = logger.With("cn", tlsState.PeerCertificates[0].Subject.CommonName)
+ logger.Info("authenticated")
+ seed, err := tlsState.ExportKeyingMaterial(udpobfs.App, nil, udpobfs.SeedLen)
+ if err != nil {
+ logger.Error(err.Error())
+ return
+ }
+ cryptoState := udpobfs.NewCryptoState(seed, false)
+ txs := make(chan udpobfs.Buf)
+ txFinished := make(chan struct{})
+ var rxPkts, txPkts, rxBytes, txBytes int64
+ go func() {
+ pkts := make(chan []byte)
+ pktRead := make(chan struct{})
+ go func() {
+ var n int
+ var err error
+ rx := make([]byte, udpobfs.BufLen)
+ for {
+ localUDP.SetReadDeadline(time.Now().Add(time.Minute))
+ n, err = localUDP.Read(rx)
+ if n != 0 {
+ pkts <- rx[:n]
+ <-pktRead
+ }
+ if err != nil {
+ break
+ }
+ }
+ close(pkts)
+ }()
+ ticker := time.NewTicker(udpobfs.PingDuration)
+ defer ticker.Stop()
+ var pkt []byte
+ var ok bool
+ tx := make([]byte, udpobfs.BufLen)
+ now := time.Now()
+ lastPing := now
+ var got []byte
+ for {
+ select {
+ case <-ticker.C:
+ now = time.Now()
+ if now.Sub(lastPing) > udpobfs.PingDuration {
+ LnUDP.WriteTo(
+ cryptoState.Tx(tx[:udpobfs.SeqLen], nil), remoteAddr)
+ lastPing = now
+ }
+ case pkt, ok = <-pkts:
+ if !ok {
+ close(txFinished)
+ return
+ }
+ got = cryptoState.Tx(tx[:udpobfs.SeqLen+len(pkt)], pkt)
+ pktRead <- struct{}{}
+ LnUDP.WriteTo(got, remoteAddr)
+ txPkts++
+ txBytes += int64(len(pkt))
+ lastPing = time.Now()
+ }
+ }
+ }()
+ go func() {
+ ticker := time.NewTicker(2 * udpobfs.LifetimeDuration)
+ defer ticker.Stop()
+ now := time.Now()
+ last := now
+ buf := make([]byte, udpobfs.BufLen)
+ var tx udpobfs.Buf
+ var got []byte
+ for {
+ select {
+ case <-txFinished:
+ return
+ case <-ticker.C:
+ now = time.Now()
+ if now.Sub(last) > 2*udpobfs.LifetimeDuration {
+ localUDP.Close()
+ return
+ }
+ case tx = <-txs:
+ if tx.Buf == nil {
+ break
+ }
+ if tx.N < udpobfs.SeqLen {
+ logger.Warn("too short")
+ Bufs.Put(tx.Buf)
+ continue
+ }
+ got = cryptoState.Rx(buf[:tx.N], (*tx.Buf)[:tx.N])
+ Bufs.Put(tx.Buf)
+ if got == nil {
+ logger.Warn("bad MAC")
+ continue
+ }
+ if len(got) != 0 {
+ rxPkts++
+ rxBytes += int64(len(got))
+ localUDP.Write(got)
+ }
+ last = time.Now()
+ }
+ }
+ }()
+ Peers[remoteAddr.String()] = txs
+ go func() {
+ buf := make([]byte, 8)
+ for {
+ connTLS.SetReadDeadline(time.Now().Add(2 * udpobfs.LifetimeDuration))
+ if _, err = io.ReadFull(connTLS, buf); err != nil {
+ break
+ }
+ if _, err = connTLS.Write(buf); err != nil {
+ break
+ }
+ }
+ }()
+ <-txFinished
+ logger.Info("finishing",
+ "rxPkts", rxPkts,
+ "rxBytes", rxBytes,
+ "txPkts", txPkts,
+ "txBytes", txBytes)
+ delete(Peers, remoteAddr.String())
+ txs <- udpobfs.Buf{Buf: nil}
+ go func() {
+ for range txs {
+ }
+ }()
+}
+
+func main() {
+ bind := flag.String("bind", "[::]:1194", "Address to bind to")
+ dst := flag.String("dst", "[2001:db8::1234]:1194", "Address to connect to")
+ keypairPath := flag.String("keypair", "keypair.pem", "X.509 keypair")
+ caPath := flag.String("ca", "ca.pem", "CA certificate")
+ flag.Parse()
+ log.SetFlags(log.Lshortfile)
+ slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
+
+ crtRaw, _, err := udpobfs.CertificateFromFile(*keypairPath)
+ if err != nil {
+ log.Fatal(err)
+ }
+ prv, err := udpobfs.PrivateKeyFromFile(*keypairPath)
+ if err != nil {
+ log.Fatal(err)
+ }
+ TLSConfig = &tls.Config{
+ MinVersion: tls.VersionTLS13,
+ ClientAuth: tls.RequireAndVerifyClientCert,
+ Certificates: []tls.Certificate{{
+ Certificate: [][]byte{crtRaw},
+ PrivateKey: prv,
+ }},
+ }
+ _, TLSConfig.ClientCAs, err = udpobfs.CertPoolFromFile(*caPath)
+ if err != nil {
+ log.Fatalln(err)
+ }
+
+ DstAddrUDP = udpobfs.MustResolveUDPAddr(*dst)
+ LnUDP, err = net.ListenUDP("udp", udpobfs.MustResolveUDPAddr(*bind))
+ if err != nil {
+ log.Fatal(err)
+ }
+ LnTCP, err = net.ListenTCP("tcp", udpobfs.MustResolveTCPAddr(*bind))
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ go func() {
+ var n int
+ var from net.Addr
+ var txs chan udpobfs.Buf
+ var buf *[udpobfs.BufLen]byte
+ for {
+ buf = Bufs.Get().(*[udpobfs.BufLen]byte)
+ n, from, _ = LnUDP.ReadFrom((*buf)[:])
+ if n == 0 {
+ continue
+ }
+ txs = Peers[from.String()]
+ if txs != nil {
+ txs <- udpobfs.Buf{Buf: buf, N: n}
+ }
+ }
+ }()
+
+ for {
+ conn, err := LnTCP.Accept()
+ if err != nil {
+ slog.Error(err.Error())
+ continue
+ }
+ go newPeer(conn)
+ }
+}