// the first transmitted Finished message is the tls-unique
// channel-binding value.
clientFinishedIsFirst bool
+
+ // closeNotifyErr is any error from sending the alertCloseNotify record.
+ closeNotifyErr error
+ // closeNotifySent is true if the Conn attempted to send an
+ // alertCloseNotify record.
+ closeNotifySent bool
+
// clientFinished and serverFinished contain the Finished message sent
// by the client or server in the most recent handshake. This is
// retained to support the renegotiation extension and tls-unique
return m, nil
}
-var errClosed = errors.New("tls: use of closed connection")
+var (
+ errClosed = errors.New("tls: use of closed connection")
+ errShutdown = errors.New("tls: protocol is shutdown")
+)
// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
return 0, alertInternalError
}
+ if c.closeNotifySent {
+ return 0, errShutdown
+ }
+
// SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
// attack when using block mode ciphers due to predictable IVs.
// This can be prevented by splitting each Application Data
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.handshakeComplete {
- alertErr = c.sendAlert(alertCloseNotify)
+ alertErr = c.closeNotify()
}
if err := c.conn.Close(); err != nil {
return alertErr
}
+var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
+
+// CloseWrite shuts down the writing side of the connection. It should only be
+// called once the handshake has completed and does not call CloseWrite on the
+// underlying connection. Most callers should just use Close.
+func (c *Conn) CloseWrite() error {
+ c.handshakeMutex.Lock()
+ defer c.handshakeMutex.Unlock()
+ if !c.handshakeComplete {
+ return errEarlyCloseWrite
+ }
+
+ return c.closeNotify()
+}
+
+func (c *Conn) closeNotify() error {
+ c.out.Lock()
+ defer c.out.Unlock()
+
+ if !c.closeNotifySent {
+ c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
+ c.closeNotifySent = true
+ }
+ return c.closeNotifyErr
+}
+
// Handshake runs the client or server handshake
// protocol if it has not yet been run.
// Most uses of this package need not call Handshake
"fmt"
"internal/testenv"
"io"
+ "io/ioutil"
"math"
"math/rand"
"net"
}
}
+func TestConnCloseWrite(t *testing.T) {
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ clientDoneChan := make(chan struct{})
+
+ serverCloseWrite := func() error {
+ sconn, err := ln.Accept()
+ if err != nil {
+ return fmt.Errorf("accept: %v", err)
+ }
+ defer sconn.Close()
+
+ serverConfig := testConfig.Clone()
+ srv := Server(sconn, serverConfig)
+ if err := srv.Handshake(); err != nil {
+ return fmt.Errorf("handshake: %v", err)
+ }
+ defer srv.Close()
+
+ data, err := ioutil.ReadAll(srv)
+ if err != nil {
+ return err
+ }
+ if len(data) > 0 {
+ return fmt.Errorf("Read data = %q; want nothing", data)
+ }
+
+ if err := srv.CloseWrite(); err != nil {
+ return fmt.Errorf("server CloseWrite: %v", err)
+ }
+
+ // Wait for clientCloseWrite to finish, so we know we
+ // tested the CloseWrite before we defer the
+ // sconn.Close above, which would also cause the
+ // client to unblock like CloseWrite.
+ <-clientDoneChan
+ return nil
+ }
+
+ clientCloseWrite := func() error {
+ defer close(clientDoneChan)
+
+ clientConfig := testConfig.Clone()
+ conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
+ if err != nil {
+ return err
+ }
+ if err := conn.Handshake(); err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ if err := conn.CloseWrite(); err != nil {
+ return fmt.Errorf("client CloseWrite: %v", err)
+ }
+
+ if _, err := conn.Write([]byte{0}); err != errShutdown {
+ return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
+ }
+
+ data, err := ioutil.ReadAll(conn)
+ if err != nil {
+ return err
+ }
+ if len(data) > 0 {
+ return fmt.Errorf("Read data = %q; want nothing", data)
+ }
+ return nil
+ }
+
+ errChan := make(chan error, 2)
+
+ go func() { errChan <- serverCloseWrite() }()
+ go func() { errChan <- clientCloseWrite() }()
+
+ for i := 0; i < 2; i++ {
+ select {
+ case err := <-errChan:
+ if err != nil {
+ t.Fatal(err)
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatal("deadlock")
+ }
+ }
+
+ // Also test CloseWrite being called before the handshake is
+ // finished:
+ {
+ ln2 := newLocalListener(t)
+ defer ln2.Close()
+
+ netConn, err := net.Dial("tcp", ln2.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer netConn.Close()
+ conn := Client(netConn, testConfig.Clone())
+
+ if err := conn.CloseWrite(); err != errEarlyCloseWrite {
+ t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
+ }
+ }
+}
+
func TestClone(t *testing.T) {
var c1 Config
v := reflect.ValueOf(&c1).Elem()