]> Cypherpunks.ru repositories - udpobfs.git/blob - cmd/resp/main.go
Unify copyright comment format
[udpobfs.git] / cmd / resp / main.go
1 // udpobfs -- simple point-to-point UDP obfuscation proxy
2 // Copyright (C) 2023-2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU General Public License as published by
6 // the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "crypto/tls"
20         "flag"
21         "io"
22         "log"
23         "log/slog"
24         "net"
25         "os"
26         "sync"
27         "time"
28
29         "go.cypherpunks.ru/udpobfs/v2"
30 )
31
32 var (
33         DstAddrUDP *net.UDPAddr
34         TLSConfig  *tls.Config
35         LnUDP      *net.UDPConn
36         LnTCP      *net.TCPListener
37         Peers      = make(map[string]chan udpobfs.Buf)
38         PeersM     sync.RWMutex
39         Bufs       = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }}
40 )
41
42 func newPeer(conn net.Conn) {
43         logger := slog.With("remote", conn.RemoteAddr().String())
44         logger.Info("connected")
45         defer conn.Close()
46         remoteAddr := udpobfs.MustResolveUDPAddr(conn.RemoteAddr().String())
47         localUDP, err := net.DialUDP("udp", nil, DstAddrUDP)
48         if err != nil {
49                 log.Fatal(err)
50         }
51         logger = logger.With("local", localUDP.LocalAddr().String())
52         connTLS := tls.Server(conn, TLSConfig)
53         err = connTLS.Handshake()
54         if err != nil {
55                 logger.Error(err.Error())
56                 return
57         }
58         defer connTLS.Close()
59         tlsState := connTLS.ConnectionState()
60         logger = logger.With("cn", tlsState.PeerCertificates[0].Subject.CommonName)
61         logger.Info("authenticated")
62         seed, err := tlsState.ExportKeyingMaterial(udpobfs.App, nil, udpobfs.SeedLen)
63         if err != nil {
64                 logger.Error(err.Error())
65                 return
66         }
67         cryptoState := udpobfs.NewCryptoState(seed, false)
68         txs := make(chan udpobfs.Buf)
69         PeersM.Lock()
70         Peers[remoteAddr.String()] = txs
71         PeersM.Unlock()
72         txFinished := make(chan struct{})
73         var rxPkts, txPkts, rxBytes, txBytes int64
74         go func() {
75                 pkts := make(chan []byte)
76                 pktRead := make(chan struct{})
77                 go func() {
78                         var n int
79                         var err error
80                         rx := make([]byte, udpobfs.BufLen)
81                         for {
82                                 localUDP.SetReadDeadline(time.Now().Add(time.Minute))
83                                 n, err = localUDP.Read(rx)
84                                 if n != 0 {
85                                         pkts <- rx[:n]
86                                         <-pktRead
87                                 }
88                                 if err != nil {
89                                         break
90                                 }
91                         }
92                         close(pkts)
93                 }()
94                 ticker := time.NewTicker(udpobfs.PingDuration)
95                 defer ticker.Stop()
96                 var pkt []byte
97                 var ok bool
98                 tx := make([]byte, udpobfs.BufLen)
99                 now := time.Now()
100                 lastPing := now
101                 var got []byte
102                 for {
103                         select {
104                         case <-ticker.C:
105                                 now = time.Now()
106                                 if now.Sub(lastPing) > udpobfs.PingDuration {
107                                         LnUDP.WriteTo(
108                                                 cryptoState.Tx(tx[:udpobfs.SeqLen], nil), remoteAddr)
109                                         lastPing = now
110                                 }
111                         case pkt, ok = <-pkts:
112                                 if !ok {
113                                         close(txFinished)
114                                         return
115                                 }
116                                 got = cryptoState.Tx(tx[:udpobfs.SeqLen+len(pkt)], pkt)
117                                 pktRead <- struct{}{}
118                                 LnUDP.WriteTo(got, remoteAddr)
119                                 txPkts++
120                                 txBytes += int64(len(pkt))
121                                 lastPing = time.Now()
122                         }
123                 }
124         }()
125         go func() {
126                 ticker := time.NewTicker(2 * udpobfs.LifetimeDuration)
127                 defer ticker.Stop()
128                 now := time.Now()
129                 last := now
130                 buf := make([]byte, udpobfs.BufLen)
131                 var tx udpobfs.Buf
132                 var got []byte
133                 var ok bool
134                 for {
135                         select {
136                         case <-txFinished:
137                                 return
138                         case <-ticker.C:
139                                 now = time.Now()
140                                 if now.Sub(last) > 2*udpobfs.LifetimeDuration {
141                                         localUDP.Close()
142                                         return
143                                 }
144                         case tx, ok = <-txs:
145                                 if !ok {
146                                         return
147                                 }
148                                 if tx.N < udpobfs.SeqLen {
149                                         logger.Warn("too short")
150                                         Bufs.Put(tx.Buf)
151                                         continue
152                                 }
153                                 got = cryptoState.Rx(buf[:tx.N], (*tx.Buf)[:tx.N])
154                                 Bufs.Put(tx.Buf)
155                                 if got == nil {
156                                         logger.Warn("bad MAC")
157                                         continue
158                                 }
159                                 if len(got) != 0 {
160                                         rxPkts++
161                                         rxBytes += int64(len(got))
162                                         localUDP.Write(got)
163                                 }
164                                 last = time.Now()
165                         }
166                 }
167         }()
168         go func() {
169                 buf := make([]byte, 8)
170                 for {
171                         connTLS.SetReadDeadline(time.Now().Add(2 * udpobfs.LifetimeDuration))
172                         if _, err = io.ReadFull(connTLS, buf); err != nil {
173                                 break
174                         }
175                         if _, err = connTLS.Write(buf); err != nil {
176                                 break
177                         }
178                 }
179         }()
180         <-txFinished
181         logger.Info("finishing",
182                 "rxPkts", rxPkts,
183                 "rxBytes", rxBytes,
184                 "txPkts", txPkts,
185                 "txBytes", txBytes)
186         PeersM.Lock()
187         delete(Peers, remoteAddr.String())
188         PeersM.Unlock()
189         close(txs)
190 }
191
192 func main() {
193         bind := flag.String("bind", "[::]:1194", "Address to bind to")
194         dst := flag.String("dst", "[2001:db8::1234]:1194", "Address to connect to")
195         keypairPath := flag.String("keypair", "keypair.pem", "X.509 keypair")
196         caPath := flag.String("ca", "ca.pem", "CA certificate")
197         flag.Parse()
198         log.SetFlags(log.Lshortfile)
199         slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
200
201         crtRaw, _, err := udpobfs.CertificateFromFile(*keypairPath)
202         if err != nil {
203                 log.Fatal(err)
204         }
205         prv, err := udpobfs.PrivateKeyFromFile(*keypairPath)
206         if err != nil {
207                 log.Fatal(err)
208         }
209         TLSConfig = &tls.Config{
210                 MinVersion: tls.VersionTLS13,
211                 ClientAuth: tls.RequireAndVerifyClientCert,
212                 Certificates: []tls.Certificate{{
213                         Certificate: [][]byte{crtRaw},
214                         PrivateKey:  prv,
215                 }},
216         }
217         _, TLSConfig.ClientCAs, err = udpobfs.CertPoolFromFile(*caPath)
218         if err != nil {
219                 log.Fatalln(err)
220         }
221
222         DstAddrUDP = udpobfs.MustResolveUDPAddr(*dst)
223         LnUDP, err = net.ListenUDP("udp", udpobfs.MustResolveUDPAddr(*bind))
224         if err != nil {
225                 log.Fatal(err)
226         }
227         LnTCP, err = net.ListenTCP("tcp", udpobfs.MustResolveTCPAddr(*bind))
228         if err != nil {
229                 log.Fatal(err)
230         }
231
232         go func() {
233                 var n int
234                 var from net.Addr
235                 var txs chan udpobfs.Buf
236                 var buf *[udpobfs.BufLen]byte
237                 for {
238                         buf = Bufs.Get().(*[udpobfs.BufLen]byte)
239                         n, from, _ = LnUDP.ReadFrom((*buf)[:])
240                         if n == 0 {
241                                 continue
242                         }
243                         PeersM.RLock()
244                         txs = Peers[from.String()]
245                         if txs != nil {
246                                 txs <- udpobfs.Buf{Buf: buf, N: n}
247                         }
248                         PeersM.RUnlock()
249                 }
250         }()
251
252         for {
253                 conn, err := LnTCP.Accept()
254                 if err != nil {
255                         slog.Error(err.Error())
256                         continue
257                 }
258                 go newPeer(conn)
259         }
260 }