]> Cypherpunks.ru repositories - govpn.git/blobdiff - src/cypherpunks.ru/govpn/client/tcp.go
Upgrade Client
[govpn.git] / src / cypherpunks.ru / govpn / client / tcp.go
index 40d81ca818081789c3a9f87ac4e4a2f063c63a76..056b40e36e5ca2e0b6f38a1e53cb381c7f9bad71 100644 (file)
@@ -22,35 +22,62 @@ import (
        "bytes"
        "fmt"
        "net"
+       "os"
        "sync/atomic"
        "time"
 
+       "github.com/Sirupsen/logrus"
+       "github.com/pkg/errors"
+
        "cypherpunks.ru/govpn"
 )
 
 func (c *Client) startTCP() {
-       remote, err := net.ResolveTCPAddr("tcp", c.config.RemoteAddress)
-       if err != nil {
-               c.Error <- fmt.Errorf("Can not resolve remote address: %s", err)
-               return
+       var conn net.Conn
+       l := c.logger.WithField("func", logFuncPrefix+"Client.startTCP")
+       // initialize using a file descriptor
+       if c.config.FileDescriptor > 0 {
+               l.WithField("fd", c.config.FileDescriptor).Debug("Connect using file descriptor")
+               var err error
+               conn, err = net.FileConn(os.NewFile(uintptr(c.config.FileDescriptor), fmt.Sprintf("fd[%s]", c.config.RemoteAddress)))
+               if err != nil {
+                       c.Error <- errors.Wrapf(err, "net.FileConn fd:%d", c.config.FileDescriptor)
+                       return
+               }
+       } else {
+               // TODO move resolution into the loop, as the name might change over time
+               l.WithField("fd", c.config.RemoteAddress).Debug("Connect using TCP")
+               remote, err := net.ResolveTCPAddr("tcp", c.config.RemoteAddress)
+               if err != nil {
+                       c.Error <- errors.Wrapf(err, "net.ResolveTCPAdd %s", c.config.RemoteAddress)
+                       return
+               }
+               l.WithField("remote", remote.String()).Debug("dial")
+               conn, err = net.DialTCP("tcp", nil, remote)
+               if err != nil {
+                       c.Error <- errors.Wrapf(err, "net.DialTCP: %s", remote.String())
+                       return
+               }
        }
-       conn, err := net.DialTCP("tcp", nil, remote)
+       l.WithFields(c.config.LogFields()).Info("Connected")
+       c.handleTCP(conn)
+}
+
+func (c *Client) handleTCP(conn net.Conn) {
+       hs, err := govpn.HandshakeStart(c.config.RemoteAddress, conn, c.config.Peer)
        if err != nil {
-               c.Error <- fmt.Errorf("Can not connect to address: %s", err)
+               govpn.CloseLog(conn, c.logger, c.LogFields())
+               c.Error <- errors.Wrap(err, "govpn.HandshakeStart")
                return
        }
-       govpn.Printf(`[connected remote="%s"]`, c.config.RemoteAddress)
-       c.handleTCP(conn)
-}
+       buf := make([]byte, 2*(govpn.EnclessEnlargeSize+c.config.Peer.MTU)+c.config.Peer.MTU)
 
-func (c *Client) handleTCP(conn *net.TCPConn) {
-       hs := govpn.HandshakeStart(c.config.RemoteAddress, conn, c.config.Peer)
-       buf := make([]byte, 2*(govpn.EnclessEnlargeSize+c.config.MTU)+c.config.MTU)
        var n int
-       var err error
        var prev int
        var peer *govpn.Peer
+       var deadLine time.Time
        var terminator chan struct{}
+       fields := logrus.Fields{"func": logFuncPrefix + "Client.handleTCP"}
 HandshakeCycle:
        for {
                select {
@@ -59,36 +86,42 @@ HandshakeCycle:
                default:
                }
                if prev == len(buf) {
-                       govpn.Printf(`[packet-timeouted remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithFields(fields).WithFields(c.LogFields()).Debug("Packet timeouted")
                        c.timeouted <- struct{}{}
                        break HandshakeCycle
                }
 
-               if err = conn.SetReadDeadline(time.Now().Add(c.config.Peer.Timeout)); err != nil {
-                       c.Error <- err
+               deadLine = time.Now().Add(c.config.Peer.Timeout)
+               if err = conn.SetReadDeadline(deadLine); err != nil {
+                       c.Error <- errors.Wrapf(err, "conn.SetReadDeadline %s", deadLine.String())
                        break HandshakeCycle
                }
                n, err = conn.Read(buf[prev:])
                if err != nil {
-                       govpn.Printf(`[connection-timeouted remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithFields(fields).WithFields(c.LogFields()).Debug("Packet timeouted")
                        c.timeouted <- struct{}{}
                        break HandshakeCycle
                }
 
                prev += n
-               peerID := c.idsCache.Find(buf[:prev])
-               if peerID == nil {
+               _, err = c.idsCache.Find(buf[:prev])
+               if err != nil {
+                       c.logger.WithFields(fields).WithFields(c.LogFields()).WithError(err).Debug("Couldn't find peer in ids")
                        continue
                }
-               peer = hs.Client(buf[:prev])
+               peer, err = hs.Client(buf[:prev])
                prev = 0
-               if peer == nil {
+               if err != nil {
+                       c.logger.WithFields(fields).WithError(err).WithFields(c.LogFields()).Debug("Can't create new peer")
                        continue
                }
-               govpn.Printf(`[handshake-completed remote="%s"]`, c.config.RemoteAddress)
+               c.logger.WithFields(fields).WithFields(c.LogFields()).Info("Handshake completed")
                c.knownPeers = govpn.KnownPeers(map[string]**govpn.Peer{c.config.RemoteAddress: &peer})
                if c.firstUpCall {
-                       go govpn.ScriptCall(c.config.UpPath, c.config.InterfaceName, c.config.RemoteAddress)
+                       if err = c.postUpAction(); err != nil {
+                               c.Error <- errors.Wrap(err, "c.postUpAction")
+                               break HandshakeCycle
+                       }
                        c.firstUpCall = false
                }
                hs.Zero()
@@ -113,17 +146,17 @@ TransportCycle:
                default:
                }
                if prev == len(buf) {
-                       govpn.Printf(`[packet-timeouted remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithFields(c.LogFields()).Debug("Packet timeouted")
                        c.timeouted <- struct{}{}
                        break TransportCycle
                }
                if err = conn.SetReadDeadline(time.Now().Add(c.config.Peer.Timeout)); err != nil {
-                       c.Error <- err
+                       c.Error <- errors.Wrap(err, "conn.SetReadDeadline")
                        break TransportCycle
                }
                n, err = conn.Read(buf[prev:])
                if err != nil {
-                       govpn.Printf(`[connection-timeouted remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithError(err).WithFields(c.LogFields()).Debug("Connection timeouted")
                        c.timeouted <- struct{}{}
                        break TransportCycle
                }
@@ -137,12 +170,12 @@ TransportCycle:
                        continue
                }
                if !peer.PktProcess(buf[:i+govpn.NonceSize], c.tap, false) {
-                       govpn.Printf(`[packet-unauthenticated remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithFields(c.LogFields()).Debug("Packet unauthenticated")
                        c.timeouted <- struct{}{}
                        break TransportCycle
                }
                if atomic.LoadUint64(&peer.BytesIn)+atomic.LoadUint64(&peer.BytesOut) > govpn.MaxBytesPerKey {
-                       govpn.Printf(`[rehandshake-required remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithFields(c.LogFields()).Debug("Rehandshake required")
                        c.rehandshaking <- struct{}{}
                        break TransportCycle
                }
@@ -155,6 +188,9 @@ TransportCycle:
        }
        peer.Zero()
        if err = conn.Close(); err != nil {
-               c.Error <- err
+               c.Error <- errors.Wrap(err, "conn.Close")
+       }
+       if err = c.tap.Close(); err != nil {
+               c.Error <- errors.Wrap(err, logFuncPrefix+"Client.tap.Close")
        }
 }