]> Cypherpunks.ru repositories - udpobfs.git/blobdiff - cmd/resp/main.go
v2
[udpobfs.git] / cmd / resp / main.go
diff --git a/cmd/resp/main.go b/cmd/resp/main.go
new file mode 100644 (file)
index 0000000..cc97336
--- /dev/null
@@ -0,0 +1,258 @@
+/*
+udpobfs -- simple point-to-point UDP obfuscation proxy
+Copyright (C) 2023 Sergey Matveev <stargrave@stargrave.org>
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, version 3 of the License.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program.  If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package main
+
+import (
+       "crypto/tls"
+       "flag"
+       "io"
+       "log"
+       "log/slog"
+       "net"
+       "os"
+       "sync"
+       "time"
+
+       "go.cypherpunks.ru/udpobfs/v2"
+)
+
+var (
+       DstAddrUDP *net.UDPAddr
+       TLSConfig  *tls.Config
+       LnUDP      *net.UDPConn
+       LnTCP      *net.TCPListener
+       Peers      = make(map[string]chan udpobfs.Buf)
+       Bufs       = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }}
+)
+
+func newPeer(conn net.Conn) {
+       logger := slog.With("remote", conn.RemoteAddr().String())
+       logger.Info("connected")
+       defer conn.Close()
+       remoteAddr := udpobfs.MustResolveUDPAddr(conn.RemoteAddr().String())
+       localUDP, err := net.DialUDP("udp", nil, DstAddrUDP)
+       if err != nil {
+               log.Fatal(err)
+       }
+       logger = logger.With("local", localUDP.LocalAddr().String())
+       connTLS := tls.Server(conn, TLSConfig)
+       err = connTLS.Handshake()
+       if err != nil {
+               logger.Error(err.Error())
+               return
+       }
+       defer connTLS.Close()
+       tlsState := connTLS.ConnectionState()
+       logger = logger.With("cn", tlsState.PeerCertificates[0].Subject.CommonName)
+       logger.Info("authenticated")
+       seed, err := tlsState.ExportKeyingMaterial(udpobfs.App, nil, udpobfs.SeedLen)
+       if err != nil {
+               logger.Error(err.Error())
+               return
+       }
+       cryptoState := udpobfs.NewCryptoState(seed, false)
+       txs := make(chan udpobfs.Buf)
+       txFinished := make(chan struct{})
+       var rxPkts, txPkts, rxBytes, txBytes int64
+       go func() {
+               pkts := make(chan []byte)
+               pktRead := make(chan struct{})
+               go func() {
+                       var n int
+                       var err error
+                       rx := make([]byte, udpobfs.BufLen)
+                       for {
+                               localUDP.SetReadDeadline(time.Now().Add(time.Minute))
+                               n, err = localUDP.Read(rx)
+                               if n != 0 {
+                                       pkts <- rx[:n]
+                                       <-pktRead
+                               }
+                               if err != nil {
+                                       break
+                               }
+                       }
+                       close(pkts)
+               }()
+               ticker := time.NewTicker(udpobfs.PingDuration)
+               defer ticker.Stop()
+               var pkt []byte
+               var ok bool
+               tx := make([]byte, udpobfs.BufLen)
+               now := time.Now()
+               lastPing := now
+               var got []byte
+               for {
+                       select {
+                       case <-ticker.C:
+                               now = time.Now()
+                               if now.Sub(lastPing) > udpobfs.PingDuration {
+                                       LnUDP.WriteTo(
+                                               cryptoState.Tx(tx[:udpobfs.SeqLen], nil), remoteAddr)
+                                       lastPing = now
+                               }
+                       case pkt, ok = <-pkts:
+                               if !ok {
+                                       close(txFinished)
+                                       return
+                               }
+                               got = cryptoState.Tx(tx[:udpobfs.SeqLen+len(pkt)], pkt)
+                               pktRead <- struct{}{}
+                               LnUDP.WriteTo(got, remoteAddr)
+                               txPkts++
+                               txBytes += int64(len(pkt))
+                               lastPing = time.Now()
+                       }
+               }
+       }()
+       go func() {
+               ticker := time.NewTicker(2 * udpobfs.LifetimeDuration)
+               defer ticker.Stop()
+               now := time.Now()
+               last := now
+               buf := make([]byte, udpobfs.BufLen)
+               var tx udpobfs.Buf
+               var got []byte
+               for {
+                       select {
+                       case <-txFinished:
+                               return
+                       case <-ticker.C:
+                               now = time.Now()
+                               if now.Sub(last) > 2*udpobfs.LifetimeDuration {
+                                       localUDP.Close()
+                                       return
+                               }
+                       case tx = <-txs:
+                               if tx.Buf == nil {
+                                       break
+                               }
+                               if tx.N < udpobfs.SeqLen {
+                                       logger.Warn("too short")
+                                       Bufs.Put(tx.Buf)
+                                       continue
+                               }
+                               got = cryptoState.Rx(buf[:tx.N], (*tx.Buf)[:tx.N])
+                               Bufs.Put(tx.Buf)
+                               if got == nil {
+                                       logger.Warn("bad MAC")
+                                       continue
+                               }
+                               if len(got) != 0 {
+                                       rxPkts++
+                                       rxBytes += int64(len(got))
+                                       localUDP.Write(got)
+                               }
+                               last = time.Now()
+                       }
+               }
+       }()
+       Peers[remoteAddr.String()] = txs
+       go func() {
+               buf := make([]byte, 8)
+               for {
+                       connTLS.SetReadDeadline(time.Now().Add(2 * udpobfs.LifetimeDuration))
+                       if _, err = io.ReadFull(connTLS, buf); err != nil {
+                               break
+                       }
+                       if _, err = connTLS.Write(buf); err != nil {
+                               break
+                       }
+               }
+       }()
+       <-txFinished
+       logger.Info("finishing",
+               "rxPkts", rxPkts,
+               "rxBytes", rxBytes,
+               "txPkts", txPkts,
+               "txBytes", txBytes)
+       delete(Peers, remoteAddr.String())
+       txs <- udpobfs.Buf{Buf: nil}
+       go func() {
+               for range txs {
+               }
+       }()
+}
+
+func main() {
+       bind := flag.String("bind", "[::]:1194", "Address to bind to")
+       dst := flag.String("dst", "[2001:db8::1234]:1194", "Address to connect to")
+       keypairPath := flag.String("keypair", "keypair.pem", "X.509 keypair")
+       caPath := flag.String("ca", "ca.pem", "CA certificate")
+       flag.Parse()
+       log.SetFlags(log.Lshortfile)
+       slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
+
+       crtRaw, _, err := udpobfs.CertificateFromFile(*keypairPath)
+       if err != nil {
+               log.Fatal(err)
+       }
+       prv, err := udpobfs.PrivateKeyFromFile(*keypairPath)
+       if err != nil {
+               log.Fatal(err)
+       }
+       TLSConfig = &tls.Config{
+               MinVersion: tls.VersionTLS13,
+               ClientAuth: tls.RequireAndVerifyClientCert,
+               Certificates: []tls.Certificate{{
+                       Certificate: [][]byte{crtRaw},
+                       PrivateKey:  prv,
+               }},
+       }
+       _, TLSConfig.ClientCAs, err = udpobfs.CertPoolFromFile(*caPath)
+       if err != nil {
+               log.Fatalln(err)
+       }
+
+       DstAddrUDP = udpobfs.MustResolveUDPAddr(*dst)
+       LnUDP, err = net.ListenUDP("udp", udpobfs.MustResolveUDPAddr(*bind))
+       if err != nil {
+               log.Fatal(err)
+       }
+       LnTCP, err = net.ListenTCP("tcp", udpobfs.MustResolveTCPAddr(*bind))
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       go func() {
+               var n int
+               var from net.Addr
+               var txs chan udpobfs.Buf
+               var buf *[udpobfs.BufLen]byte
+               for {
+                       buf = Bufs.Get().(*[udpobfs.BufLen]byte)
+                       n, from, _ = LnUDP.ReadFrom((*buf)[:])
+                       if n == 0 {
+                               continue
+                       }
+                       txs = Peers[from.String()]
+                       if txs != nil {
+                               txs <- udpobfs.Buf{Buf: buf, N: n}
+                       }
+               }
+       }()
+
+       for {
+               conn, err := LnTCP.Accept()
+               if err != nil {
+                       slog.Error(err.Error())
+                       continue
+               }
+               go newPeer(conn)
+       }
+}