]> Cypherpunks.ru repositories - govpn.git/blobdiff - govpn.go
Well, performance is not so high actually
[govpn.git] / govpn.go
index 60c8080ffb98739c63ff957d43f9b4b6afbda814..7b15dd429910ab039016347a8dff3966fd0fa1e9 100644 (file)
--- a/govpn.go
+++ b/govpn.go
@@ -1,5 +1,5 @@
 /*
-govpn -- high-performance secure virtual private network daemon
+govpn -- simple secure virtual private network daemon
 Copyright (C) 2014 Sergey Matveev <stargrave@stargrave.org>
 
 This program is free software: you can redistribute it and/or modify
@@ -15,104 +15,132 @@ GNU General Public License for more details.
 You should have received a copy of the GNU General Public License
 along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
+
+// Simple secure virtual private network daemon
 package main
 
 import (
+       "bytes"
        "encoding/binary"
        "encoding/hex"
        "flag"
        "fmt"
+       "io"
+       "io/ioutil"
        "log"
        "net"
+       "os"
+       "os/exec"
+       "os/signal"
        "time"
 
-       "code.google.com/p/go.crypto/poly1305"
-       "code.google.com/p/go.crypto/salsa20"
-       "github.com/chon219/water"
+       "golang.org/x/crypto/poly1305"
+       "golang.org/x/crypto/salsa20"
+)
+
+var (
+       remoteAddr = flag.String("remote", "", "Remote server address")
+       bindAddr   = flag.String("bind", "", "Bind to address")
+       ifaceName  = flag.String("iface", "tap0", "TAP network interface")
+       keyPath    = flag.String("key", "", "Path to authentication key file")
+       upPath     = flag.String("up", "", "Path to up-script")
+       downPath   = flag.String("down", "", "Path to down-script")
+       mtu        = flag.Int("mtu", 1500, "MTU")
+       nonceDiff  = flag.Int("noncediff", 1, "Allow nonce difference")
+       timeoutP   = flag.Int("timeout", 60, "Timeout seconds")
+       verboseP   = flag.Bool("v", false, "Increase verbosity")
 )
 
 const (
-       NonceSize    = 8
-       AliveTimeout = time.Second * 90
+       NonceSize = 8
+       KeySize   = 32
        // S20BS is Salsa20's internal blocksize in bytes
-       S20BS = 64
+       S20BS         = 64
+       HeartBeatSize = 12
+       HeartBeatMark = "\x00\x00\x00HEARTBEAT"
+       // Maximal amount of bytes transfered with single key (4 GiB)
+       MaxBytesPerKey = 4294967296
 )
 
-type Peer struct {
-       addr      *net.UDPAddr
-       lastPing  time.Time
-       key       *[32]byte // encryption key
-       nonceOur  uint64    // nonce for our messages
-       nonceRecv uint64    // latest received nonce from remote peer
+type TAP interface {
+       io.Reader
+       io.Writer
 }
 
-func (p *Peer) IsAlive() bool {
-       if (p == nil) || (p.lastPing.Add(AliveTimeout).Before(time.Now())) {
-               return false
-       }
-       return true
-}
-
-func (p *Peer) SetAlive() {
-       p.lastPing = time.Now()
+type Peer struct {
+       addr      *net.UDPAddr
+       key       *[KeySize]byte // encryption key
+       nonceOur  uint64         // nonce for our messages
+       nonceRecv uint64         // latest received nonce from remote peer
 }
 
 type UDPPkt struct {
        addr *net.UDPAddr
-       data []byte
+       size int
 }
 
-var (
-       remoteAddr = flag.String("remote", "", "Remote server address")
-       bindAddr   = flag.String("bind", "", "Bind to address")
-       ifaceName  = flag.String("iface", "tap0", "TAP network interface")
-       keyHex     = flag.String("key", "", "Authentication key")
-       mtu        = flag.Int("mtu", 1500, "MTU")
-       verbose    = flag.Bool("v", false, "Increase verbosity")
-)
+func ScriptCall(path *string) {
+       if *path == "" {
+               return
+       }
+       cmd := exec.Command(*path, *ifaceName)
+       var out bytes.Buffer
+       cmd.Stdout = &out
+       if err := cmd.Run(); err != nil {
+               fmt.Println(time.Now(), "script error: ", err.Error(), string(out.Bytes()))
+       }
+}
 
 func main() {
        flag.Parse()
+       timeout := *timeoutP
+       verbose := *verboseP
+       noncediff := uint64(*nonceDiff)
        log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile)
 
        // Key decoding
-       if len(*keyHex) != 64 {
-               panic("Key is required argument (64 hex characters)")
+       keyData, err := ioutil.ReadFile(*keyPath)
+       if err != nil {
+               panic("Unable to read keyfile: " + err.Error())
+       }
+       if len(keyData) < 64 {
+               panic("Key must be 64 hex characters long")
        }
-       keyDecoded, err := hex.DecodeString(*keyHex)
+       keyDecoded, err := hex.DecodeString(string(keyData[0:64]))
        if err != nil {
-               panic(err)
+               panic("Unable to decode the key: " + err.Error())
        }
-       key := new([32]byte)
+       key := new([KeySize]byte)
        copy(key[:], keyDecoded)
+       keyDecoded = nil
+       keyData = nil
 
        // Interface listening
        maxIfacePktSize := *mtu - poly1305.TagSize - NonceSize
        log.Println("Max MTU", maxIfacePktSize, "on interface", *ifaceName)
-       iface, err := water.NewTAP(*ifaceName)
-       if err != nil {
-               panic(err)
-       }
-       ethSink := make(chan []byte)
+       iface := NewTAP(*ifaceName)
+       ethBuf := make([]byte, maxIfacePktSize)
+       ethSink := make(chan int)
+       ethSinkReady := make(chan bool)
        go func() {
                for {
-                       buf := make([]byte, maxIfacePktSize)
-                       n, err := iface.Read(buf)
+                       <-ethSinkReady
+                       n, err := iface.Read(ethBuf)
                        if err != nil {
                                panic(err)
                        }
-                       ethSink <- buf[:n]
+                       ethSink <- n
                }
        }()
+       ethSinkReady <- true
 
        // Network address parsing
-       if (len(*bindAddr) > 1 && len(*remoteAddr) > 1) || (len(*bindAddr) == 0 && len(*remoteAddr) == 0) {
+       if (len(*bindAddr) > 1 && len(*remoteAddr) > 1) ||
+               (len(*bindAddr) == 0 && len(*remoteAddr) == 0) {
                panic("Either -bind or -remote must be specified only")
        }
-
        var conn *net.UDPConn
        var remote *net.UDPAddr
-
        serverMode := false
        bindTo := "0.0.0.0:0"
 
@@ -137,41 +165,80 @@ func main() {
                }
        }
 
-       udpSink := make(chan UDPPkt)
-       go func(conn *net.UDPConn, sink chan<- UDPPkt) {
+       udpBuf := make([]byte, *mtu)
+       udpSink := make(chan *UDPPkt)
+       udpSinkReady := make(chan bool)
+       go func(conn *net.UDPConn) {
                for {
-                       data := make([]byte, *mtu)
-                       n, addr, err := conn.ReadFromUDP(data)
+                       <-udpSinkReady
+                       conn.SetReadDeadline(time.Now().Add(time.Second))
+                       n, addr, err := conn.ReadFromUDP(udpBuf)
                        if err != nil {
-                               fmt.Print("B")
+                               if verbose {
+                                       fmt.Print("B")
+                               }
+                               udpSink <- nil
+                       } else {
+                               udpSink <- &UDPPkt{addr, n}
                        }
-                       sink <- UDPPkt{addr, data[:n]}
                }
-       }(conn, udpSink)
+       }(conn)
+       udpSinkReady <- true
 
        // Process packets
-       var udpPkt UDPPkt
-       var ethPkt []byte
+       var udpPkt *UDPPkt
+       var udpPktData []byte
+       var ethPktSize int
+       var frame []byte
        var addr string
-       var peer Peer
+       var peer *Peer
        var p *Peer
-       var buf []byte
 
+       timeouts := 0
+       bytes := 0
        states := make(map[string]*Handshake)
        nonce := make([]byte, NonceSize)
-       keyAuth := new([32]byte)
+       keyAuth := new([KeySize]byte)
        tag := new([poly1305.TagSize]byte)
+       buf := make([]byte, *mtu+S20BS)
+       emptyKey := make([]byte, KeySize)
 
        if !serverMode {
-               log.Println("starting handshake with", *remoteAddr)
                states[remote.String()] = HandshakeStart(conn, remote, key)
        }
 
+       heartbeat := time.Tick(time.Second * time.Duration(timeout/3))
+       go func() { <-heartbeat }()
+       heartbeatMark := []byte(HeartBeatMark)
+
+       termSignal := make(chan os.Signal, 1)
+       signal.Notify(termSignal, os.Interrupt, os.Kill)
+
+       finished := false
        for {
-               buf = make([]byte, *mtu+S20BS)
+               if finished {
+                       break
+               }
+               if !serverMode && bytes > MaxBytesPerKey {
+                       states[remote.String()] = HandshakeStart(conn, remote, key)
+                       bytes = 0
+               }
                select {
+               case <-termSignal:
+                       finished = true
+               case <-heartbeat:
+                       go func() { ethSink <- -1 }()
                case udpPkt = <-udpSink:
-                       if isValidHandshakePkt(udpPkt.data) {
+                       timeouts++
+                       if !serverMode && timeouts >= timeout {
+                               finished = true
+                       }
+                       if udpPkt == nil {
+                               udpSinkReady <- true
+                               continue
+                       }
+                       udpPktData = udpBuf[:udpPkt.size]
+                       if isValidHandshakePkt(udpPktData) {
                                addr = udpPkt.addr.String()
                                state, exists := states[addr]
                                if serverMode {
@@ -179,72 +246,96 @@ func main() {
                                                state = &Handshake{addr: udpPkt.addr}
                                                states[addr] = state
                                        }
-                                       p = state.Server(conn, key, udpPkt.data)
+                                       p = state.Server(noncediff, conn, key, udpPktData)
                                } else {
                                        if !exists {
                                                fmt.Print("[HS?]")
+                                               udpSinkReady <- true
                                                continue
                                        }
-                                       p = state.Client(conn, key, udpPkt.data)
+                                       p = state.Client(noncediff, conn, key, udpPktData)
                                }
                                if p != nil {
                                        fmt.Print("[HS-OK]")
-                                       peer = *p
+                                       if peer == nil {
+                                               go ScriptCall(upPath)
+                                       }
+                                       peer = p
                                        delete(states, addr)
                                }
+                               udpSinkReady <- true
                                continue
                        }
-                       if !peer.IsAlive() {
+                       if peer == nil {
+                               udpSinkReady <- true
                                continue
                        }
-                       nonceRecv, _ := binary.Uvarint(udpPkt.data[:8])
-                       if peer.nonceRecv >= nonceRecv {
+                       nonceRecv, _ := binary.Uvarint(udpPktData[:8])
+                       if nonceRecv < peer.nonceRecv-noncediff {
                                fmt.Print("R")
+                               udpSinkReady <- true
                                continue
                        }
-                       copy(tag[:], udpPkt.data[len(udpPkt.data)-poly1305.TagSize:])
-                       copy(buf[S20BS:], udpPkt.data[NonceSize:len(udpPkt.data)-poly1305.TagSize])
+                       copy(buf[:KeySize], emptyKey)
+                       copy(tag[:], udpPktData[udpPkt.size-poly1305.TagSize:])
+                       copy(buf[S20BS:], udpPktData[NonceSize:udpPkt.size-poly1305.TagSize])
                        salsa20.XORKeyStream(
-                               buf[:S20BS+len(udpPkt.data)-poly1305.TagSize],
-                               buf[:S20BS+len(udpPkt.data)-poly1305.TagSize],
-                               udpPkt.data[:NonceSize],
+                               buf[:S20BS+udpPkt.size-poly1305.TagSize],
+                               buf[:S20BS+udpPkt.size-poly1305.TagSize],
+                               udpPktData[:NonceSize],
                                peer.key,
                        )
-                       copy(keyAuth[:], buf[:32])
-                       if !poly1305.Verify(tag, udpPkt.data[:len(udpPkt.data)-poly1305.TagSize], keyAuth) {
+                       copy(keyAuth[:], buf[:KeySize])
+                       if !poly1305.Verify(tag, udpPktData[:udpPkt.size-poly1305.TagSize], keyAuth) {
+                               udpSinkReady <- true
                                fmt.Print("T")
                                continue
                        }
+                       udpSinkReady <- true
                        peer.nonceRecv = nonceRecv
-                       peer.SetAlive()
-                       if _, err := iface.Write(buf[S20BS : S20BS+len(udpPkt.data)-NonceSize-poly1305.TagSize]); err != nil {
-                               log.Println("Error writing to iface")
+                       timeouts = 0
+                       frame = buf[S20BS : S20BS+udpPkt.size-NonceSize-poly1305.TagSize]
+                       bytes += len(frame)
+                       if string(frame[0:HeartBeatSize]) == HeartBeatMark {
+                               continue
+                       }
+                       if _, err := iface.Write(frame); err != nil {
+                               log.Println("Error writing to iface: ", err)
                        }
-                       if *verbose {
+                       if verbose {
                                fmt.Print("r")
                        }
-               case ethPkt = <-ethSink:
-                       if len(ethPkt) > maxIfacePktSize {
+               case ethPktSize = <-ethSink:
+                       if ethPktSize > maxIfacePktSize {
                                panic("Too large packet on interface")
                        }
-                       if !peer.IsAlive() {
+                       if peer == nil {
+                               ethSinkReady <- true
                                continue
                        }
                        peer.nonceOur = peer.nonceOur + 2
                        binary.PutUvarint(nonce, peer.nonceOur)
-                       copy(buf[S20BS:], ethPkt)
+                       copy(buf[:KeySize], emptyKey)
+                       if ethPktSize > -1 {
+                               copy(buf[S20BS:], ethBuf[:ethPktSize])
+                               ethSinkReady <- true
+                       } else {
+                               copy(buf[S20BS:], heartbeatMark)
+                               ethPktSize = HeartBeatSize
+                       }
                        salsa20.XORKeyStream(buf, buf, nonce, peer.key)
                        copy(buf[S20BS-NonceSize:S20BS], nonce)
-                       copy(keyAuth[:], buf[:32])
-                       dataToSend := buf[S20BS-NonceSize : S20BS+len(ethPkt)]
+                       copy(keyAuth[:], buf[:KeySize])
+                       dataToSend := buf[S20BS-NonceSize : S20BS+ethPktSize]
                        poly1305.Sum(tag, dataToSend, keyAuth)
-                       _, err := conn.WriteTo(append(dataToSend, tag[:]...), peer.addr)
-                       if err != nil {
+                       bytes += len(dataToSend)
+                       if _, err := conn.WriteTo(append(dataToSend, tag[:]...), peer.addr); err != nil {
                                log.Println("Error sending UDP", err)
                        }
-                       if *verbose {
+                       if verbose {
                                fmt.Print("w")
                        }
                }
        }
+       ScriptCall(downPath)
 }