]> Cypherpunks.ru repositories - govpn.git/blob - src/cypherpunks.ru/govpn/client/tcp.go
Upgrade Client
[govpn.git] / src / cypherpunks.ru / govpn / client / tcp.go
1 /*
2 GoVPN -- simple secure free software virtual private network daemon
3 Copyright (C) 2014-2017 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, either version 3 of the License, or
8 (at your option) any later version.
9
10 This program is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 package client
20
21 import (
22         "bytes"
23         "fmt"
24         "net"
25         "os"
26         "sync/atomic"
27         "time"
28
29         "github.com/Sirupsen/logrus"
30         "github.com/pkg/errors"
31
32         "cypherpunks.ru/govpn"
33 )
34
35 func (c *Client) startTCP() {
36         var conn net.Conn
37         l := c.logger.WithField("func", logFuncPrefix+"Client.startTCP")
38         // initialize using a file descriptor
39         if c.config.FileDescriptor > 0 {
40                 l.WithField("fd", c.config.FileDescriptor).Debug("Connect using file descriptor")
41                 var err error
42                 conn, err = net.FileConn(os.NewFile(uintptr(c.config.FileDescriptor), fmt.Sprintf("fd[%s]", c.config.RemoteAddress)))
43                 if err != nil {
44                         c.Error <- errors.Wrapf(err, "net.FileConn fd:%d", c.config.FileDescriptor)
45                         return
46                 }
47         } else {
48                 // TODO move resolution into the loop, as the name might change over time
49                 l.WithField("fd", c.config.RemoteAddress).Debug("Connect using TCP")
50                 remote, err := net.ResolveTCPAddr("tcp", c.config.RemoteAddress)
51                 if err != nil {
52                         c.Error <- errors.Wrapf(err, "net.ResolveTCPAdd %s", c.config.RemoteAddress)
53                         return
54                 }
55                 l.WithField("remote", remote.String()).Debug("dial")
56                 conn, err = net.DialTCP("tcp", nil, remote)
57                 if err != nil {
58                         c.Error <- errors.Wrapf(err, "net.DialTCP: %s", remote.String())
59                         return
60                 }
61         }
62         l.WithFields(c.config.LogFields()).Info("Connected")
63         c.handleTCP(conn)
64 }
65
66 func (c *Client) handleTCP(conn net.Conn) {
67         hs, err := govpn.HandshakeStart(c.config.RemoteAddress, conn, c.config.Peer)
68         if err != nil {
69                 govpn.CloseLog(conn, c.logger, c.LogFields())
70                 c.Error <- errors.Wrap(err, "govpn.HandshakeStart")
71                 return
72         }
73         buf := make([]byte, 2*(govpn.EnclessEnlargeSize+c.config.Peer.MTU)+c.config.Peer.MTU)
74
75         var n int
76         var prev int
77         var peer *govpn.Peer
78         var deadLine time.Time
79         var terminator chan struct{}
80         fields := logrus.Fields{"func": logFuncPrefix + "Client.handleTCP"}
81 HandshakeCycle:
82         for {
83                 select {
84                 case <-c.termination:
85                         break HandshakeCycle
86                 default:
87                 }
88                 if prev == len(buf) {
89                         c.logger.WithFields(fields).WithFields(c.LogFields()).Debug("Packet timeouted")
90                         c.timeouted <- struct{}{}
91                         break HandshakeCycle
92                 }
93
94                 deadLine = time.Now().Add(c.config.Peer.Timeout)
95                 if err = conn.SetReadDeadline(deadLine); err != nil {
96                         c.Error <- errors.Wrapf(err, "conn.SetReadDeadline %s", deadLine.String())
97                         break HandshakeCycle
98                 }
99                 n, err = conn.Read(buf[prev:])
100                 if err != nil {
101                         c.logger.WithFields(fields).WithFields(c.LogFields()).Debug("Packet timeouted")
102                         c.timeouted <- struct{}{}
103                         break HandshakeCycle
104                 }
105
106                 prev += n
107                 _, err = c.idsCache.Find(buf[:prev])
108                 if err != nil {
109                         c.logger.WithFields(fields).WithFields(c.LogFields()).WithError(err).Debug("Couldn't find peer in ids")
110                         continue
111                 }
112                 peer, err = hs.Client(buf[:prev])
113                 prev = 0
114                 if err != nil {
115                         c.logger.WithFields(fields).WithError(err).WithFields(c.LogFields()).Debug("Can't create new peer")
116                         continue
117                 }
118                 c.logger.WithFields(fields).WithFields(c.LogFields()).Info("Handshake completed")
119                 c.knownPeers = govpn.KnownPeers(map[string]**govpn.Peer{c.config.RemoteAddress: &peer})
120                 if c.firstUpCall {
121                         if err = c.postUpAction(); err != nil {
122                                 c.Error <- errors.Wrap(err, "c.postUpAction")
123                                 break HandshakeCycle
124                         }
125                         c.firstUpCall = false
126                 }
127                 hs.Zero()
128                 terminator = make(chan struct{})
129                 go govpn.PeerTapProcessor(peer, c.tap, terminator)
130                 break HandshakeCycle
131         }
132         if hs != nil {
133                 hs.Zero()
134         }
135         if peer == nil {
136                 return
137         }
138
139         prev = 0
140         var i int
141 TransportCycle:
142         for {
143                 select {
144                 case <-c.termination:
145                         break TransportCycle
146                 default:
147                 }
148                 if prev == len(buf) {
149                         c.logger.WithFields(c.LogFields()).Debug("Packet timeouted")
150                         c.timeouted <- struct{}{}
151                         break TransportCycle
152                 }
153                 if err = conn.SetReadDeadline(time.Now().Add(c.config.Peer.Timeout)); err != nil {
154                         c.Error <- errors.Wrap(err, "conn.SetReadDeadline")
155                         break TransportCycle
156                 }
157                 n, err = conn.Read(buf[prev:])
158                 if err != nil {
159                         c.logger.WithError(err).WithFields(c.LogFields()).Debug("Connection timeouted")
160                         c.timeouted <- struct{}{}
161                         break TransportCycle
162                 }
163                 prev += n
164         CheckMore:
165                 if prev < govpn.MinPktLength {
166                         continue
167                 }
168                 i = bytes.Index(buf[:prev], peer.NonceExpect)
169                 if i == -1 {
170                         continue
171                 }
172                 if !peer.PktProcess(buf[:i+govpn.NonceSize], c.tap, false) {
173                         c.logger.WithFields(c.LogFields()).Debug("Packet unauthenticated")
174                         c.timeouted <- struct{}{}
175                         break TransportCycle
176                 }
177                 if atomic.LoadUint64(&peer.BytesIn)+atomic.LoadUint64(&peer.BytesOut) > govpn.MaxBytesPerKey {
178                         c.logger.WithFields(c.LogFields()).Debug("Rehandshake required")
179                         c.rehandshaking <- struct{}{}
180                         break TransportCycle
181                 }
182                 copy(buf, buf[i+govpn.NonceSize:prev])
183                 prev = prev - i - govpn.NonceSize
184                 goto CheckMore
185         }
186         if terminator != nil {
187                 terminator <- struct{}{}
188         }
189         peer.Zero()
190         if err = conn.Close(); err != nil {
191                 c.Error <- errors.Wrap(err, "conn.Close")
192         }
193         if err = c.tap.Close(); err != nil {
194                 c.Error <- errors.Wrap(err, logFuncPrefix+"Client.tap.Close")
195         }
196 }