}
}
- conn := &ucspi.Conn{R: os.NewFile(6, "R"), W: os.NewFile(7, "W")}
- if conn.R == nil {
- log.Fatalln("no 6 file descriptor")
- }
- if conn.W == nil {
- log.Fatalln("no 7 file descriptor")
+ conn, err := ucspi.NewConn(os.NewFile(6, "R"), os.NewFile(7, "W"))
+ if err != nil {
+ log.Fatalln(err)
}
tlsConn := tls.Client(conn, cfg)
if err := tlsConn.Handshake(); err != nil {
if err = cmd.Start(); err != nil {
log.Fatalln(err)
}
- copiers := make(chan struct{})
+ worker := make(chan struct{})
go func() {
io.Copy(rw, tlsConn)
rw.Close()
- close(copiers)
+ close(worker)
}()
go func() {
io.Copy(tlsConn, wr)
}()
_, err = cmd.Process.Wait()
- <-copiers
+ <-worker
+ tlsConn.Close()
if err != nil {
log.Fatalln(err)
}
"crypto/x509"
"flag"
"fmt"
+ "io"
"log"
"os"
"os/exec"
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
- conn := &ucspi.Conn{R: os.Stdin, W: os.Stdout}
+ conn, _ := ucspi.NewConn(os.Stdin, os.Stdout)
tlsConn := tls.Server(conn, cfg)
if err = tlsConn.Handshake(); err != nil {
log.Fatalln(err)
dn = tlsConn.ConnectionState().PeerCertificates[0].Subject.String()
}
+ rr, rw, err := os.Pipe()
+ if err != nil {
+ log.Fatalln(err)
+ }
+ wr, ww, err := os.Pipe()
+ if err != nil {
+ log.Fatalln(err)
+ }
args := flag.Args()
cmd := exec.Command(args[0], args[1:]...)
- cmd.Stdin = tlsConn
- cmd.Stdout = tlsConn
+ cmd.Stdin = rr
+ cmd.Stdout = ww
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(), "PROTO=TLS")
if dn != "" {
if err = cmd.Start(); err != nil {
log.Fatalln(err)
}
- if _, err = cmd.Process.Wait(); err != nil {
+ worker := make(chan struct{})
+ go func() {
+ io.Copy(rw, tlsConn)
+ }()
+ go func() {
+ io.Copy(tlsConn, wr)
+ tlsConn.Close()
+ close(worker)
+ }()
+ err = cmd.Wait()
+ ww.Close()
+ <-worker
+ if err != nil {
log.Fatalln(err)
}
}
package ucspi
import (
- "io"
+ "errors"
"net"
"os"
"time"
func (addr *Addr) String() string { return addr.ip + ":" + addr.port }
type Conn struct {
- R *os.File
- W *os.File
- eof chan struct{}
+ R *os.File
+ W *os.File
}
-type ReadResult struct {
- n int
- err error
+func NewConn(r, w *os.File) (*Conn, error) {
+ if r == nil {
+ return nil, errors.New("no R file descriptor")
+ }
+ if w == nil {
+ return nil, errors.New("no W file descriptor")
+ }
+ return &Conn{R: r, W: w}, nil
}
func (conn *Conn) Read(b []byte) (int, error) {
- c := make(chan ReadResult)
- go func() {
- n, err := conn.R.Read(b)
- c <- ReadResult{n, err}
- }()
- select {
- case res := <-c:
- return res.n, res.err
- case <-conn.eof:
- return 0, io.EOF
- }
+ return conn.R.Read(b)
}
func (conn *Conn) Write(b []byte) (int, error) { return conn.W.Write(b) }
func (conn *Conn) Close() error {
- if err := conn.R.Close(); err != nil {
- return err
+ errR := conn.R.Close()
+ errW := conn.W.Close()
+ if errR != nil {
+ return errR
}
- return os.Stdin.Close()
+ return errW
}
func (conn *Conn) LocalAddr() net.Addr {
}
func (conn *Conn) SetReadDeadline(t time.Time) error {
- // An ugly hack to forcefully terminate pending read.
- // net/http calls SetReadDeadline(aLongTimeAgo), but file
- // descriptors are not capable to exit immediately that way.
- if t.Equal(aLongTimeAgo) {
- conn.eof <- struct{}{}
- }
return conn.R.SetReadDeadline(t)
}