]> Cypherpunks.ru repositories - udpobfs.git/blob - cmd/init/main.go
v2
[udpobfs.git] / cmd / init / main.go
1 /*
2 udpobfs -- simple point-to-point UDP obfuscation proxy
3 Copyright (C) 2023 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package main
19
20 import (
21         "bytes"
22         "crypto/rand"
23         "crypto/sha256"
24         "crypto/tls"
25         "crypto/x509"
26         "encoding/hex"
27         "errors"
28         "flag"
29         "io"
30         "log"
31         "log/slog"
32         "net"
33         "os"
34         "sync"
35         "time"
36
37         "go.cypherpunks.ru/udpobfs/v2"
38         "lukechampine.com/blake3"
39 )
40
41 var (
42         DstAddrUDP *net.UDPAddr
43         DstAddrTCP *net.TCPAddr
44         TLSConfig  *tls.Config
45         LnUDP      *net.UDPConn
46         Peers      = make(map[string]chan udpobfs.Buf)
47         Bufs       = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }}
48 )
49
50 func newPeer(localAddr net.Addr, dataInitial []byte) {
51         logger := slog.With("remote", localAddr.String())
52         logger.Info("connected")
53         conn, err := net.DialTCP("tcp", nil, DstAddrTCP)
54         if err != nil {
55                 slog.Warn(err.Error())
56                 return
57         }
58         defer conn.Close()
59         srcAddr, err := net.ResolveUDPAddr("udp", conn.LocalAddr().String())
60         if err != nil {
61                 log.Fatal(err)
62         }
63         logger = logger.With("local", srcAddr.String())
64         connTLS := tls.Client(conn, TLSConfig)
65         err = connTLS.Handshake()
66         if err != nil {
67                 logger.Error(err.Error())
68                 return
69         }
70         defer connTLS.Close()
71         tlsState := connTLS.ConnectionState()
72         logger = logger.With("cn", tlsState.PeerCertificates[0].Subject.CommonName)
73         logger.Info("authenticated")
74         seed, err := tlsState.ExportKeyingMaterial(udpobfs.App, nil, udpobfs.SeedLen)
75         if err != nil {
76                 logger.Error(err.Error())
77                 return
78         }
79         connUDP, err := net.ListenUDP("udp", srcAddr)
80         if err != nil {
81                 logger.Error(err.Error())
82                 return
83         }
84         cryptoState := udpobfs.NewCryptoState(seed, true)
85         txs := make(chan udpobfs.Buf)
86         rxFinished := make(chan struct{})
87         var rxPkts, txPkts, rxBytes, txBytes int64
88         go func() {
89                 var n int
90                 var err error
91                 rx := make([]byte, udpobfs.BufLen)
92                 tx := make([]byte, udpobfs.BufLen)
93                 var got []byte
94                 for {
95                         connUDP.SetReadDeadline(time.Now().Add(2 * udpobfs.LifetimeDuration))
96                         n, err = connUDP.Read(rx)
97                         if n == 0 {
98                                 if err != nil {
99                                         break
100                                 }
101                                 continue
102                         }
103                         if n < udpobfs.SeqLen {
104                                 logger.Warn("too short")
105                                 continue
106                         }
107                         got = cryptoState.Rx(tx[:n], rx[:n])
108                         if got == nil {
109                                 logger.Warn("bad MAC")
110                                 continue
111                         }
112                         if len(got) != 0 {
113                                 rxPkts++
114                                 rxBytes += int64(len(got))
115                                 LnUDP.WriteTo(got, localAddr)
116                         }
117                         if err != nil {
118                                 break
119                         }
120                 }
121                 close(rxFinished)
122         }()
123         go func() {
124                 buf := make([]byte, udpobfs.BufLen)
125                 var tx udpobfs.Buf
126                 ticker := time.NewTicker(udpobfs.PingDuration)
127                 defer ticker.Stop()
128                 now := time.Now()
129                 lastPing := now
130                 last := now
131                 var got []byte
132                 for {
133                         select {
134                         case <-ticker.C:
135                                 now = time.Now()
136                                 if now.Sub(last) > 2*udpobfs.LifetimeDuration {
137                                         connUDP.Close()
138                                         return
139                                 }
140                                 if now.Sub(lastPing) > udpobfs.PingDuration {
141                                         _, err = connUDP.WriteTo(
142                                                 cryptoState.Tx(buf[:udpobfs.SeqLen], nil), DstAddrUDP)
143                                         lastPing = now
144                                 }
145                         case tx = <-txs:
146                                 if tx.Buf == nil {
147                                         return
148                                 }
149                                 got = cryptoState.Tx(buf[:udpobfs.SeqLen+tx.N], (*tx.Buf)[:tx.N])
150                                 Bufs.Put(tx.Buf)
151                                 connUDP.WriteTo(got, DstAddrUDP)
152                                 txPkts++
153                                 txBytes += int64(len(got))
154                                 lastPing = time.Now()
155                                 last = lastPing
156                         }
157                 }
158         }()
159         Peers[localAddr.String()] = txs
160         {
161                 txPkts++
162                 txBytes += int64(len(dataInitial))
163                 tmp := make([]byte, udpobfs.SeqLen+len(dataInitial))
164                 connUDP.WriteTo(cryptoState.Tx(tmp, dataInitial), DstAddrUDP)
165         }
166         go func() {
167                 defer connUDP.Close()
168                 ticker := time.NewTicker(udpobfs.LifetimeDuration)
169                 defer ticker.Stop()
170                 our := make([]byte, 8)
171                 their := make([]byte, 8)
172                 key := make([]byte, 32)
173                 if _, err = io.ReadFull(rand.Reader, key); err != nil {
174                         log.Fatal(err)
175                 }
176                 rnd := blake3.New(32, key).XOF()
177                 var err error
178                 for {
179                         select {
180                         case <-ticker.C:
181                                 if _, err = io.ReadFull(rnd, our); err != nil {
182                                         log.Fatal(err)
183                                 }
184                                 if _, err = connTLS.Write(our); err != nil {
185                                         return
186                                 }
187                                 if _, err = io.ReadFull(connTLS, their); err != nil {
188                                         return
189                                 }
190                                 if !bytes.Equal(our, their) {
191                                         logger.Error("pong mismatch")
192                                         return
193                                 }
194                         case <-rxFinished:
195                                 return
196                         }
197                 }
198         }()
199         <-rxFinished
200         logger.Info("finishing",
201                 "rxPkts", rxPkts,
202                 "rxBytes", rxBytes,
203                 "txPkts", txPkts,
204                 "txBytes", txBytes)
205         delete(Peers, localAddr.String())
206         txs <- udpobfs.Buf{Buf: nil}
207         go func() {
208                 for range txs {
209                 }
210         }()
211 }
212
213 func main() {
214         bind := flag.String("bind", "[::]:1194", "Address to bind to")
215         dst := flag.String("dst", "[2001:db8::1234]:1194", "Address to connect to")
216         keypairPath := flag.String("keypair", "keypair.pem", "X.509 keypair")
217         caPath := flag.String("ca", "ca.pem", "CA certificate")
218         serverHash := flag.String("hash", "", "Expected server's SPKI SHA256 fingerprint")
219         serverName := flag.String("name", "example.com", "Expected server's hostname")
220         flag.Parse()
221         log.SetFlags(log.Lshortfile)
222         slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
223
224         crtRaw, _, err := udpobfs.CertificateFromFile(*keypairPath)
225         if err != nil {
226                 log.Fatal(err)
227         }
228         prv, err := udpobfs.PrivateKeyFromFile(*keypairPath)
229         if err != nil {
230                 log.Fatal(err)
231         }
232         TLSConfig = &tls.Config{
233                 MinVersion: tls.VersionTLS13,
234                 Certificates: []tls.Certificate{{
235                         Certificate: [][]byte{crtRaw},
236                         PrivateKey:  prv,
237                 }},
238                 ServerName: *serverName,
239         }
240         _, TLSConfig.RootCAs, err = udpobfs.CertPoolFromFile(*caPath)
241         if err != nil {
242                 log.Fatalln(err)
243         }
244
245         if *serverHash != "" {
246                 hshOur, err := hex.DecodeString(*serverHash)
247                 if err != nil {
248                         log.Fatal(err)
249                 }
250                 TLSConfig.VerifyPeerCertificate = func(
251                         rawCerts [][]byte,
252                         verifiedChains [][]*x509.Certificate,
253                 ) error {
254                         spki := verifiedChains[0][0].RawSubjectPublicKeyInfo
255                         hshTheir := sha256.Sum256(spki)
256                         if !bytes.Equal(hshOur, hshTheir[:]) {
257                                 return errors.New("server certificate's SPKI hash mismatch")
258                         }
259                         return nil
260                 }
261         }
262
263         DstAddrUDP = udpobfs.MustResolveUDPAddr(*dst)
264         DstAddrTCP = udpobfs.MustResolveTCPAddr(*dst)
265         LnUDP, err = net.ListenUDP("udp", udpobfs.MustResolveUDPAddr(*bind))
266         if err != nil {
267                 log.Fatal(err)
268         }
269
270         var n int
271         var from net.Addr
272         var txs chan udpobfs.Buf
273         var buf *[udpobfs.BufLen]byte
274         for {
275                 buf = Bufs.Get().(*[udpobfs.BufLen]byte)
276                 n, from, _ = LnUDP.ReadFrom((*buf)[:])
277                 if n == 0 {
278                         continue
279                 }
280                 txs = Peers[from.String()]
281                 if txs != nil {
282                         txs <- udpobfs.Buf{Buf: buf, N: n}
283                 }
284                 if txs == nil {
285                         neu := make([]byte, n)
286                         copy(neu, (*buf)[:n])
287                         Bufs.Put(buf)
288                         go newPeer(from, neu)
289                 }
290         }
291 }