]> Cypherpunks.ru repositories - udpobfs.git/blob - main.go
b4c317e48200e7ffe0e120cf032c9c10257b65c58ef2f0a0fe363c64ae779035
[udpobfs.git] / main.go
1 /*
2 udpobfs -- simple point-to-point UDP obfuscation proxy
3 Copyright (C) 2023 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package main
19
20 import (
21         "bufio"
22         "crypto/rand"
23         "crypto/subtle"
24         "encoding/base32"
25         "flag"
26         "fmt"
27         "io"
28         "log"
29         "net"
30         "os"
31         "os/signal"
32         "syscall"
33
34         "golang.org/x/crypto/blowfish"
35         "golang.org/x/crypto/chacha20"
36         "golang.org/x/crypto/poly1305"
37         "golang.org/x/crypto/sha3"
38 )
39
40 const KeyLen = 32
41
42 func mustWrite(w io.Writer, data []byte) {
43         if n, err := w.Write(data); err != nil || n != len(data) {
44                 log.Fatal("non full write")
45         }
46 }
47
48 func incr(buf []byte) (overflow bool) {
49         for i := len(buf) - 1; i >= 0; i-- {
50                 buf[i]++
51                 if buf[i] != 0 {
52                         return
53                 }
54         }
55         overflow = true
56         return
57 }
58
59 type State struct {
60         ourKey, theirKey     []byte
61         ourObfs, theirObfs   *blowfish.Cipher
62         ourNonce, theirNonce []byte
63         ourSeq, theirSeq     []byte
64 }
65
66 var Base32Codec *base32.Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
67
68 func main() {
69         keygen := flag.Bool("keygen", false, "Generate random key")
70         responder := flag.Bool("responder", false, "Are we responder?")
71         bind := flag.String("bind", "[::]:1194", "Address to bind to")
72         dst := flag.String("dst", "[2001:db8::1234]::1194", "Address to connect to")
73         flag.Parse()
74         log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile)
75
76         if *keygen {
77                 key := make([]byte, KeyLen)
78                 if _, err := io.ReadFull(rand.Reader, key); err != nil {
79                         log.Fatal(err)
80                 }
81                 fmt.Println(Base32Codec.EncodeToString(key))
82                 return
83         }
84
85         var state *State
86         stateReady := make(chan struct{})
87         go func() {
88                 first := true
89                 s := bufio.NewScanner(os.Stdin)
90                 h := sha3.NewShake128()
91                 for s.Scan() {
92                         key, err := Base32Codec.DecodeString(s.Text())
93                         if err != nil {
94                                 log.Fatal(err)
95                         }
96                         if len(key) != KeyLen {
97                                 log.Fatal("wrong key length")
98                         }
99                         h.Reset()
100                         mustWrite(h, []byte("go.cypherpunks.ru/udpobfs"))
101                         mustWrite(h, key)
102                         iEncKey := make([]byte, chacha20.KeySize)
103                         iBlkKey := make([]byte, 32)
104                         rEncKey := make([]byte, chacha20.KeySize)
105                         rBlkKey := make([]byte, 32)
106                         if _, err := io.ReadFull(h, iEncKey); err != nil {
107                                 log.Fatal(err)
108                         }
109                         if _, err := io.ReadFull(h, iBlkKey); err != nil {
110                                 log.Fatal(err)
111                         }
112                         if _, err := io.ReadFull(h, rEncKey); err != nil {
113                                 log.Fatal(err)
114                         }
115                         if _, err := io.ReadFull(h, rBlkKey); err != nil {
116                                 log.Fatal(err)
117                         }
118                         iObfs, err := blowfish.NewCipher(iBlkKey)
119                         if err != nil {
120                                 log.Fatal(err)
121                         }
122                         rObfs, err := blowfish.NewCipher(rBlkKey)
123                         if err != nil {
124                                 log.Fatal(err)
125                         }
126                         var newState State
127                         if *responder {
128                                 newState = State{
129                                         ourKey:     rEncKey,
130                                         theirKey:   iEncKey,
131                                         ourObfs:    rObfs,
132                                         theirObfs:  iObfs,
133                                         ourNonce:   make([]byte, chacha20.NonceSize),
134                                         theirNonce: make([]byte, chacha20.NonceSize),
135                                 }
136                         } else {
137                                 newState = State{
138                                         ourKey:     iEncKey,
139                                         theirKey:   rEncKey,
140                                         ourObfs:    iObfs,
141                                         theirObfs:  rObfs,
142                                         ourNonce:   make([]byte, chacha20.NonceSize),
143                                         theirNonce: make([]byte, chacha20.NonceSize),
144                                 }
145                         }
146                         newState.ourSeq = newState.ourNonce[4:]
147                         newState.theirSeq = newState.theirNonce[4:]
148                         state = &newState
149                         if first {
150                                 close(stateReady)
151                                 first = false
152                         }
153                 }
154                 if s.Err() != nil {
155                         log.Fatal(s.Err())
156                 }
157         }()
158         <-stateReady
159
160         addr, err := net.ResolveUDPAddr("udp", *bind)
161         if err != nil {
162                 log.Fatal(err)
163         }
164         connBind, err := net.ListenUDP("udp", addr)
165         if err != nil {
166                 log.Fatal(err)
167         }
168
169         var connLocal *net.UDPConn
170         var connRemote *net.UDPConn
171         var addrLocal *net.UDPAddr
172         var addrRemote *net.UDPAddr
173         addr, err = net.ResolveUDPAddr("udp", *dst)
174         if err != nil {
175                 log.Fatal(err)
176         }
177         if *responder {
178                 connLocal, err = net.DialUDP("udp", nil, addr)
179         } else {
180                 connRemote, err = net.DialUDP("udp", nil, addr)
181         }
182         if err != nil {
183                 log.Fatal(err)
184         }
185         log.Println(*bind, "->", *dst)
186
187         go func() {
188                 rx := make([]byte, 1<<14)
189                 tx := make([]byte, 1<<14)
190                 var n int
191                 var polyKey [32]byte
192                 var s *chacha20.Cipher
193                 var p *poly1305.MAC
194                 tag := make([]byte, poly1305.TagSize)
195                 var from *net.UDPAddr
196                 for {
197                         if *responder {
198                                 n, err = connLocal.Read(rx)
199                         } else {
200                                 n, from, err = connBind.ReadFromUDP(rx)
201                         }
202                         if err != nil {
203                                 log.Fatal(err)
204                         }
205                         if *responder && addrRemote == nil {
206                                 continue
207                         }
208                         if !*responder && (addrLocal == nil ||
209                                 from.Port != addrLocal.Port || !from.IP.Equal(addrLocal.IP)) {
210                                 addrLocal = from
211                         }
212                         if incr(state.ourSeq[5:]) {
213                                 incr(state.ourSeq[:5])
214                         }
215                         copy(tx, state.ourSeq[5:])
216                         s, err = chacha20.NewUnauthenticatedCipher(state.ourKey, state.ourNonce)
217                         if err != nil {
218                                 log.Fatal(err)
219                         }
220                         clear(polyKey[:])
221                         s.XORKeyStream(polyKey[:], polyKey[:])
222                         s.SetCounter(1)
223                         s.XORKeyStream(tx[8:], rx[:n])
224                         p = poly1305.New(&polyKey)
225                         mustWrite(p, state.ourSeq)
226                         mustWrite(p, tx[8:8+n])
227                         p.Sum(tag[:0])
228                         copy(tx[3:8], tag)
229                         state.ourObfs.Encrypt(tx[:8], tx[:8])
230                         if *responder {
231                                 connBind.WriteTo(tx[:8+n], addrRemote)
232                         } else {
233                                 connRemote.Write(tx[:8+n])
234                         }
235                 }
236         }()
237         go func(state **State) {
238                 rx := make([]byte, 1<<14)
239                 tx := make([]byte, 1<<14)
240                 var n int
241                 var polyKey [32]byte
242                 var s *chacha20.Cipher
243                 var p *poly1305.MAC
244                 var from *net.UDPAddr
245                 tag := make([]byte, poly1305.TagSize)
246                 nonce := make([]byte, chacha20.NonceSize)
247                 seq := nonce[4:]
248                 var seqOur, seqTheir uint32
249                 for {
250                         if *responder {
251                                 n, from, err = connBind.ReadFromUDP(rx)
252                         } else {
253                                 n, err = connRemote.Read(rx)
254                         }
255                         if err != nil {
256                                 log.Fatal(err)
257                         }
258                         if n < 8 {
259                                 log.Println("too short")
260                                 continue
261                         }
262                         if *responder && (addrRemote == nil ||
263                                 from.Port != addrRemote.Port || !from.IP.Equal(addrRemote.IP)) {
264                                 addrRemote = from
265                         }
266                         (*state).theirObfs.Decrypt(rx[:8], rx[:8])
267                         seqOur = uint32((*state).theirSeq[0])<<16 |
268                                 uint32((*state).theirSeq[1])<<8 |
269                                 uint32((*state).theirSeq[2])
270                         seqTheir = uint32(rx[0])<<16 | uint32(rx[1])<<8 | uint32(rx[2])
271                         if seqOur == seqTheir {
272                                 log.Println("replay")
273                                 continue
274                         }
275                         copy(seq, (*state).theirNonce[:5])
276                         copy(seq[5:], rx[:3])
277                         if seqTheir < seqOur && incr(seq[:5]) {
278                                 log.Fatal("seq is overflowed")
279                         }
280                         s, err = chacha20.NewUnauthenticatedCipher((*state).theirKey, nonce)
281                         if err != nil {
282                                 log.Fatal(err)
283                         }
284                         clear(polyKey[:])
285                         s.XORKeyStream(polyKey[:], polyKey[:])
286                         s.SetCounter(1)
287                         p = poly1305.New(&polyKey)
288                         mustWrite(p, seq)
289                         mustWrite(p, rx[8:n])
290                         p.Sum(tag[:0])
291                         if subtle.ConstantTimeCompare(tag[:5], rx[3:8]) != 1 {
292                                 log.Print("bad MAC")
293                                 continue
294                         }
295                         copy((*state).theirSeq, seq)
296                         s.XORKeyStream(tx, rx[8:n])
297                         if *responder {
298                                 connLocal.Write(tx[:n-8])
299                         } else {
300                                 connBind.WriteTo(tx[:n-8], addrLocal)
301                         }
302                 }
303         }(&state)
304         exit := make(chan os.Signal, 1)
305         signal.Notify(exit, syscall.SIGTERM, syscall.SIGINT)
306         <-exit
307 }