]> Cypherpunks.ru repositories - ucspi.git/commitdiff
Buffer fixing and simplification
authorSergey Matveev <stargrave@stargrave.org>
Fri, 17 Sep 2021 19:51:14 +0000 (22:51 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Fri, 17 Sep 2021 19:51:14 +0000 (22:51 +0300)
cmd/tlsc/main.go
cmd/tlss/main.go
conn.go

index 7f1651a9301f5d2a4d7eca353d958098a37509ca..aa1d142cd66661c9d97fd7da32dca4d65957a4eb 100644 (file)
@@ -102,12 +102,9 @@ func main() {
                }
        }
 
-       conn := &ucspi.Conn{R: os.NewFile(6, "R"), W: os.NewFile(7, "W")}
-       if conn.R == nil {
-               log.Fatalln("no 6 file descriptor")
-       }
-       if conn.W == nil {
-               log.Fatalln("no 7 file descriptor")
+       conn, err := ucspi.NewConn(os.NewFile(6, "R"), os.NewFile(7, "W"))
+       if err != nil {
+               log.Fatalln(err)
        }
        tlsConn := tls.Client(conn, cfg)
        if err := tlsConn.Handshake(); err != nil {
@@ -154,17 +151,18 @@ func main() {
        if err = cmd.Start(); err != nil {
                log.Fatalln(err)
        }
-       copiers := make(chan struct{})
+       worker := make(chan struct{})
        go func() {
                io.Copy(rw, tlsConn)
                rw.Close()
-               close(copiers)
+               close(worker)
        }()
        go func() {
                io.Copy(tlsConn, wr)
        }()
        _, err = cmd.Process.Wait()
-       <-copiers
+       <-worker
+       tlsConn.Close()
        if err != nil {
                log.Fatalln(err)
        }
index 7740ecd5cdf07443908de03885378f48437120e1..881d2210990b85d1d745b5b6ac0f0c78e749c2fd 100644 (file)
@@ -22,6 +22,7 @@ import (
        "crypto/x509"
        "flag"
        "fmt"
+       "io"
        "log"
        "os"
        "os/exec"
@@ -70,7 +71,7 @@ func main() {
                cfg.ClientAuth = tls.RequireAndVerifyClientCert
        }
 
-       conn := &ucspi.Conn{R: os.Stdin, W: os.Stdout}
+       conn, _ := ucspi.NewConn(os.Stdin, os.Stdout)
        tlsConn := tls.Server(conn, cfg)
        if err = tlsConn.Handshake(); err != nil {
                log.Fatalln(err)
@@ -80,10 +81,18 @@ func main() {
                dn = tlsConn.ConnectionState().PeerCertificates[0].Subject.String()
        }
 
+       rr, rw, err := os.Pipe()
+       if err != nil {
+               log.Fatalln(err)
+       }
+       wr, ww, err := os.Pipe()
+       if err != nil {
+               log.Fatalln(err)
+       }
        args := flag.Args()
        cmd := exec.Command(args[0], args[1:]...)
-       cmd.Stdin = tlsConn
-       cmd.Stdout = tlsConn
+       cmd.Stdin = rr
+       cmd.Stdout = ww
        cmd.Stderr = os.Stderr
        cmd.Env = append(os.Environ(), "PROTO=TLS")
        if dn != "" {
@@ -93,7 +102,19 @@ func main() {
        if err = cmd.Start(); err != nil {
                log.Fatalln(err)
        }
-       if _, err = cmd.Process.Wait(); err != nil {
+       worker := make(chan struct{})
+       go func() {
+               io.Copy(rw, tlsConn)
+       }()
+       go func() {
+               io.Copy(tlsConn, wr)
+               tlsConn.Close()
+               close(worker)
+       }()
+       err = cmd.Wait()
+       ww.Close()
+       <-worker
+       if err != nil {
                log.Fatalln(err)
        }
 }
diff --git a/conn.go b/conn.go
index 7a6ac4227f90b8ad8ab2882ebb3736489fecab61..480a2980c3ae4c47e76223429337f22a6cb6ee1f 100644 (file)
--- a/conn.go
+++ b/conn.go
@@ -18,7 +18,7 @@ along with this program.  If not, see <http://www.gnu.org/licenses/>.
 package ucspi
 
 import (
-       "io"
+       "errors"
        "net"
        "os"
        "time"
@@ -36,37 +36,33 @@ func (addr *Addr) Network() string { return "tcp" }
 func (addr *Addr) String() string { return addr.ip + ":" + addr.port }
 
 type Conn struct {
-       R   *os.File
-       W   *os.File
-       eof chan struct{}
+       R *os.File
+       W *os.File
 }
 
-type ReadResult struct {
-       n   int
-       err error
+func NewConn(r, w *os.File) (*Conn, error) {
+       if r == nil {
+               return nil, errors.New("no R file descriptor")
+       }
+       if w == nil {
+               return nil, errors.New("no W file descriptor")
+       }
+       return &Conn{R: r, W: w}, nil
 }
 
 func (conn *Conn) Read(b []byte) (int, error) {
-       c := make(chan ReadResult)
-       go func() {
-               n, err := conn.R.Read(b)
-               c <- ReadResult{n, err}
-       }()
-       select {
-       case res := <-c:
-               return res.n, res.err
-       case <-conn.eof:
-               return 0, io.EOF
-       }
+       return conn.R.Read(b)
 }
 
 func (conn *Conn) Write(b []byte) (int, error) { return conn.W.Write(b) }
 
 func (conn *Conn) Close() error {
-       if err := conn.R.Close(); err != nil {
-               return err
+       errR := conn.R.Close()
+       errW := conn.W.Close()
+       if errR != nil {
+               return errR
        }
-       return os.Stdin.Close()
+       return errW
 }
 
 func (conn *Conn) LocalAddr() net.Addr {
@@ -85,12 +81,6 @@ func (conn *Conn) SetDeadline(t time.Time) error {
 }
 
 func (conn *Conn) SetReadDeadline(t time.Time) error {
-       // An ugly hack to forcefully terminate pending read.
-       // net/http calls SetReadDeadline(aLongTimeAgo), but file
-       // descriptors are not capable to exit immediately that way.
-       if t.Equal(aLongTimeAgo) {
-               conn.eof <- struct{}{}
-       }
        return conn.R.SetReadDeadline(t)
 }