From: Sergey Matveev Date: Fri, 17 Sep 2021 19:51:14 +0000 (+0300) Subject: Buffer fixing and simplification X-Git-Tag: v0.1.0~7 X-Git-Url: http://www.git.cypherpunks.ru/?p=ucspi.git;a=commitdiff_plain;h=72fc7e6d14a5113b013514eecd3c5b3485671631 Buffer fixing and simplification --- diff --git a/cmd/tlsc/main.go b/cmd/tlsc/main.go index 7f1651a..aa1d142 100644 --- a/cmd/tlsc/main.go +++ b/cmd/tlsc/main.go @@ -102,12 +102,9 @@ func main() { } } - conn := &ucspi.Conn{R: os.NewFile(6, "R"), W: os.NewFile(7, "W")} - if conn.R == nil { - log.Fatalln("no 6 file descriptor") - } - if conn.W == nil { - log.Fatalln("no 7 file descriptor") + conn, err := ucspi.NewConn(os.NewFile(6, "R"), os.NewFile(7, "W")) + if err != nil { + log.Fatalln(err) } tlsConn := tls.Client(conn, cfg) if err := tlsConn.Handshake(); err != nil { @@ -154,17 +151,18 @@ func main() { if err = cmd.Start(); err != nil { log.Fatalln(err) } - copiers := make(chan struct{}) + worker := make(chan struct{}) go func() { io.Copy(rw, tlsConn) rw.Close() - close(copiers) + close(worker) }() go func() { io.Copy(tlsConn, wr) }() _, err = cmd.Process.Wait() - <-copiers + <-worker + tlsConn.Close() if err != nil { log.Fatalln(err) } diff --git a/cmd/tlss/main.go b/cmd/tlss/main.go index 7740ecd..881d221 100644 --- a/cmd/tlss/main.go +++ b/cmd/tlss/main.go @@ -22,6 +22,7 @@ import ( "crypto/x509" "flag" "fmt" + "io" "log" "os" "os/exec" @@ -70,7 +71,7 @@ func main() { cfg.ClientAuth = tls.RequireAndVerifyClientCert } - conn := &ucspi.Conn{R: os.Stdin, W: os.Stdout} + conn, _ := ucspi.NewConn(os.Stdin, os.Stdout) tlsConn := tls.Server(conn, cfg) if err = tlsConn.Handshake(); err != nil { log.Fatalln(err) @@ -80,10 +81,18 @@ func main() { dn = tlsConn.ConnectionState().PeerCertificates[0].Subject.String() } + rr, rw, err := os.Pipe() + if err != nil { + log.Fatalln(err) + } + wr, ww, err := os.Pipe() + if err != nil { + log.Fatalln(err) + } args := flag.Args() cmd := exec.Command(args[0], args[1:]...) - cmd.Stdin = tlsConn - cmd.Stdout = tlsConn + cmd.Stdin = rr + cmd.Stdout = ww cmd.Stderr = os.Stderr cmd.Env = append(os.Environ(), "PROTO=TLS") if dn != "" { @@ -93,7 +102,19 @@ func main() { if err = cmd.Start(); err != nil { log.Fatalln(err) } - if _, err = cmd.Process.Wait(); err != nil { + worker := make(chan struct{}) + go func() { + io.Copy(rw, tlsConn) + }() + go func() { + io.Copy(tlsConn, wr) + tlsConn.Close() + close(worker) + }() + err = cmd.Wait() + ww.Close() + <-worker + if err != nil { log.Fatalln(err) } } diff --git a/conn.go b/conn.go index 7a6ac42..480a298 100644 --- a/conn.go +++ b/conn.go @@ -18,7 +18,7 @@ along with this program. If not, see . package ucspi import ( - "io" + "errors" "net" "os" "time" @@ -36,37 +36,33 @@ func (addr *Addr) Network() string { return "tcp" } func (addr *Addr) String() string { return addr.ip + ":" + addr.port } type Conn struct { - R *os.File - W *os.File - eof chan struct{} + R *os.File + W *os.File } -type ReadResult struct { - n int - err error +func NewConn(r, w *os.File) (*Conn, error) { + if r == nil { + return nil, errors.New("no R file descriptor") + } + if w == nil { + return nil, errors.New("no W file descriptor") + } + return &Conn{R: r, W: w}, nil } func (conn *Conn) Read(b []byte) (int, error) { - c := make(chan ReadResult) - go func() { - n, err := conn.R.Read(b) - c <- ReadResult{n, err} - }() - select { - case res := <-c: - return res.n, res.err - case <-conn.eof: - return 0, io.EOF - } + return conn.R.Read(b) } func (conn *Conn) Write(b []byte) (int, error) { return conn.W.Write(b) } func (conn *Conn) Close() error { - if err := conn.R.Close(); err != nil { - return err + errR := conn.R.Close() + errW := conn.W.Close() + if errR != nil { + return errR } - return os.Stdin.Close() + return errW } func (conn *Conn) LocalAddr() net.Addr { @@ -85,12 +81,6 @@ func (conn *Conn) SetDeadline(t time.Time) error { } func (conn *Conn) SetReadDeadline(t time.Time) error { - // An ugly hack to forcefully terminate pending read. - // net/http calls SetReadDeadline(aLongTimeAgo), but file - // descriptors are not capable to exit immediately that way. - if t.Equal(aLongTimeAgo) { - conn.eof <- struct{}{} - } return conn.R.SetReadDeadline(t) }