]> Cypherpunks.ru repositories - govpn.git/commitdiff
Upgrade Client
authorBruno Clermont <bruno@robotinfra.com>
Wed, 8 Feb 2017 10:39:19 +0000 (18:39 +0800)
committerSergey Matveev <stargrave@stargrave.org>
Wed, 15 Feb 2017 06:46:23 +0000 (09:46 +0300)
- wrap errors
- switch to logrus
- add Android support: allow connection with file descriptor
- move `govpn/client.Protocol` to `govpn.Protocol`
- improve usage as a library: switch from Up/Down as executed script to Go function
- add `PreUp` step
- allow metrics to be consumed by library user
- use a generic channel to stop client
- log failure to close resources
- close TAP when not used anymore

src/cypherpunks.ru/govpn/action.go [new file with mode: 0644]
src/cypherpunks.ru/govpn/client/client.go
src/cypherpunks.ru/govpn/client/proxy.go
src/cypherpunks.ru/govpn/client/tcp.go
src/cypherpunks.ru/govpn/client/udp.go
src/cypherpunks.ru/govpn/cmd/govpn-client/main.go
src/cypherpunks.ru/govpn/common.go
src/cypherpunks.ru/govpn/conf.go

diff --git a/src/cypherpunks.ru/govpn/action.go b/src/cypherpunks.ru/govpn/action.go
new file mode 100644 (file)
index 0000000..5ad35b2
--- /dev/null
@@ -0,0 +1,71 @@
+/*
+GoVPN -- simple secure free software virtual private network daemon
+Copyright (C) 2014-2016 Sergey Matveev <stargrave@stargrave.org>
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program.  If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package govpn
+
+import (
+       "os"
+       "os/exec"
+
+       "github.com/pkg/errors"
+)
+
+// PeerContext hold info about a peer that connect or disconnect
+// used for Up, PreUp and Down
+type PeerContext struct {
+       RemoteAddress string
+       Protocol      Protocol
+       Config        PeerConf
+}
+
+// TunnelAction is an action for either client or server that is
+// executed when tunnel goes down
+type TunnelAction func(PeerContext) error
+
+// TunnelPreUpAction is an action for client or server that is executed
+// after user is authenticated
+type TunnelPreUpAction func(PeerContext) (*TAP, error)
+
+// RunScriptAction convert the path to a script into a TunnelAction
+func RunScriptAction(path *string) TunnelAction {
+       if path == nil {
+               return nil
+       }
+       return func(ctx PeerContext) error {
+               _, err := ScriptCall(*path, ctx.Config.Iface, ctx.RemoteAddress)
+               return errors.Wrapf(err, "ScriptCall path=%q interface=%q remote=%q", *path, ctx.Config.Iface, ctx.RemoteAddress)
+       }
+}
+
+// ScriptCall call external program/script.
+// You have to specify path to it and (inteface name as a rule) something
+// that will be the first argument when calling it. Function will return
+// it's output and possible error.
+func ScriptCall(path, ifaceName, remoteAddr string) ([]byte, error) {
+       if path == "" {
+               return nil, nil
+       }
+       if _, err := os.Stat(path); err != nil && os.IsNotExist(err) {
+               return nil, errors.Wrap(err, "os.Path")
+       }
+       cmd := exec.Command(path)
+       cmd.Env = append(cmd.Env, environmentKeyInterface+"="+ifaceName)
+       cmd.Env = append(cmd.Env, environmentKeyRemote+"="+remoteAddr)
+       out, err := cmd.CombinedOutput()
+       return out, errors.Wrap(err, "cmd.CombinedOutput")
+}
index a4c49079157da40428bc5d534f7989afd0dad1e5..d04dc6b1171d055871ad582a7dd31c3736c65bef 100644 (file)
@@ -19,57 +19,69 @@ along with this program.  If not, see <http://www.gnu.org/licenses/>.
 package client
 
 import (
-       "errors"
        "fmt"
-       "net"
-       "os"
        "time"
 
+       "github.com/Sirupsen/logrus"
        "github.com/agl/ed25519"
+       "github.com/pkg/errors"
 
        "cypherpunks.ru/govpn"
 )
 
-// Protocol is a GoVPN supported protocol: either UDP, TCP or both
-type Protocol int
-
-const (
-       // ProtocolUDP is UDP transport protocol
-       ProtocolUDP Protocol = iota
-       // ProtocolTCP is TCP transport protocol
-       ProtocolTCP
-)
+const logFuncPrefix = "govpn/client."
 
 // Configuration holds GoVPN client configuration
 type Configuration struct {
        PrivateKey          *[ed25519.PrivateKeySize]byte
        Peer                *govpn.PeerConf
-       Protocol            Protocol
-       InterfaceName       string
+       Protocol            govpn.Protocol
        ProxyAddress        string
        ProxyAuthentication string
        RemoteAddress       string
-       UpPath              string
-       DownPath            string
-       StatsAddress        string
        NoReconnect         bool
-       MTU                 int
+       // FileDescriptor allow to create a Client from a pre-existing file descriptor.
+       // Required for Android. requires TCP protocol
+       FileDescriptor int
 }
 
 // Validate returns an error if a configuration is invalid
 func (c *Configuration) Validate() error {
-       if c.MTU > govpn.MTUMax {
-               return fmt.Errorf("Invalid MTU %d, maximum allowable is %d", c.MTU, govpn.MTUMax)
+       if c.Peer.MTU > govpn.MTUMax {
+               return errors.Errorf("Invalid MTU %d, maximum allowable is %d", c.Peer.MTU, govpn.MTUMax)
        }
        if len(c.RemoteAddress) == 0 {
                return errors.New("Missing RemoteAddress")
        }
-       if len(c.InterfaceName) == 0 {
-               return errors.New("Missing InterfaceName")
+       if len(c.Peer.Iface) == 0 && c.Peer.PreUp == nil {
+               return errors.New("Missing InterfaceName *or* PreUp")
+       }
+       if c.Protocol != govpn.ProtocolTCP && c.Protocol != govpn.ProtocolUDP {
+               return errors.Errorf("Invalid protocol %d for client", c.Protocol)
+       }
+       if c.FileDescriptor > 0 && c.Protocol != govpn.ProtocolTCP {
+               return errors.Errorf("Connect with file descriptor requires protocol %s", govpn.ProtocolTCP.String())
        }
        return nil
 }
 
+// LogFields return a logrus compatible logging context
+func (c *Configuration) LogFields() logrus.Fields {
+       const prefix = "client_conf_"
+       f := c.Peer.LogFields(prefix)
+       f[prefix+"protocol"] = c.Protocol.String()
+       f[prefix+"no_reconnect"] = c.NoReconnect
+       if len(c.ProxyAddress) > 0 {
+               f[prefix+"proxy"] = c.ProxyAddress
+       }
+       if c.FileDescriptor > 0 {
+               f[prefix+"remote"] = fmt.Sprintf("fd:%d(%s)", c.FileDescriptor, c.RemoteAddress)
+       } else {
+               f[prefix+"remote"] = c.RemoteAddress
+       }
+       return f
+}
+
 func (c *Configuration) isProxy() bool {
        return len(c.ProxyAddress) > 0
 }
@@ -79,36 +91,96 @@ type Client struct {
        idsCache      *govpn.MACCache
        tap           *govpn.TAP
        knownPeers    govpn.KnownPeers
-       statsPort     net.Listener
        timeouted     chan struct{}
        rehandshaking chan struct{}
        termination   chan struct{}
        firstUpCall   bool
-       termSignal    chan os.Signal
+       termSignal    chan interface{}
        config        Configuration
+       logger        *logrus.Logger
 
        // Error channel receives any kind of routine errors
        Error chan error
 }
 
+// LogFields return a logrus compatible logging context
+func (c *Client) LogFields() logrus.Fields {
+       const prefix = "client_"
+       f := logrus.Fields{
+               prefix + "remote": c.config.RemoteAddress,
+       }
+       if c.tap != nil {
+               f[prefix+"interface"] = c.tap.Name
+       }
+       if c.config.Peer != nil {
+               f[prefix+"id"] = c.config.Peer.ID.String()
+       }
+       return f
+}
+
+func (c *Client) postDownAction() error {
+       if c.config.Peer.Down == nil {
+               return nil
+       }
+       err := c.config.Peer.Down(govpn.PeerContext{
+               RemoteAddress: c.config.RemoteAddress,
+               Protocol:      c.config.Protocol,
+               Config:        *c.config.Peer,
+       })
+       return errors.Wrap(err, "c.config.Peer.Down")
+}
+
+func (c *Client) postUpAction() error {
+       if c.config.Peer.Up == nil {
+               return nil
+       }
+       err := c.config.Peer.Up(govpn.PeerContext{
+               RemoteAddress: c.config.RemoteAddress,
+               Protocol:      c.config.Protocol,
+               Config:        *c.config.Peer,
+       })
+       return errors.Wrap(err, "c.config.Peer.Up")
+}
+
+// KnownPeers return GoVPN peers. Always 1.
+// used to get client statistics.
+func (c *Client) KnownPeers() *govpn.KnownPeers {
+       return &c.knownPeers
+}
+
 // MainCycle main loop of a connecting/connected client
 func (c *Client) MainCycle() {
        var err error
-       c.tap, err = govpn.TAPListen(c.config.InterfaceName, c.config.MTU)
-       if err != nil {
-               c.Error <- fmt.Errorf("Can not listen on TUN/TAP interface: %s", err.Error())
-               return
+       l := c.logger.WithFields(logrus.Fields{"func": logFuncPrefix + "Client.MainCycle"})
+       l.WithFields(c.LogFields()).WithFields(c.config.LogFields()).Info("Starting...")
+
+       // if available, run PreUp, it might create interface
+       if c.config.Peer.PreUp != nil {
+               l.Debug("Running PreUp")
+               if c.tap, err = c.config.Peer.PreUp(govpn.PeerContext{
+                       RemoteAddress: c.config.RemoteAddress,
+                       Protocol:      c.config.Protocol,
+                       Config:        *c.config.Peer,
+               }); err != nil {
+                       c.Error <- errors.Wrap(err, "c.config.Peer.PreUp")
+                       return
+               }
+               l.Debug("PreUp success")
+       } else {
+               l.Debug("No PreUp to run")
        }
 
-       if len(c.config.StatsAddress) > 0 {
-               c.statsPort, err = net.Listen("tcp", c.config.StatsAddress)
+       // if tap wasn't set by PreUp, listen here
+       if c.tap == nil {
+               l.WithField("asking", c.config.Peer.Iface).Debug("No interface, try to listen")
+               c.tap, err = govpn.TAPListen(c.config.Peer.Iface, c.config.Peer.MTU)
                if err != nil {
-                       c.Error <- fmt.Errorf("Can't listen on stats port: %s", err.Error())
+                       c.Error <- errors.Wrapf(err, "govpn.TAPListen inteface:%s mtu:%d", c.config.Peer.Iface, c.config.Peer.MTU)
                        return
                }
-               c.knownPeers = govpn.KnownPeers(make(map[string]**govpn.Peer))
-               go govpn.StatsProcessor(c.statsPort, &c.knownPeers)
        }
+       c.config.Peer.Iface = c.tap.Name
+       l.WithFields(c.LogFields()).Debug("Got interface, start main loop")
 
 MainCycle:
        for {
@@ -116,9 +188,11 @@ MainCycle:
                c.rehandshaking = make(chan struct{})
                c.termination = make(chan struct{})
                switch c.config.Protocol {
-               case ProtocolUDP:
+               case govpn.ProtocolUDP:
+                       l.Debug("Start UDP")
                        go c.startUDP()
-               case ProtocolTCP:
+               case govpn.ProtocolTCP:
+                       l.Debug("Start TCP")
                        if c.config.isProxy() {
                                go c.proxyTCP()
                        } else {
@@ -127,16 +201,18 @@ MainCycle:
                }
                select {
                case <-c.termSignal:
-                       govpn.BothPrintf(`[finish remote="%s"]`, c.config.RemoteAddress)
+                       l.WithFields(c.LogFields()).Debug("Finish")
                        c.termination <- struct{}{}
                        // empty value signals that everything is fine
                        c.Error <- nil
                        break MainCycle
                case <-c.timeouted:
                        if c.config.NoReconnect {
+                               l.Debug("No reconnect, stop")
+                               c.Error <- nil
                                break MainCycle
                        }
-                       govpn.BothPrintf(`[sleep seconds="%d"]`, c.config.Peer.Timeout/time.Second)
+                       l.WithField("timeout", c.config.Peer.Timeout.String()).Debug("Sleep")
                        time.Sleep(c.config.Peer.Timeout)
                case <-c.rehandshaking:
                }
@@ -144,26 +220,28 @@ MainCycle:
                close(c.rehandshaking)
                close(c.termination)
        }
-       if _, err = govpn.ScriptCall(
-               c.config.DownPath,
-               c.config.InterfaceName,
-               c.config.RemoteAddress,
-       ); err != nil {
-               c.Error <- err
+       l.WithFields(c.config.LogFields()).Debug("Run post down action")
+       if err = c.postDownAction(); err != nil {
+               c.Error <- errors.Wrap(err, "c.postDownAction")
        }
 }
 
 // NewClient returns a configured GoVPN client, to trigger connection
 // MainCycle must be executed.
-func NewClient(conf Configuration, verifier *govpn.Verifier, termSignal chan os.Signal) *Client {
+func NewClient(conf Configuration, logger *logrus.Logger, termSignal chan interface{}) (*Client, error) {
        client := Client{
                idsCache:    govpn.NewMACCache(),
                firstUpCall: true,
                config:      conf,
                termSignal:  termSignal,
                Error:       make(chan error, 1),
+               knownPeers:  govpn.KnownPeers(make(map[string]**govpn.Peer)),
+               logger:      logger,
+       }
+       govpn.SetLogger(client.logger)
+       confs := map[govpn.PeerID]*govpn.PeerConf{*conf.Peer.ID: conf.Peer}
+       if err := client.idsCache.Update(&confs); err != nil {
+               return nil, errors.Wrap(err, "client.idsCache.Update")
        }
-       confs := map[govpn.PeerID]*govpn.PeerConf{*verifier.ID: conf.Peer}
-       client.idsCache.Update(&confs)
-       return &client
+       return &client, nil
 }
index be7d1149fa3e4d44e257a195fe92151c11e8ec8d..9ee1208d9bac3fee97af5cb7b283d9059362f1a3 100644 (file)
@@ -21,22 +21,23 @@ package client
 import (
        "bufio"
        "encoding/base64"
-       "fmt"
        "net"
        "net/http"
 
+       "github.com/pkg/errors"
+
        "cypherpunks.ru/govpn"
 )
 
 func (c *Client) proxyTCP() {
        proxyAddr, err := net.ResolveTCPAddr("tcp", c.config.ProxyAddress)
        if err != nil {
-               c.Error <- err
+               c.Error <- errors.Wrapf(err, "net.ResolveTCPAddr %s", c.config.ProxyAddress)
                return
        }
        conn, err := net.DialTCP("tcp", nil, proxyAddr)
        if err != nil {
-               c.Error <- err
+               c.Error <- errors.Wrapf(err, "net.DialTCP %s", proxyAddr.String())
                return
        }
        req := "CONNECT " + c.config.ProxyAddress + " HTTP/1.1\n"
@@ -46,15 +47,25 @@ func (c *Client) proxyTCP() {
                req += base64.StdEncoding.EncodeToString([]byte(c.config.ProxyAuthentication)) + "\n"
        }
        req += "\n"
-       conn.Write([]byte(req))
+       if _, err = conn.Write([]byte(req)); err != nil {
+               govpn.CloseLog(conn, c.logger, c.LogFields())
+               c.Error <- errors.Wrap(err, "conn.Write")
+               return
+       }
        resp, err := http.ReadResponse(
                bufio.NewReader(conn),
                &http.Request{Method: "CONNECT"},
        )
-       if err != nil || resp.StatusCode != http.StatusOK {
-               c.Error <- fmt.Errorf("Unexpected response from proxy: %s", err.Error())
+       if err != nil {
+               govpn.CloseLog(conn, c.logger, c.LogFields())
+               c.Error <- errors.Wrap(err, "http.ReadResponse CONNECT")
+               return
+       }
+       if resp.StatusCode != http.StatusOK {
+               govpn.CloseLog(conn, c.logger, c.LogFields())
+               c.Error <- errors.Errorf("Unexpected response from proxy: %s", http.StatusText(resp.StatusCode))
                return
        }
-       govpn.Printf(`[proxy-connected remote="%s" addr="%s"]`, c.config.RemoteAddress, *proxyAddr)
+       c.logger.WithField("func", logFuncPrefix+"Client.proxyTCP").WithFields(c.config.LogFields()).Debug("Proxy connected")
        go c.handleTCP(conn)
 }
index 40d81ca818081789c3a9f87ac4e4a2f063c63a76..056b40e36e5ca2e0b6f38a1e53cb381c7f9bad71 100644 (file)
@@ -22,35 +22,62 @@ import (
        "bytes"
        "fmt"
        "net"
+       "os"
        "sync/atomic"
        "time"
 
+       "github.com/Sirupsen/logrus"
+       "github.com/pkg/errors"
+
        "cypherpunks.ru/govpn"
 )
 
 func (c *Client) startTCP() {
-       remote, err := net.ResolveTCPAddr("tcp", c.config.RemoteAddress)
-       if err != nil {
-               c.Error <- fmt.Errorf("Can not resolve remote address: %s", err)
-               return
+       var conn net.Conn
+       l := c.logger.WithField("func", logFuncPrefix+"Client.startTCP")
+       // initialize using a file descriptor
+       if c.config.FileDescriptor > 0 {
+               l.WithField("fd", c.config.FileDescriptor).Debug("Connect using file descriptor")
+               var err error
+               conn, err = net.FileConn(os.NewFile(uintptr(c.config.FileDescriptor), fmt.Sprintf("fd[%s]", c.config.RemoteAddress)))
+               if err != nil {
+                       c.Error <- errors.Wrapf(err, "net.FileConn fd:%d", c.config.FileDescriptor)
+                       return
+               }
+       } else {
+               // TODO move resolution into the loop, as the name might change over time
+               l.WithField("fd", c.config.RemoteAddress).Debug("Connect using TCP")
+               remote, err := net.ResolveTCPAddr("tcp", c.config.RemoteAddress)
+               if err != nil {
+                       c.Error <- errors.Wrapf(err, "net.ResolveTCPAdd %s", c.config.RemoteAddress)
+                       return
+               }
+               l.WithField("remote", remote.String()).Debug("dial")
+               conn, err = net.DialTCP("tcp", nil, remote)
+               if err != nil {
+                       c.Error <- errors.Wrapf(err, "net.DialTCP: %s", remote.String())
+                       return
+               }
        }
-       conn, err := net.DialTCP("tcp", nil, remote)
+       l.WithFields(c.config.LogFields()).Info("Connected")
+       c.handleTCP(conn)
+}
+
+func (c *Client) handleTCP(conn net.Conn) {
+       hs, err := govpn.HandshakeStart(c.config.RemoteAddress, conn, c.config.Peer)
        if err != nil {
-               c.Error <- fmt.Errorf("Can not connect to address: %s", err)
+               govpn.CloseLog(conn, c.logger, c.LogFields())
+               c.Error <- errors.Wrap(err, "govpn.HandshakeStart")
                return
        }
-       govpn.Printf(`[connected remote="%s"]`, c.config.RemoteAddress)
-       c.handleTCP(conn)
-}
+       buf := make([]byte, 2*(govpn.EnclessEnlargeSize+c.config.Peer.MTU)+c.config.Peer.MTU)
 
-func (c *Client) handleTCP(conn *net.TCPConn) {
-       hs := govpn.HandshakeStart(c.config.RemoteAddress, conn, c.config.Peer)
-       buf := make([]byte, 2*(govpn.EnclessEnlargeSize+c.config.MTU)+c.config.MTU)
        var n int
-       var err error
        var prev int
        var peer *govpn.Peer
+       var deadLine time.Time
        var terminator chan struct{}
+       fields := logrus.Fields{"func": logFuncPrefix + "Client.handleTCP"}
 HandshakeCycle:
        for {
                select {
@@ -59,36 +86,42 @@ HandshakeCycle:
                default:
                }
                if prev == len(buf) {
-                       govpn.Printf(`[packet-timeouted remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithFields(fields).WithFields(c.LogFields()).Debug("Packet timeouted")
                        c.timeouted <- struct{}{}
                        break HandshakeCycle
                }
 
-               if err = conn.SetReadDeadline(time.Now().Add(c.config.Peer.Timeout)); err != nil {
-                       c.Error <- err
+               deadLine = time.Now().Add(c.config.Peer.Timeout)
+               if err = conn.SetReadDeadline(deadLine); err != nil {
+                       c.Error <- errors.Wrapf(err, "conn.SetReadDeadline %s", deadLine.String())
                        break HandshakeCycle
                }
                n, err = conn.Read(buf[prev:])
                if err != nil {
-                       govpn.Printf(`[connection-timeouted remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithFields(fields).WithFields(c.LogFields()).Debug("Packet timeouted")
                        c.timeouted <- struct{}{}
                        break HandshakeCycle
                }
 
                prev += n
-               peerID := c.idsCache.Find(buf[:prev])
-               if peerID == nil {
+               _, err = c.idsCache.Find(buf[:prev])
+               if err != nil {
+                       c.logger.WithFields(fields).WithFields(c.LogFields()).WithError(err).Debug("Couldn't find peer in ids")
                        continue
                }
-               peer = hs.Client(buf[:prev])
+               peer, err = hs.Client(buf[:prev])
                prev = 0
-               if peer == nil {
+               if err != nil {
+                       c.logger.WithFields(fields).WithError(err).WithFields(c.LogFields()).Debug("Can't create new peer")
                        continue
                }
-               govpn.Printf(`[handshake-completed remote="%s"]`, c.config.RemoteAddress)
+               c.logger.WithFields(fields).WithFields(c.LogFields()).Info("Handshake completed")
                c.knownPeers = govpn.KnownPeers(map[string]**govpn.Peer{c.config.RemoteAddress: &peer})
                if c.firstUpCall {
-                       go govpn.ScriptCall(c.config.UpPath, c.config.InterfaceName, c.config.RemoteAddress)
+                       if err = c.postUpAction(); err != nil {
+                               c.Error <- errors.Wrap(err, "c.postUpAction")
+                               break HandshakeCycle
+                       }
                        c.firstUpCall = false
                }
                hs.Zero()
@@ -113,17 +146,17 @@ TransportCycle:
                default:
                }
                if prev == len(buf) {
-                       govpn.Printf(`[packet-timeouted remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithFields(c.LogFields()).Debug("Packet timeouted")
                        c.timeouted <- struct{}{}
                        break TransportCycle
                }
                if err = conn.SetReadDeadline(time.Now().Add(c.config.Peer.Timeout)); err != nil {
-                       c.Error <- err
+                       c.Error <- errors.Wrap(err, "conn.SetReadDeadline")
                        break TransportCycle
                }
                n, err = conn.Read(buf[prev:])
                if err != nil {
-                       govpn.Printf(`[connection-timeouted remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithError(err).WithFields(c.LogFields()).Debug("Connection timeouted")
                        c.timeouted <- struct{}{}
                        break TransportCycle
                }
@@ -137,12 +170,12 @@ TransportCycle:
                        continue
                }
                if !peer.PktProcess(buf[:i+govpn.NonceSize], c.tap, false) {
-                       govpn.Printf(`[packet-unauthenticated remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithFields(c.LogFields()).Debug("Packet unauthenticated")
                        c.timeouted <- struct{}{}
                        break TransportCycle
                }
                if atomic.LoadUint64(&peer.BytesIn)+atomic.LoadUint64(&peer.BytesOut) > govpn.MaxBytesPerKey {
-                       govpn.Printf(`[rehandshake-required remote="%s"]`, c.config.RemoteAddress)
+                       c.logger.WithFields(c.LogFields()).Debug("Rehandshake required")
                        c.rehandshaking <- struct{}{}
                        break TransportCycle
                }
@@ -155,6 +188,9 @@ TransportCycle:
        }
        peer.Zero()
        if err = conn.Close(); err != nil {
-               c.Error <- err
+               c.Error <- errors.Wrap(err, "conn.Close")
+       }
+       if err = c.tap.Close(); err != nil {
+               c.Error <- errors.Wrap(err, logFuncPrefix+"Client.tap.Close")
        }
 }
index bb7045ac711778f25e8a9f8d839dea3ae9f11d92..24702f3b6837e3e30a9162ec1c6459adc8f806b9 100644 (file)
@@ -1,6 +1,6 @@
 /*
 GoVPN -- simple secure free software virtual private network daemon
-Copyright (C) 2014-2017 Sergey Matveev <stargrave@stargrave.org>
+Copyright (C) 2014-2016 Sergey Matveev <stargrave@stargrave.org>
 
 This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
@@ -19,34 +19,51 @@ along with this program.  If not, see <http://www.gnu.org/licenses/>.
 package client
 
 import (
-       "fmt"
        "net"
        "sync/atomic"
        "time"
 
+       "github.com/pkg/errors"
+
        "cypherpunks.ru/govpn"
 )
 
 func (c *Client) startUDP() {
+       l := c.logger.WithField("func", "startUDP")
+
+       // TODO move resolution into the loop, as the name might change over time
+       l.Debug("Resolve UDP address")
        remote, err := net.ResolveUDPAddr("udp", c.config.RemoteAddress)
        if err != nil {
-               c.Error <- fmt.Errorf("Can not resolve remote address: %s", err)
+               c.Error <- errors.Wrapf(err, "net.ResolveUDPAddr %s", c.config.RemoteAddress)
                return
        }
+       l.WithField("remote", remote.String()).Debug("dial")
        conn, err := net.DialUDP("udp", nil, remote)
        if err != nil {
-               c.Error <- fmt.Errorf("Can not connect remote address: %s", err)
+               c.Error <- errors.Wrapf(err, "net.DialUDP %s", c.config.RemoteAddress)
+               return
+       }
+       l.WithFields(c.config.LogFields()).Info("Connected")
+
+       l.Debug("Handshake start")
+       hs, err := govpn.HandshakeStart(c.config.RemoteAddress, conn, c.config.Peer)
+       if err != nil {
+               govpn.CloseLog(conn, c.logger, c.LogFields())
+               c.Error <- errors.Wrap(err, "govpn.HandshakeStart")
                return
        }
-       govpn.Printf(`[connected remote="%s"]`, c.config.RemoteAddress)
+       l.Debug("Handshake done")
 
-       hs := govpn.HandshakeStart(c.config.RemoteAddress, conn, c.config.Peer)
-       buf := make([]byte, c.config.MTU*2)
+       buf := make([]byte, c.config.Peer.MTU*2)
        var n int
        var timeouts int
        var peer *govpn.Peer
+       var deadLine time.Time
        var terminator chan struct{}
        timeout := int(c.config.Peer.Timeout.Seconds())
+       l.Debug("Main cycle")
+
 MainCycle:
        for {
                select {
@@ -55,48 +72,58 @@ MainCycle:
                default:
                }
 
-               if err = conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
-                       c.Error <- err
+               deadLine = time.Now().Add(time.Second)
+               if err = conn.SetReadDeadline(deadLine); err != nil {
+                       c.Error <- errors.Wrapf(err, "conn.SetReadDeadline %s", deadLine.String())
                        break MainCycle
                }
+               l.Debug("conn.Read")
                n, err = conn.Read(buf)
-               if timeouts == timeout {
-                       govpn.Printf(`[connection-timeouted remote="%s"]`, c.config.RemoteAddress)
+               if timeouts >= timeout {
+                       l.WithFields(c.LogFields()).Debug("Connection timeouted")
                        c.timeouted <- struct{}{}
                        break
                }
                if err != nil {
+                       l.WithError(err).WithFields(c.LogFields()).Debug("Can't read from connection")
                        timeouts++
                        continue
                }
                if peer != nil {
+                       c.logger.WithFields(c.LogFields()).Debug("No peer yet, process packet")
                        if peer.PktProcess(buf[:n], c.tap, true) {
+                               l.WithFields(c.LogFields()).Debug("Packet processed")
                                timeouts = 0
                        } else {
-                               govpn.Printf(`[packet-unauthenticated remote="%s"]`, c.config.RemoteAddress)
+                               l.WithFields(c.LogFields()).Debug("Packet unauthenticated")
                                timeouts++
                        }
                        if atomic.LoadUint64(&peer.BytesIn)+atomic.LoadUint64(&peer.BytesOut) > govpn.MaxBytesPerKey {
-                               govpn.Printf(`[rehandshake-required remote="%s"]`, c.config.RemoteAddress)
+                               l.WithFields(c.LogFields()).Debug("Rehandshake required")
                                c.rehandshaking <- struct{}{}
                                break MainCycle
                        }
                        continue
                }
-               if c.idsCache.Find(buf[:n]) == nil {
-                       govpn.Printf(`[identity-invalid remote="%s"]`, c.config.RemoteAddress)
+               if _, err = c.idsCache.Find(buf[:n]); err != nil {
+                       l.WithError(err).WithFields(c.LogFields()).Debug("Identity invalid")
                        continue
                }
                timeouts = 0
-               peer = hs.Client(buf[:n])
+               peer, err = hs.Client(buf[:n])
+               if err != nil {
+                       c.Error <- errors.Wrap(err, "hs.Client")
+                       continue
+               }
+               // no error, but handshake not yet completed
                if peer == nil {
                        continue
                }
-               govpn.Printf(`[handshake-completed remote="%s"]`, c.config.RemoteAddress)
+               l.WithFields(c.LogFields()).Info("Handshake completed")
                c.knownPeers = govpn.KnownPeers(map[string]**govpn.Peer{c.config.RemoteAddress: &peer})
-               if c.firstUpCall {
-                       go govpn.ScriptCall(c.config.UpPath, c.config.InterfaceName, c.config.RemoteAddress)
-                       c.firstUpCall = false
+               if err = c.postUpAction(); err != nil {
+                       c.Error <- errors.Wrap(err, "c.postUpAction")
+                       continue
                }
                hs.Zero()
                terminator = make(chan struct{})
@@ -109,6 +136,9 @@ MainCycle:
                hs.Zero()
        }
        if err = conn.Close(); err != nil {
-               c.Error <- err
+               c.Error <- errors.Wrap(err, "conn.Close")
+       }
+       if err = c.tap.Close(); err != nil {
+               c.Error <- errors.Wrap(err, logFuncPrefix+"Client.tap.Close")
        }
 }
index 8fcd988ecc8c6d1bb8bf3a688482df4751fa8c28..f6530dfe8505b70aaf86e8979472d52d0d26417b 100644 (file)
@@ -1,6 +1,6 @@
 /*
 GoVPN -- simple secure free software virtual private network daemon
-Copyright (C) 2014-2017 Sergey Matveev <stargrave@stargrave.org>
+Copyright (C) 2014-2016 Sergey Matveev <stargrave@stargrave.org>
 
 This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
@@ -22,11 +22,10 @@ package main
 import (
        "flag"
        "fmt"
-       "log"
-       "os"
-       "os/signal"
        "time"
 
+       "github.com/Sirupsen/logrus"
+
        "cypherpunks.ru/govpn"
        "cypherpunks.ru/govpn/client"
 )
@@ -54,8 +53,10 @@ func main() {
                syslog      = flag.Bool("syslog", false, "Enable logging to syslog")
                version     = flag.Bool("version", false, "Print version information")
                warranty    = flag.Bool("warranty", false, "Print warranty information")
-               protocol    client.Protocol
+               logLevel    = flag.String("log_level", "warning", "Log level")
+               protocol    govpn.Protocol
                err         error
+               fields      = logrus.Fields{"func": "main"}
        )
 
        flag.Parse()
@@ -67,41 +68,46 @@ func main() {
                fmt.Println(govpn.VersionGet())
                return
        }
-       log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile)
+
+       logger, err := govpn.NewLogger(*logLevel, *syslog)
+       if err != nil {
+               logrus.WithFields(fields).WithError(err).Fatal("Couldn't initialize logging")
+       }
 
        if *egdPath != "" {
-               log.Println("Using", *egdPath, "EGD")
+               logger.WithField("egd_path", *egdPath).WithFields(fields).Debug("Init EGD")
                govpn.EGDInit(*egdPath)
        }
 
-       switch *proto {
-       case "udp":
-               protocol = client.ProtocolUDP
-       case "tcp":
-               protocol = client.ProtocolTCP
-       default:
-               log.Fatalln("Unknown protocol specified")
+       if protocol, err = govpn.NewProtocolFromString(*proto); err != nil {
+               logger.WithError(err).WithFields(fields).WithField("proto", *proto).Fatal("Invalid protocol")
        }
 
-       if *proxyAddr != "" && protocol == client.ProtocolUDP {
-               log.Fatalln("HTTP proxy is supported only in TCP mode")
+       if *proxyAddr != "" && protocol == govpn.ProtocolUDP {
+               logrus.WithFields(fields).WithFields(logrus.Fields{
+                       "proxy": *proxyAddr,
+                       "proto": *proto,
+               }).Fatal("HTTP proxy is supported only in TCP mode")
        }
 
        if *verifierRaw == "" {
-               log.Fatalln("-verifier is required")
+               logger.Fatalln("-verifier is required")
        }
        verifier, err := govpn.VerifierFromString(*verifierRaw)
        if err != nil {
-               log.Fatalln("Invalid -verifier:", err)
+               logger.WithError(err).Fatal("Invalid -verifier")
        }
        key, err := govpn.KeyRead(*keyPath)
        if err != nil {
-               log.Fatalln("Invalid -key:", err)
+               logger.WithError(err).Fatal("Invalid -key")
+       }
+       priv, err := verifier.PasswordApply(key)
+       if err != nil {
+               logger.WithError(err).Fatal("Can't PasswordApply")
        }
-       priv := verifier.PasswordApply(key)
        if *encless {
-               if protocol != client.ProtocolTCP {
-                       log.Fatalln("Currently encryptionless mode works only with TCP")
+               if protocol != govpn.ProtocolTCP {
+                       logger.Fatal("Currently encryptionless mode works only with TCP")
                }
                *noisy = true
        }
@@ -118,33 +124,30 @@ func main() {
                        Encless:  *encless,
                        Verifier: verifier,
                        DSAPriv:  priv,
+                       Up:       govpn.RunScriptAction(upPath),
+                       Down:     govpn.RunScriptAction(downPath),
                },
                Protocol:            protocol,
-               InterfaceName:       *ifaceName,
                ProxyAddress:        *proxyAddr,
                ProxyAuthentication: *proxyAuth,
                RemoteAddress:       *remoteAddr,
-               UpPath:              *upPath,
-               DownPath:            *downPath,
-               StatsAddress:        *stats,
                NoReconnect:         *noreconnect,
-               MTU:                 *mtu,
        }
        if err = conf.Validate(); err != nil {
-               log.Fatalln("Invalid settings:", err)
+               logger.WithError(err).Fatal("Invalid settings")
        }
 
-       log.Println(govpn.VersionGet())
+       c, err := client.NewClient(conf, logger, govpn.CatchSignalShutdown())
+       if err != nil {
+               logger.WithError(err).Fatal("Can't initialize client")
+       }
 
-       if *syslog {
-               govpn.SyslogEnable()
+       if *stats != "" {
+               go govpn.StatsProcessor(*stats, c.KnownPeers())
        }
 
-       termSignal := make(chan os.Signal, 1)
-       signal.Notify(termSignal, os.Interrupt, os.Kill)
-       c := client.NewClient(conf, verifier, termSignal)
        go c.MainCycle()
        if err = <-c.Error; err != nil {
-               log.Fatalln(err)
+               logger.WithError(err).Fatal("Fatal error")
        }
 }
index 56b0f696527231738a49f1eae7f15988327f76fa..67774b1d3ccbaab4ee27548451a269e9fb26eadc 100644 (file)
@@ -19,15 +19,24 @@ along with this program.  If not, see <http://www.gnu.org/licenses/>.
 package govpn
 
 import (
-       "log"
-       "os"
-       "os/exec"
+       "encoding/hex"
+       "encoding/json"
        "runtime"
+       "strings"
+       "time"
+
+       "github.com/pkg/errors"
 )
 
 const (
-       TimeoutDefault = 60
-       EtherSize      = 14
+       // ProtocolUDP is UDP transport protocol
+       ProtocolUDP Protocol = iota
+       // ProtocolTCP is TCP transport protocol
+       ProtocolTCP
+       // ProtocolALL is TCP+UDP transport protocol
+       ProtocolALL
+
+       EtherSize = 14
        // MTUMax is maximum MTU size of Ethernet packet
        MTUMax = 9000 + EtherSize + 1
        // MTUDefault is default MTU size of Ethernet packet
@@ -35,35 +44,92 @@ const (
 
        ENV_IFACE  = "GOVPN_IFACE"
        ENV_REMOTE = "GOVPN_REMOTE"
+
+       wrapNewProtocolFromString = "NewProtocolFromString"
 )
 
 var (
        // Version holds release string set at build time
-       Version string
+       Version      string
+       protocolText = map[Protocol]string{
+               ProtocolUDP: "udp",
+               ProtocolTCP: "tcp",
+               ProtocolALL: "all",
+       }
+       // TimeoutDefault is default timeout value for various network operations
+       TimeoutDefault = 60 * time.Second
 )
 
-// ScriptCall calls external program/script.
-// You have to specify path to it and (inteface name as a rule) something
-// that will be the first argument when calling it. Function will return
-// it's output and possible error.
-func ScriptCall(path, ifaceName, remoteAddr string) ([]byte, error) {
-       if path == "" {
-               return nil, nil
+// Protocol is a GoVPN supported protocol: either UDP, TCP or both
+type Protocol int
+
+// String converts a Protocol into a string
+func (p Protocol) String() string {
+       return protocolText[p]
+}
+
+// MarshalJSON returns a JSON string from a protocol
+func (p Protocol) MarshalJSON() ([]byte, error) {
+       str := p.String()
+       output, err := json.Marshal(&str)
+       return output, errors.Wrap(err, "json.Marshal")
+}
+
+// UnmarshalJSON converts a JSON string into a Protocol
+func (p *Protocol) UnmarshalJSON(encoded []byte) error {
+       var str string
+       if err := json.Unmarshal(encoded, &str); err != nil {
+               return errors.Wrapf(err, "Can't unmarshall to string %q", hex.EncodeToString(encoded))
        }
-       if _, err := os.Stat(path); err != nil && os.IsNotExist(err) {
-               return nil, err
+       proto, err := NewProtocolFromString(str)
+       if err != nil {
+               return errors.Wrap(err, wrapNewProtocolFromString)
        }
-       cmd := exec.Command(path)
-       cmd.Env = append(cmd.Env, ENV_IFACE+"="+ifaceName)
-       cmd.Env = append(cmd.Env, ENV_REMOTE+"="+remoteAddr)
-       out, err := cmd.CombinedOutput()
+       *p = proto
+       return nil
+}
+
+// UnmarshalYAML converts a YAML string into a Protocol
+func (p *Protocol) UnmarshalYAML(unmarshal func(interface{}) error) error {
+       var str string
+       err := unmarshal(&str)
        if err != nil {
-               log.Println("Script error", path, err, string(out))
+               return errors.Wrap(err, "unmarshall")
+       }
+
+       proto, err := NewProtocolFromString(str)
+       if err != nil {
+               return errors.Wrap(err, wrapNewProtocolFromString)
+       }
+       *p = proto
+       return nil
+}
+
+// NewProtocolFromString converts a string into a govpn.Protocol
+func NewProtocolFromString(p string) (Protocol, error) {
+       lowP := strings.ToLower(p)
+       for k, v := range protocolText {
+               if strings.ToLower(v) == lowP {
+                       return k, nil
+               }
        }
-       return out, err
+
+       choices := make([]string, len(protocolText))
+       var index = 0
+       for k, v := range protocolText {
+               if v == p {
+                       z := k
+                       p = &z
+                       return nil
+               }
+               choices[index] = v
+               index++
+       }
+
+       return Protocol(-1), errors.Errorf("Invalid protocol %q: %s", p, strings.Join(choices, ","))
 }
 
-// Zero each byte.
+// SliceZero zeros each byte.
 func SliceZero(data []byte) {
        for i := 0; i < len(data); i++ {
                data[i] = 0
index daa00d194d55f88d8d28095ed913dcfede2ee75a..825a0cf87327d48c10b8e551d24954dde9c49e21 100644 (file)
@@ -1,6 +1,6 @@
 /*
 GoVPN -- simple secure free software virtual private network daemon
-Copyright (C) 2014-2017 Sergey Matveev <stargrave@stargrave.org>
+Copyright (C) 2014-2016 Sergey Matveev <stargrave@stargrave.org>
 
 This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
@@ -21,27 +21,48 @@ package govpn
 import (
        "time"
 
+       "github.com/Sirupsen/logrus"
        "github.com/agl/ed25519"
 )
 
 // PeerConf is configuration of a single GoVPN Peer (client)
 type PeerConf struct {
-       ID          *PeerID       `yaml:"-"`
-       Name        string        `yaml:"name"`
-       Iface       string        `yaml:"iface"`
-       MTU         int           `yaml:"mtu"`
-       Up          string        `yaml:"up"`
-       Down        string        `yaml:"down"`
-       TimeoutInt  int           `yaml:"timeout"`
-       Timeout     time.Duration `yaml:"-"`
-       Noise       bool          `yaml:"noise"`
-       CPR         int           `yaml:"cpr"`
-       Encless     bool          `yaml:"encless"`
-       TimeSync    int           `yaml:"timesync"`
-       VerifierRaw string        `yaml:"verifier"`
-
-       // This is passphrase verifier
-       Verifier *Verifier `yaml:"-"`
+       ID       *PeerID
+       Name     string
+       Iface    string
+       MTU      int
+       PreUp    TunnelPreUpAction
+       Up       TunnelAction
+       Down     TunnelAction
+       Timeout  time.Duration
+       Noise    bool
+       CPR      int
+       Encless  bool
+       TimeSync int
+
+       // This is passphrase verifier, client side only
+       Verifier *Verifier
        // This field exists only on client's side
-       DSAPriv *[ed25519.PrivateKeySize]byte `yaml:"-"`
+       DSAPriv *[ed25519.PrivateKeySize]byte
+}
+
+// LogFields return a logrus compatible logging context
+func (pc *PeerConf) LogFields(rootPrefix string) logrus.Fields {
+       p := rootPrefix + "peerconf_"
+       output := logrus.Fields{
+               p + "peer_name": pc.Name,
+               p + "mtu":       pc.MTU,
+               p + "noise":     pc.Noise,
+               p + "pcr":       pc.CPR,
+               p + "encless":   pc.Encless,
+               p + "timesync":  pc.TimeSync,
+               p + "timeout":   pc.Timeout.String(),
+               p + "pre_up":    pc.PreUp != nil,
+               p + "up":        pc.Up != nil,
+               p + "down":      pc.Down != nil,
+       }
+       if pc.ID != nil {
+               output[p+"id"] = pc.ID.String()
+       }
+       return output
 }