]> Cypherpunks.ru repositories - udpobfs.git/blob - cmd/resp/main.go
RWMutex is not that bad
[udpobfs.git] / cmd / resp / main.go
1 /*
2 udpobfs -- simple point-to-point UDP obfuscation proxy
3 Copyright (C) 2023 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package main
19
20 import (
21         "crypto/tls"
22         "flag"
23         "io"
24         "log"
25         "log/slog"
26         "net"
27         "os"
28         "sync"
29         "time"
30
31         "go.cypherpunks.ru/udpobfs/v2"
32 )
33
34 var (
35         DstAddrUDP *net.UDPAddr
36         TLSConfig  *tls.Config
37         LnUDP      *net.UDPConn
38         LnTCP      *net.TCPListener
39         Peers      = make(map[string]chan udpobfs.Buf)
40         PeersM     sync.RWMutex
41         Bufs       = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }}
42 )
43
44 func newPeer(conn net.Conn) {
45         logger := slog.With("remote", conn.RemoteAddr().String())
46         logger.Info("connected")
47         defer conn.Close()
48         remoteAddr := udpobfs.MustResolveUDPAddr(conn.RemoteAddr().String())
49         localUDP, err := net.DialUDP("udp", nil, DstAddrUDP)
50         if err != nil {
51                 log.Fatal(err)
52         }
53         logger = logger.With("local", localUDP.LocalAddr().String())
54         connTLS := tls.Server(conn, TLSConfig)
55         err = connTLS.Handshake()
56         if err != nil {
57                 logger.Error(err.Error())
58                 return
59         }
60         defer connTLS.Close()
61         tlsState := connTLS.ConnectionState()
62         logger = logger.With("cn", tlsState.PeerCertificates[0].Subject.CommonName)
63         logger.Info("authenticated")
64         seed, err := tlsState.ExportKeyingMaterial(udpobfs.App, nil, udpobfs.SeedLen)
65         if err != nil {
66                 logger.Error(err.Error())
67                 return
68         }
69         cryptoState := udpobfs.NewCryptoState(seed, false)
70         txs := make(chan udpobfs.Buf)
71         PeersM.Lock()
72         Peers[remoteAddr.String()] = txs
73         PeersM.Unlock()
74         txFinished := make(chan struct{})
75         var rxPkts, txPkts, rxBytes, txBytes int64
76         go func() {
77                 pkts := make(chan []byte)
78                 pktRead := make(chan struct{})
79                 go func() {
80                         var n int
81                         var err error
82                         rx := make([]byte, udpobfs.BufLen)
83                         for {
84                                 localUDP.SetReadDeadline(time.Now().Add(time.Minute))
85                                 n, err = localUDP.Read(rx)
86                                 if n != 0 {
87                                         pkts <- rx[:n]
88                                         <-pktRead
89                                 }
90                                 if err != nil {
91                                         break
92                                 }
93                         }
94                         close(pkts)
95                 }()
96                 ticker := time.NewTicker(udpobfs.PingDuration)
97                 defer ticker.Stop()
98                 var pkt []byte
99                 var ok bool
100                 tx := make([]byte, udpobfs.BufLen)
101                 now := time.Now()
102                 lastPing := now
103                 var got []byte
104                 for {
105                         select {
106                         case <-ticker.C:
107                                 now = time.Now()
108                                 if now.Sub(lastPing) > udpobfs.PingDuration {
109                                         LnUDP.WriteTo(
110                                                 cryptoState.Tx(tx[:udpobfs.SeqLen], nil), remoteAddr)
111                                         lastPing = now
112                                 }
113                         case pkt, ok = <-pkts:
114                                 if !ok {
115                                         close(txFinished)
116                                         return
117                                 }
118                                 got = cryptoState.Tx(tx[:udpobfs.SeqLen+len(pkt)], pkt)
119                                 pktRead <- struct{}{}
120                                 LnUDP.WriteTo(got, remoteAddr)
121                                 txPkts++
122                                 txBytes += int64(len(pkt))
123                                 lastPing = time.Now()
124                         }
125                 }
126         }()
127         go func() {
128                 ticker := time.NewTicker(2 * udpobfs.LifetimeDuration)
129                 defer ticker.Stop()
130                 now := time.Now()
131                 last := now
132                 buf := make([]byte, udpobfs.BufLen)
133                 var tx udpobfs.Buf
134                 var got []byte
135                 var ok bool
136                 for {
137                         select {
138                         case <-txFinished:
139                                 return
140                         case <-ticker.C:
141                                 now = time.Now()
142                                 if now.Sub(last) > 2*udpobfs.LifetimeDuration {
143                                         localUDP.Close()
144                                         return
145                                 }
146                         case tx, ok = <-txs:
147                                 if !ok {
148                                         return
149                                 }
150                                 if tx.N < udpobfs.SeqLen {
151                                         logger.Warn("too short")
152                                         Bufs.Put(tx.Buf)
153                                         continue
154                                 }
155                                 got = cryptoState.Rx(buf[:tx.N], (*tx.Buf)[:tx.N])
156                                 Bufs.Put(tx.Buf)
157                                 if got == nil {
158                                         logger.Warn("bad MAC")
159                                         continue
160                                 }
161                                 if len(got) != 0 {
162                                         rxPkts++
163                                         rxBytes += int64(len(got))
164                                         localUDP.Write(got)
165                                 }
166                                 last = time.Now()
167                         }
168                 }
169         }()
170         go func() {
171                 buf := make([]byte, 8)
172                 for {
173                         connTLS.SetReadDeadline(time.Now().Add(2 * udpobfs.LifetimeDuration))
174                         if _, err = io.ReadFull(connTLS, buf); err != nil {
175                                 break
176                         }
177                         if _, err = connTLS.Write(buf); err != nil {
178                                 break
179                         }
180                 }
181         }()
182         <-txFinished
183         logger.Info("finishing",
184                 "rxPkts", rxPkts,
185                 "rxBytes", rxBytes,
186                 "txPkts", txPkts,
187                 "txBytes", txBytes)
188         PeersM.Lock()
189         delete(Peers, remoteAddr.String())
190         PeersM.Unlock()
191         close(txs)
192 }
193
194 func main() {
195         bind := flag.String("bind", "[::]:1194", "Address to bind to")
196         dst := flag.String("dst", "[2001:db8::1234]:1194", "Address to connect to")
197         keypairPath := flag.String("keypair", "keypair.pem", "X.509 keypair")
198         caPath := flag.String("ca", "ca.pem", "CA certificate")
199         flag.Parse()
200         log.SetFlags(log.Lshortfile)
201         slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
202
203         crtRaw, _, err := udpobfs.CertificateFromFile(*keypairPath)
204         if err != nil {
205                 log.Fatal(err)
206         }
207         prv, err := udpobfs.PrivateKeyFromFile(*keypairPath)
208         if err != nil {
209                 log.Fatal(err)
210         }
211         TLSConfig = &tls.Config{
212                 MinVersion: tls.VersionTLS13,
213                 ClientAuth: tls.RequireAndVerifyClientCert,
214                 Certificates: []tls.Certificate{{
215                         Certificate: [][]byte{crtRaw},
216                         PrivateKey:  prv,
217                 }},
218         }
219         _, TLSConfig.ClientCAs, err = udpobfs.CertPoolFromFile(*caPath)
220         if err != nil {
221                 log.Fatalln(err)
222         }
223
224         DstAddrUDP = udpobfs.MustResolveUDPAddr(*dst)
225         LnUDP, err = net.ListenUDP("udp", udpobfs.MustResolveUDPAddr(*bind))
226         if err != nil {
227                 log.Fatal(err)
228         }
229         LnTCP, err = net.ListenTCP("tcp", udpobfs.MustResolveTCPAddr(*bind))
230         if err != nil {
231                 log.Fatal(err)
232         }
233
234         go func() {
235                 var n int
236                 var from net.Addr
237                 var txs chan udpobfs.Buf
238                 var buf *[udpobfs.BufLen]byte
239                 for {
240                         buf = Bufs.Get().(*[udpobfs.BufLen]byte)
241                         n, from, _ = LnUDP.ReadFrom((*buf)[:])
242                         if n == 0 {
243                                 continue
244                         }
245                         PeersM.RLock()
246                         txs = Peers[from.String()]
247                         if txs != nil {
248                                 txs <- udpobfs.Buf{Buf: buf, N: n}
249                         }
250                         PeersM.RUnlock()
251                 }
252         }()
253
254         for {
255                 conn, err := LnTCP.Accept()
256                 if err != nil {
257                         slog.Error(err.Error())
258                         continue
259                 }
260                 go newPeer(conn)
261         }
262 }