]> Cypherpunks.ru repositories - gostls13.git/commitdiff
crypto/tls: add CloseWrite method to Conn
authorBen Burkert <ben@benburkert.com>
Mon, 17 Oct 2016 21:47:48 +0000 (14:47 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 26 Oct 2016 23:05:40 +0000 (23:05 +0000)
The CloseWrite method sends a close_notify alert record to the other
side of the connection. This record indicates that the sender has
finished sending on the connection. Unlike the Close method, the sender
may still read from the connection until it recieves a close_notify
record (or the underlying connection is closed). This is analogous to a
TCP half-close.

This is a rework of CL 25159 with fixes for the unstable test.

Updates #8579

Change-Id: I47608d2f82a88baff07a90fd64c280ed16a60d5e
Reviewed-on: https://go-review.googlesource.com/31318
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/crypto/tls/conn.go
src/crypto/tls/tls_test.go

index 8d6f7e77add4df5ca456b4445d68672f5d1b235a..28d111afc02010fdadac9af7cb03225328a571c6 100644 (file)
@@ -64,6 +64,13 @@ type Conn struct {
        // the first transmitted Finished message is the tls-unique
        // channel-binding value.
        clientFinishedIsFirst bool
+
+       // closeNotifyErr is any error from sending the alertCloseNotify record.
+       closeNotifyErr error
+       // closeNotifySent is true if the Conn attempted to send an
+       // alertCloseNotify record.
+       closeNotifySent bool
+
        // clientFinished and serverFinished contain the Finished message sent
        // by the client or server in the most recent handshake. This is
        // retained to support the renegotiation extension and tls-unique
@@ -1000,7 +1007,10 @@ func (c *Conn) readHandshake() (interface{}, error) {
        return m, nil
 }
 
-var errClosed = errors.New("tls: use of closed connection")
+var (
+       errClosed   = errors.New("tls: use of closed connection")
+       errShutdown = errors.New("tls: protocol is shutdown")
+)
 
 // Write writes data to the connection.
 func (c *Conn) Write(b []byte) (int, error) {
@@ -1031,6 +1041,10 @@ func (c *Conn) Write(b []byte) (int, error) {
                return 0, alertInternalError
        }
 
+       if c.closeNotifySent {
+               return 0, errShutdown
+       }
+
        // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
        // attack when using block mode ciphers due to predictable IVs.
        // This can be prevented by splitting each Application Data
@@ -1194,7 +1208,7 @@ func (c *Conn) Close() error {
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
        if c.handshakeComplete {
-               alertErr = c.sendAlert(alertCloseNotify)
+               alertErr = c.closeNotify()
        }
 
        if err := c.conn.Close(); err != nil {
@@ -1203,6 +1217,32 @@ func (c *Conn) Close() error {
        return alertErr
 }
 
+var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
+
+// CloseWrite shuts down the writing side of the connection. It should only be
+// called once the handshake has completed and does not call CloseWrite on the
+// underlying connection. Most callers should just use Close.
+func (c *Conn) CloseWrite() error {
+       c.handshakeMutex.Lock()
+       defer c.handshakeMutex.Unlock()
+       if !c.handshakeComplete {
+               return errEarlyCloseWrite
+       }
+
+       return c.closeNotify()
+}
+
+func (c *Conn) closeNotify() error {
+       c.out.Lock()
+       defer c.out.Unlock()
+
+       if !c.closeNotifySent {
+               c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
+               c.closeNotifySent = true
+       }
+       return c.closeNotifyErr
+}
+
 // Handshake runs the client or server handshake
 // protocol if it has not yet been run.
 // Most uses of this package need not call Handshake
index 99153a5dad4f7e3bda0b257e8d58bc863ec5719a..b4fedd59c0ffbf30a849993158bf1eb9785ae887 100644 (file)
@@ -11,6 +11,7 @@ import (
        "fmt"
        "internal/testenv"
        "io"
+       "io/ioutil"
        "math"
        "math/rand"
        "net"
@@ -458,6 +459,112 @@ func TestConnCloseBreakingWrite(t *testing.T) {
        }
 }
 
+func TestConnCloseWrite(t *testing.T) {
+       ln := newLocalListener(t)
+       defer ln.Close()
+
+       clientDoneChan := make(chan struct{})
+
+       serverCloseWrite := func() error {
+               sconn, err := ln.Accept()
+               if err != nil {
+                       return fmt.Errorf("accept: %v", err)
+               }
+               defer sconn.Close()
+
+               serverConfig := testConfig.Clone()
+               srv := Server(sconn, serverConfig)
+               if err := srv.Handshake(); err != nil {
+                       return fmt.Errorf("handshake: %v", err)
+               }
+               defer srv.Close()
+
+               data, err := ioutil.ReadAll(srv)
+               if err != nil {
+                       return err
+               }
+               if len(data) > 0 {
+                       return fmt.Errorf("Read data = %q; want nothing", data)
+               }
+
+               if err := srv.CloseWrite(); err != nil {
+                       return fmt.Errorf("server CloseWrite: %v", err)
+               }
+
+               // Wait for clientCloseWrite to finish, so we know we
+               // tested the CloseWrite before we defer the
+               // sconn.Close above, which would also cause the
+               // client to unblock like CloseWrite.
+               <-clientDoneChan
+               return nil
+       }
+
+       clientCloseWrite := func() error {
+               defer close(clientDoneChan)
+
+               clientConfig := testConfig.Clone()
+               conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
+               if err != nil {
+                       return err
+               }
+               if err := conn.Handshake(); err != nil {
+                       return err
+               }
+               defer conn.Close()
+
+               if err := conn.CloseWrite(); err != nil {
+                       return fmt.Errorf("client CloseWrite: %v", err)
+               }
+
+               if _, err := conn.Write([]byte{0}); err != errShutdown {
+                       return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
+               }
+
+               data, err := ioutil.ReadAll(conn)
+               if err != nil {
+                       return err
+               }
+               if len(data) > 0 {
+                       return fmt.Errorf("Read data = %q; want nothing", data)
+               }
+               return nil
+       }
+
+       errChan := make(chan error, 2)
+
+       go func() { errChan <- serverCloseWrite() }()
+       go func() { errChan <- clientCloseWrite() }()
+
+       for i := 0; i < 2; i++ {
+               select {
+               case err := <-errChan:
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+               case <-time.After(10 * time.Second):
+                       t.Fatal("deadlock")
+               }
+       }
+
+       // Also test CloseWrite being called before the handshake is
+       // finished:
+       {
+               ln2 := newLocalListener(t)
+               defer ln2.Close()
+
+               netConn, err := net.Dial("tcp", ln2.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer netConn.Close()
+               conn := Client(netConn, testConfig.Clone())
+
+               if err := conn.CloseWrite(); err != errEarlyCloseWrite {
+                       t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
+               }
+       }
+}
+
 func TestClone(t *testing.T) {
        var c1 Config
        v := reflect.ValueOf(&c1).Elem()