]> Cypherpunks.ru repositories - udpobfs.git/blob - cmd/init/main.go
Unify copyright comment format
[udpobfs.git] / cmd / init / main.go
1 // udpobfs -- simple point-to-point UDP obfuscation proxy
2 // Copyright (C) 2023-2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU General Public License as published by
6 // the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bytes"
20         "crypto/rand"
21         "crypto/sha256"
22         "crypto/tls"
23         "crypto/x509"
24         "encoding/hex"
25         "errors"
26         "flag"
27         "io"
28         "log"
29         "log/slog"
30         "net"
31         "os"
32         "sync"
33         "time"
34
35         "go.cypherpunks.ru/udpobfs/v2"
36         "lukechampine.com/blake3"
37 )
38
39 var (
40         DstAddrUDP *net.UDPAddr
41         DstAddrTCP *net.TCPAddr
42         TLSConfig  *tls.Config
43         LnUDP      *net.UDPConn
44         Peers      = make(map[string]chan udpobfs.Buf)
45         PeersM     sync.RWMutex
46         Bufs       = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }}
47 )
48
49 func newPeer(localAddr net.Addr, dataInitial []byte) {
50         logger := slog.With("remote", localAddr.String())
51         logger.Info("connected")
52         conn, err := net.DialTCP("tcp", nil, DstAddrTCP)
53         if err != nil {
54                 slog.Warn(err.Error())
55                 return
56         }
57         defer conn.Close()
58         srcAddr, err := net.ResolveUDPAddr("udp", conn.LocalAddr().String())
59         if err != nil {
60                 log.Fatal(err)
61         }
62         logger = logger.With("local", srcAddr.String())
63         connTLS := tls.Client(conn, TLSConfig)
64         err = connTLS.Handshake()
65         if err != nil {
66                 logger.Error(err.Error())
67                 return
68         }
69         defer connTLS.Close()
70         tlsState := connTLS.ConnectionState()
71         logger = logger.With("cn", tlsState.PeerCertificates[0].Subject.CommonName)
72         logger.Info("authenticated")
73         seed, err := tlsState.ExportKeyingMaterial(udpobfs.App, nil, udpobfs.SeedLen)
74         if err != nil {
75                 logger.Error(err.Error())
76                 return
77         }
78         connUDP, err := net.ListenUDP("udp", srcAddr)
79         if err != nil {
80                 logger.Error(err.Error())
81                 return
82         }
83         cryptoState := udpobfs.NewCryptoState(seed, true)
84         txs := make(chan udpobfs.Buf)
85         PeersM.Lock()
86         Peers[localAddr.String()] = txs
87         PeersM.Unlock()
88         var rxPkts, txPkts, rxBytes, txBytes int64
89         {
90                 txPkts++
91                 txBytes += int64(len(dataInitial))
92                 tmp := make([]byte, udpobfs.SeqLen+len(dataInitial))
93                 connUDP.WriteTo(cryptoState.Tx(tmp, dataInitial), DstAddrUDP)
94         }
95         rxFinished := make(chan struct{})
96         go func() {
97                 var n int
98                 var err error
99                 rx := make([]byte, udpobfs.BufLen)
100                 tx := make([]byte, udpobfs.BufLen)
101                 var got []byte
102                 for {
103                         connUDP.SetReadDeadline(time.Now().Add(2 * udpobfs.LifetimeDuration))
104                         n, err = connUDP.Read(rx)
105                         if n == 0 {
106                                 if err != nil {
107                                         break
108                                 }
109                                 continue
110                         }
111                         if n < udpobfs.SeqLen {
112                                 logger.Warn("too short")
113                                 continue
114                         }
115                         got = cryptoState.Rx(tx[:n], rx[:n])
116                         if got == nil {
117                                 logger.Warn("bad MAC")
118                                 continue
119                         }
120                         if len(got) != 0 {
121                                 rxPkts++
122                                 rxBytes += int64(len(got))
123                                 LnUDP.WriteTo(got, localAddr)
124                         }
125                         if err != nil {
126                                 break
127                         }
128                 }
129                 close(rxFinished)
130         }()
131         go func() {
132                 buf := make([]byte, udpobfs.BufLen)
133                 var tx udpobfs.Buf
134                 ticker := time.NewTicker(udpobfs.PingDuration)
135                 defer ticker.Stop()
136                 now := time.Now()
137                 lastPing := now
138                 last := now
139                 var got []byte
140                 var ok bool
141                 for {
142                         select {
143                         case <-ticker.C:
144                                 now = time.Now()
145                                 if now.Sub(last) > 2*udpobfs.LifetimeDuration {
146                                         connUDP.Close()
147                                         return
148                                 }
149                                 if now.Sub(lastPing) > udpobfs.PingDuration {
150                                         _, err = connUDP.WriteTo(
151                                                 cryptoState.Tx(buf[:udpobfs.SeqLen], nil), DstAddrUDP)
152                                         lastPing = now
153                                 }
154                         case tx, ok = <-txs:
155                                 if !ok {
156                                         return
157                                 }
158                                 got = cryptoState.Tx(buf[:udpobfs.SeqLen+tx.N], (*tx.Buf)[:tx.N])
159                                 Bufs.Put(tx.Buf)
160                                 connUDP.WriteTo(got, DstAddrUDP)
161                                 txPkts++
162                                 txBytes += int64(len(got))
163                                 lastPing = time.Now()
164                                 last = lastPing
165                         }
166                 }
167         }()
168         go func() {
169                 defer connUDP.Close()
170                 ticker := time.NewTicker(udpobfs.LifetimeDuration)
171                 defer ticker.Stop()
172                 our := make([]byte, 8)
173                 their := make([]byte, 8)
174                 key := make([]byte, 32)
175                 if _, err = io.ReadFull(rand.Reader, key); err != nil {
176                         log.Fatal(err)
177                 }
178                 rnd := blake3.New(32, key).XOF()
179                 var err error
180                 for {
181                         select {
182                         case <-ticker.C:
183                                 if _, err = io.ReadFull(rnd, our); err != nil {
184                                         log.Fatal(err)
185                                 }
186                                 if _, err = connTLS.Write(our); err != nil {
187                                         return
188                                 }
189                                 if _, err = io.ReadFull(connTLS, their); err != nil {
190                                         return
191                                 }
192                                 if !bytes.Equal(our, their) {
193                                         logger.Error("pong mismatch")
194                                         return
195                                 }
196                         case <-rxFinished:
197                                 return
198                         }
199                 }
200         }()
201         <-rxFinished
202         logger.Info("finishing",
203                 "rxPkts", rxPkts,
204                 "rxBytes", rxBytes,
205                 "txPkts", txPkts,
206                 "txBytes", txBytes)
207         PeersM.Lock()
208         delete(Peers, localAddr.String())
209         PeersM.Unlock()
210         close(txs)
211 }
212
213 func main() {
214         bind := flag.String("bind", "[::]:1194", "Address to bind to")
215         dst := flag.String("dst", "[2001:db8::1234]:1194", "Address to connect to")
216         keypairPath := flag.String("keypair", "keypair.pem", "X.509 keypair")
217         caPath := flag.String("ca", "ca.pem", "CA certificate")
218         serverHash := flag.String("hash", "", "Expected server's SPKI SHA256 fingerprint")
219         serverName := flag.String("name", "example.com", "Expected server's hostname")
220         flag.Parse()
221         log.SetFlags(log.Lshortfile)
222         slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
223
224         crtRaw, _, err := udpobfs.CertificateFromFile(*keypairPath)
225         if err != nil {
226                 log.Fatal(err)
227         }
228         prv, err := udpobfs.PrivateKeyFromFile(*keypairPath)
229         if err != nil {
230                 log.Fatal(err)
231         }
232         TLSConfig = &tls.Config{
233                 MinVersion: tls.VersionTLS13,
234                 Certificates: []tls.Certificate{{
235                         Certificate: [][]byte{crtRaw},
236                         PrivateKey:  prv,
237                 }},
238                 ServerName: *serverName,
239         }
240         _, TLSConfig.RootCAs, err = udpobfs.CertPoolFromFile(*caPath)
241         if err != nil {
242                 log.Fatalln(err)
243         }
244
245         if *serverHash != "" {
246                 hshOur, err := hex.DecodeString(*serverHash)
247                 if err != nil {
248                         log.Fatal(err)
249                 }
250                 TLSConfig.VerifyPeerCertificate = func(
251                         rawCerts [][]byte,
252                         verifiedChains [][]*x509.Certificate,
253                 ) error {
254                         spki := verifiedChains[0][0].RawSubjectPublicKeyInfo
255                         hshTheir := sha256.Sum256(spki)
256                         if !bytes.Equal(hshOur, hshTheir[:]) {
257                                 return errors.New("server certificate's SPKI hash mismatch")
258                         }
259                         return nil
260                 }
261         }
262
263         DstAddrUDP = udpobfs.MustResolveUDPAddr(*dst)
264         DstAddrTCP = udpobfs.MustResolveTCPAddr(*dst)
265         LnUDP, err = net.ListenUDP("udp", udpobfs.MustResolveUDPAddr(*bind))
266         if err != nil {
267                 log.Fatal(err)
268         }
269
270         var n int
271         var from net.Addr
272         var txs chan udpobfs.Buf
273         var buf *[udpobfs.BufLen]byte
274         for {
275                 buf = Bufs.Get().(*[udpobfs.BufLen]byte)
276                 n, from, _ = LnUDP.ReadFrom((*buf)[:])
277                 if n == 0 {
278                         continue
279                 }
280                 PeersM.RLock()
281                 txs = Peers[from.String()]
282                 if txs != nil {
283                         txs <- udpobfs.Buf{Buf: buf, N: n}
284                 }
285                 PeersM.RUnlock()
286                 if txs == nil {
287                         neu := make([]byte, n)
288                         copy(neu, (*buf)[:n])
289                         Bufs.Put(buf)
290                         go newPeer(from, neu)
291                 }
292         }
293 }