]> Cypherpunks.ru repositories - udpobfs.git/blob - cmd/resp/main.go
v2
[udpobfs.git] / cmd / resp / main.go
1 /*
2 udpobfs -- simple point-to-point UDP obfuscation proxy
3 Copyright (C) 2023 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package main
19
20 import (
21         "crypto/tls"
22         "flag"
23         "io"
24         "log"
25         "log/slog"
26         "net"
27         "os"
28         "sync"
29         "time"
30
31         "go.cypherpunks.ru/udpobfs/v2"
32 )
33
34 var (
35         DstAddrUDP *net.UDPAddr
36         TLSConfig  *tls.Config
37         LnUDP      *net.UDPConn
38         LnTCP      *net.TCPListener
39         Peers      = make(map[string]chan udpobfs.Buf)
40         Bufs       = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }}
41 )
42
43 func newPeer(conn net.Conn) {
44         logger := slog.With("remote", conn.RemoteAddr().String())
45         logger.Info("connected")
46         defer conn.Close()
47         remoteAddr := udpobfs.MustResolveUDPAddr(conn.RemoteAddr().String())
48         localUDP, err := net.DialUDP("udp", nil, DstAddrUDP)
49         if err != nil {
50                 log.Fatal(err)
51         }
52         logger = logger.With("local", localUDP.LocalAddr().String())
53         connTLS := tls.Server(conn, TLSConfig)
54         err = connTLS.Handshake()
55         if err != nil {
56                 logger.Error(err.Error())
57                 return
58         }
59         defer connTLS.Close()
60         tlsState := connTLS.ConnectionState()
61         logger = logger.With("cn", tlsState.PeerCertificates[0].Subject.CommonName)
62         logger.Info("authenticated")
63         seed, err := tlsState.ExportKeyingMaterial(udpobfs.App, nil, udpobfs.SeedLen)
64         if err != nil {
65                 logger.Error(err.Error())
66                 return
67         }
68         cryptoState := udpobfs.NewCryptoState(seed, false)
69         txs := make(chan udpobfs.Buf)
70         txFinished := make(chan struct{})
71         var rxPkts, txPkts, rxBytes, txBytes int64
72         go func() {
73                 pkts := make(chan []byte)
74                 pktRead := make(chan struct{})
75                 go func() {
76                         var n int
77                         var err error
78                         rx := make([]byte, udpobfs.BufLen)
79                         for {
80                                 localUDP.SetReadDeadline(time.Now().Add(time.Minute))
81                                 n, err = localUDP.Read(rx)
82                                 if n != 0 {
83                                         pkts <- rx[:n]
84                                         <-pktRead
85                                 }
86                                 if err != nil {
87                                         break
88                                 }
89                         }
90                         close(pkts)
91                 }()
92                 ticker := time.NewTicker(udpobfs.PingDuration)
93                 defer ticker.Stop()
94                 var pkt []byte
95                 var ok bool
96                 tx := make([]byte, udpobfs.BufLen)
97                 now := time.Now()
98                 lastPing := now
99                 var got []byte
100                 for {
101                         select {
102                         case <-ticker.C:
103                                 now = time.Now()
104                                 if now.Sub(lastPing) > udpobfs.PingDuration {
105                                         LnUDP.WriteTo(
106                                                 cryptoState.Tx(tx[:udpobfs.SeqLen], nil), remoteAddr)
107                                         lastPing = now
108                                 }
109                         case pkt, ok = <-pkts:
110                                 if !ok {
111                                         close(txFinished)
112                                         return
113                                 }
114                                 got = cryptoState.Tx(tx[:udpobfs.SeqLen+len(pkt)], pkt)
115                                 pktRead <- struct{}{}
116                                 LnUDP.WriteTo(got, remoteAddr)
117                                 txPkts++
118                                 txBytes += int64(len(pkt))
119                                 lastPing = time.Now()
120                         }
121                 }
122         }()
123         go func() {
124                 ticker := time.NewTicker(2 * udpobfs.LifetimeDuration)
125                 defer ticker.Stop()
126                 now := time.Now()
127                 last := now
128                 buf := make([]byte, udpobfs.BufLen)
129                 var tx udpobfs.Buf
130                 var got []byte
131                 for {
132                         select {
133                         case <-txFinished:
134                                 return
135                         case <-ticker.C:
136                                 now = time.Now()
137                                 if now.Sub(last) > 2*udpobfs.LifetimeDuration {
138                                         localUDP.Close()
139                                         return
140                                 }
141                         case tx = <-txs:
142                                 if tx.Buf == nil {
143                                         break
144                                 }
145                                 if tx.N < udpobfs.SeqLen {
146                                         logger.Warn("too short")
147                                         Bufs.Put(tx.Buf)
148                                         continue
149                                 }
150                                 got = cryptoState.Rx(buf[:tx.N], (*tx.Buf)[:tx.N])
151                                 Bufs.Put(tx.Buf)
152                                 if got == nil {
153                                         logger.Warn("bad MAC")
154                                         continue
155                                 }
156                                 if len(got) != 0 {
157                                         rxPkts++
158                                         rxBytes += int64(len(got))
159                                         localUDP.Write(got)
160                                 }
161                                 last = time.Now()
162                         }
163                 }
164         }()
165         Peers[remoteAddr.String()] = txs
166         go func() {
167                 buf := make([]byte, 8)
168                 for {
169                         connTLS.SetReadDeadline(time.Now().Add(2 * udpobfs.LifetimeDuration))
170                         if _, err = io.ReadFull(connTLS, buf); err != nil {
171                                 break
172                         }
173                         if _, err = connTLS.Write(buf); err != nil {
174                                 break
175                         }
176                 }
177         }()
178         <-txFinished
179         logger.Info("finishing",
180                 "rxPkts", rxPkts,
181                 "rxBytes", rxBytes,
182                 "txPkts", txPkts,
183                 "txBytes", txBytes)
184         delete(Peers, remoteAddr.String())
185         txs <- udpobfs.Buf{Buf: nil}
186         go func() {
187                 for range txs {
188                 }
189         }()
190 }
191
192 func main() {
193         bind := flag.String("bind", "[::]:1194", "Address to bind to")
194         dst := flag.String("dst", "[2001:db8::1234]:1194", "Address to connect to")
195         keypairPath := flag.String("keypair", "keypair.pem", "X.509 keypair")
196         caPath := flag.String("ca", "ca.pem", "CA certificate")
197         flag.Parse()
198         log.SetFlags(log.Lshortfile)
199         slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
200
201         crtRaw, _, err := udpobfs.CertificateFromFile(*keypairPath)
202         if err != nil {
203                 log.Fatal(err)
204         }
205         prv, err := udpobfs.PrivateKeyFromFile(*keypairPath)
206         if err != nil {
207                 log.Fatal(err)
208         }
209         TLSConfig = &tls.Config{
210                 MinVersion: tls.VersionTLS13,
211                 ClientAuth: tls.RequireAndVerifyClientCert,
212                 Certificates: []tls.Certificate{{
213                         Certificate: [][]byte{crtRaw},
214                         PrivateKey:  prv,
215                 }},
216         }
217         _, TLSConfig.ClientCAs, err = udpobfs.CertPoolFromFile(*caPath)
218         if err != nil {
219                 log.Fatalln(err)
220         }
221
222         DstAddrUDP = udpobfs.MustResolveUDPAddr(*dst)
223         LnUDP, err = net.ListenUDP("udp", udpobfs.MustResolveUDPAddr(*bind))
224         if err != nil {
225                 log.Fatal(err)
226         }
227         LnTCP, err = net.ListenTCP("tcp", udpobfs.MustResolveTCPAddr(*bind))
228         if err != nil {
229                 log.Fatal(err)
230         }
231
232         go func() {
233                 var n int
234                 var from net.Addr
235                 var txs chan udpobfs.Buf
236                 var buf *[udpobfs.BufLen]byte
237                 for {
238                         buf = Bufs.Get().(*[udpobfs.BufLen]byte)
239                         n, from, _ = LnUDP.ReadFrom((*buf)[:])
240                         if n == 0 {
241                                 continue
242                         }
243                         txs = Peers[from.String()]
244                         if txs != nil {
245                                 txs <- udpobfs.Buf{Buf: buf, N: n}
246                         }
247                 }
248         }()
249
250         for {
251                 conn, err := LnTCP.Accept()
252                 if err != nil {
253                         slog.Error(err.Error())
254                         continue
255                 }
256                 go newPeer(conn)
257         }
258 }