]> Cypherpunks.ru repositories - udpobfs.git/blob - cmd/init/main.go
3564dbf0ed737d1a4c718e6774742fddbf2ffce3c45b5fbf264fea4c00584b48
[udpobfs.git] / cmd / init / main.go
1 /*
2 udpobfs -- simple point-to-point UDP obfuscation proxy
3 Copyright (C) 2023 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package main
19
20 import (
21         "bytes"
22         "crypto/rand"
23         "crypto/sha256"
24         "crypto/tls"
25         "crypto/x509"
26         "encoding/hex"
27         "errors"
28         "flag"
29         "io"
30         "log"
31         "log/slog"
32         "net"
33         "os"
34         "sync"
35         "time"
36
37         "go.cypherpunks.ru/udpobfs/v2"
38         "lukechampine.com/blake3"
39 )
40
41 var (
42         DstAddrUDP *net.UDPAddr
43         DstAddrTCP *net.TCPAddr
44         TLSConfig  *tls.Config
45         LnUDP      *net.UDPConn
46         Peers      = make(map[string]chan udpobfs.Buf)
47         PeersM     sync.RWMutex
48         Bufs       = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }}
49 )
50
51 func newPeer(localAddr net.Addr, dataInitial []byte) {
52         logger := slog.With("remote", localAddr.String())
53         logger.Info("connected")
54         conn, err := net.DialTCP("tcp", nil, DstAddrTCP)
55         if err != nil {
56                 slog.Warn(err.Error())
57                 return
58         }
59         defer conn.Close()
60         srcAddr, err := net.ResolveUDPAddr("udp", conn.LocalAddr().String())
61         if err != nil {
62                 log.Fatal(err)
63         }
64         logger = logger.With("local", srcAddr.String())
65         connTLS := tls.Client(conn, TLSConfig)
66         err = connTLS.Handshake()
67         if err != nil {
68                 logger.Error(err.Error())
69                 return
70         }
71         defer connTLS.Close()
72         tlsState := connTLS.ConnectionState()
73         logger = logger.With("cn", tlsState.PeerCertificates[0].Subject.CommonName)
74         logger.Info("authenticated")
75         seed, err := tlsState.ExportKeyingMaterial(udpobfs.App, nil, udpobfs.SeedLen)
76         if err != nil {
77                 logger.Error(err.Error())
78                 return
79         }
80         connUDP, err := net.ListenUDP("udp", srcAddr)
81         if err != nil {
82                 logger.Error(err.Error())
83                 return
84         }
85         cryptoState := udpobfs.NewCryptoState(seed, true)
86         txs := make(chan udpobfs.Buf)
87         PeersM.Lock()
88         Peers[localAddr.String()] = txs
89         PeersM.Unlock()
90         var rxPkts, txPkts, rxBytes, txBytes int64
91         {
92                 txPkts++
93                 txBytes += int64(len(dataInitial))
94                 tmp := make([]byte, udpobfs.SeqLen+len(dataInitial))
95                 connUDP.WriteTo(cryptoState.Tx(tmp, dataInitial), DstAddrUDP)
96         }
97         rxFinished := make(chan struct{})
98         go func() {
99                 var n int
100                 var err error
101                 rx := make([]byte, udpobfs.BufLen)
102                 tx := make([]byte, udpobfs.BufLen)
103                 var got []byte
104                 for {
105                         connUDP.SetReadDeadline(time.Now().Add(2 * udpobfs.LifetimeDuration))
106                         n, err = connUDP.Read(rx)
107                         if n == 0 {
108                                 if err != nil {
109                                         break
110                                 }
111                                 continue
112                         }
113                         if n < udpobfs.SeqLen {
114                                 logger.Warn("too short")
115                                 continue
116                         }
117                         got = cryptoState.Rx(tx[:n], rx[:n])
118                         if got == nil {
119                                 logger.Warn("bad MAC")
120                                 continue
121                         }
122                         if len(got) != 0 {
123                                 rxPkts++
124                                 rxBytes += int64(len(got))
125                                 LnUDP.WriteTo(got, localAddr)
126                         }
127                         if err != nil {
128                                 break
129                         }
130                 }
131                 close(rxFinished)
132         }()
133         go func() {
134                 buf := make([]byte, udpobfs.BufLen)
135                 var tx udpobfs.Buf
136                 ticker := time.NewTicker(udpobfs.PingDuration)
137                 defer ticker.Stop()
138                 now := time.Now()
139                 lastPing := now
140                 last := now
141                 var got []byte
142                 var ok bool
143                 for {
144                         select {
145                         case <-ticker.C:
146                                 now = time.Now()
147                                 if now.Sub(last) > 2*udpobfs.LifetimeDuration {
148                                         connUDP.Close()
149                                         return
150                                 }
151                                 if now.Sub(lastPing) > udpobfs.PingDuration {
152                                         _, err = connUDP.WriteTo(
153                                                 cryptoState.Tx(buf[:udpobfs.SeqLen], nil), DstAddrUDP)
154                                         lastPing = now
155                                 }
156                         case tx, ok = <-txs:
157                                 if !ok {
158                                         return
159                                 }
160                                 got = cryptoState.Tx(buf[:udpobfs.SeqLen+tx.N], (*tx.Buf)[:tx.N])
161                                 Bufs.Put(tx.Buf)
162                                 connUDP.WriteTo(got, DstAddrUDP)
163                                 txPkts++
164                                 txBytes += int64(len(got))
165                                 lastPing = time.Now()
166                                 last = lastPing
167                         }
168                 }
169         }()
170         go func() {
171                 defer connUDP.Close()
172                 ticker := time.NewTicker(udpobfs.LifetimeDuration)
173                 defer ticker.Stop()
174                 our := make([]byte, 8)
175                 their := make([]byte, 8)
176                 key := make([]byte, 32)
177                 if _, err = io.ReadFull(rand.Reader, key); err != nil {
178                         log.Fatal(err)
179                 }
180                 rnd := blake3.New(32, key).XOF()
181                 var err error
182                 for {
183                         select {
184                         case <-ticker.C:
185                                 if _, err = io.ReadFull(rnd, our); err != nil {
186                                         log.Fatal(err)
187                                 }
188                                 if _, err = connTLS.Write(our); err != nil {
189                                         return
190                                 }
191                                 if _, err = io.ReadFull(connTLS, their); err != nil {
192                                         return
193                                 }
194                                 if !bytes.Equal(our, their) {
195                                         logger.Error("pong mismatch")
196                                         return
197                                 }
198                         case <-rxFinished:
199                                 return
200                         }
201                 }
202         }()
203         <-rxFinished
204         logger.Info("finishing",
205                 "rxPkts", rxPkts,
206                 "rxBytes", rxBytes,
207                 "txPkts", txPkts,
208                 "txBytes", txBytes)
209         PeersM.Lock()
210         delete(Peers, localAddr.String())
211         PeersM.Unlock()
212         close(txs)
213 }
214
215 func main() {
216         bind := flag.String("bind", "[::]:1194", "Address to bind to")
217         dst := flag.String("dst", "[2001:db8::1234]:1194", "Address to connect to")
218         keypairPath := flag.String("keypair", "keypair.pem", "X.509 keypair")
219         caPath := flag.String("ca", "ca.pem", "CA certificate")
220         serverHash := flag.String("hash", "", "Expected server's SPKI SHA256 fingerprint")
221         serverName := flag.String("name", "example.com", "Expected server's hostname")
222         flag.Parse()
223         log.SetFlags(log.Lshortfile)
224         slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
225
226         crtRaw, _, err := udpobfs.CertificateFromFile(*keypairPath)
227         if err != nil {
228                 log.Fatal(err)
229         }
230         prv, err := udpobfs.PrivateKeyFromFile(*keypairPath)
231         if err != nil {
232                 log.Fatal(err)
233         }
234         TLSConfig = &tls.Config{
235                 MinVersion: tls.VersionTLS13,
236                 Certificates: []tls.Certificate{{
237                         Certificate: [][]byte{crtRaw},
238                         PrivateKey:  prv,
239                 }},
240                 ServerName: *serverName,
241         }
242         _, TLSConfig.RootCAs, err = udpobfs.CertPoolFromFile(*caPath)
243         if err != nil {
244                 log.Fatalln(err)
245         }
246
247         if *serverHash != "" {
248                 hshOur, err := hex.DecodeString(*serverHash)
249                 if err != nil {
250                         log.Fatal(err)
251                 }
252                 TLSConfig.VerifyPeerCertificate = func(
253                         rawCerts [][]byte,
254                         verifiedChains [][]*x509.Certificate,
255                 ) error {
256                         spki := verifiedChains[0][0].RawSubjectPublicKeyInfo
257                         hshTheir := sha256.Sum256(spki)
258                         if !bytes.Equal(hshOur, hshTheir[:]) {
259                                 return errors.New("server certificate's SPKI hash mismatch")
260                         }
261                         return nil
262                 }
263         }
264
265         DstAddrUDP = udpobfs.MustResolveUDPAddr(*dst)
266         DstAddrTCP = udpobfs.MustResolveTCPAddr(*dst)
267         LnUDP, err = net.ListenUDP("udp", udpobfs.MustResolveUDPAddr(*bind))
268         if err != nil {
269                 log.Fatal(err)
270         }
271
272         var n int
273         var from net.Addr
274         var txs chan udpobfs.Buf
275         var buf *[udpobfs.BufLen]byte
276         for {
277                 buf = Bufs.Get().(*[udpobfs.BufLen]byte)
278                 n, from, _ = LnUDP.ReadFrom((*buf)[:])
279                 if n == 0 {
280                         continue
281                 }
282                 PeersM.RLock()
283                 txs = Peers[from.String()]
284                 if txs != nil {
285                         txs <- udpobfs.Buf{Buf: buf, N: n}
286                 }
287                 PeersM.RUnlock()
288                 if txs == nil {
289                         neu := make([]byte, n)
290                         copy(neu, (*buf)[:n])
291                         Bufs.Put(buf)
292                         go newPeer(from, neu)
293                 }
294         }
295 }