]> Cypherpunks.ru repositories - gostls13.git/commitdiff
database/sql: prevent Tx statement from committing after rollback
authorDaniel Theophanes <kardianos@gmail.com>
Fri, 24 Jan 2020 14:40:49 +0000 (06:40 -0800)
committerDaniel Theophanes <kardianos@gmail.com>
Mon, 20 Apr 2020 17:45:50 +0000 (17:45 +0000)
It was possible for a Tx that was aborted for rollback
asynchronously to execute a query after the rollback had completed
on the database, which often would auto commit the query outside
of the transaction.

By W-locking the tx.closemu prior to issuing the rollback
connection it ensures any Tx query either fails or finishes
on the Tx, and never after the Tx has rolled back.

Fixes #34775
Fixes #32942

Change-Id: I017b7932082f2f4ead70bae08b61ed9068ac1d01
Reviewed-on: https://go-review.googlesource.com/c/go/+/216240
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
src/database/sql/fakedb_test.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index b6e9a5707e9e6b16cbdb7cafc34da30cdd0eab4b..7605a2a6d23e0327bcccaf1d8a8cd9d616a26a23 100644 (file)
@@ -390,6 +390,7 @@ func setStrictFakeConnClose(t *testing.T) {
 
 func (c *fakeConn) ResetSession(ctx context.Context) error {
        c.dirtySession = false
+       c.currTx = nil
        if c.isBad() {
                return driver.ErrBadConn
        }
@@ -734,6 +735,9 @@ var hookExecBadConn func() bool
 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
        panic("Using ExecContext")
 }
+
+var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
+
 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
        if s.panic == "Exec" {
                panic(s.panic)
@@ -746,7 +750,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
                return nil, driver.ErrBadConn
        }
        if s.c.isDirtyAndMark() {
-               return nil, errors.New("fakedb: session is dirty")
+               return nil, errFakeConnSessionDirty
        }
 
        err := checkSubsetTypes(s.c.db.allowAny, args)
@@ -860,7 +864,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
                return nil, driver.ErrBadConn
        }
        if s.c.isDirtyAndMark() {
-               return nil, errors.New("fakedb: session is dirty")
+               return nil, errFakeConnSessionDirty
        }
 
        err := checkSubsetTypes(s.c.db.allowAny, args)
@@ -893,6 +897,37 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
                                }
                        }
                }
+               if s.table == "tx_status" && s.colName[0] == "tx_status" {
+                       txStatus := "autocommit"
+                       if s.c.currTx != nil {
+                               txStatus = "transaction"
+                       }
+                       cursor := &rowsCursor{
+                               parentMem: s.c,
+                               posRow:    -1,
+                               rows: [][]*row{
+                                       []*row{
+                                               {
+                                                       cols: []interface{}{
+                                                               txStatus,
+                                                       },
+                                               },
+                                       },
+                               },
+                               cols: [][]string{
+                                       []string{
+                                               "tx_status",
+                                       },
+                               },
+                               colType: [][]string{
+                                       []string{
+                                               "string",
+                                       },
+                               },
+                               errPos: -1,
+                       }
+                       return cursor, nil
+               }
 
                t.mu.Lock()
 
index 2fae0f02ff4ad8b84bb97c8e679843fd4a7b6f97..3710264dcf6097dcd25db6f51662185726a8a1d8 100644 (file)
@@ -2061,14 +2061,10 @@ func (tx *Tx) isDone() bool {
 // that has already been committed or rolled back.
 var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back")
 
-// close returns the connection to the pool and
-// must only be called by Tx.rollback or Tx.Commit.
-func (tx *Tx) close(err error) {
-       tx.cancel()
-
-       tx.closemu.Lock()
-       defer tx.closemu.Unlock()
-
+// closeLocked returns the connection to the pool and
+// must only be called by Tx.rollback or Tx.Commit while
+// closemu is Locked and tx already canceled.
+func (tx *Tx) closeLocked(err error) {
        tx.releaseConn(err)
        tx.dc = nil
        tx.txi = nil
@@ -2135,6 +2131,15 @@ func (tx *Tx) Commit() error {
        if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
                return ErrTxDone
        }
+
+       // Cancel the Tx to release any active R-closemu locks.
+       // This is safe to do because tx.done has already transitioned
+       // from 0 to 1. Hold the W-closemu lock prior to rollback
+       // to ensure no other connection has an active query.
+       tx.cancel()
+       tx.closemu.Lock()
+       defer tx.closemu.Unlock()
+
        var err error
        withLock(tx.dc, func() {
                err = tx.txi.Commit()
@@ -2142,16 +2147,31 @@ func (tx *Tx) Commit() error {
        if err != driver.ErrBadConn {
                tx.closePrepared()
        }
-       tx.close(err)
+       tx.closeLocked(err)
        return err
 }
 
+var rollbackHook func()
+
 // rollback aborts the transaction and optionally forces the pool to discard
 // the connection.
 func (tx *Tx) rollback(discardConn bool) error {
        if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
                return ErrTxDone
        }
+
+       if rollbackHook != nil {
+               rollbackHook()
+       }
+
+       // Cancel the Tx to release any active R-closemu locks.
+       // This is safe to do because tx.done has already transitioned
+       // from 0 to 1. Hold the W-closemu lock prior to rollback
+       // to ensure no other connection has an active query.
+       tx.cancel()
+       tx.closemu.Lock()
+       defer tx.closemu.Unlock()
+
        var err error
        withLock(tx.dc, func() {
                err = tx.txi.Rollback()
@@ -2162,7 +2182,7 @@ func (tx *Tx) rollback(discardConn bool) error {
        if discardConn {
                err = driver.ErrBadConn
        }
-       tx.close(err)
+       tx.closeLocked(err)
        return err
 }
 
index e9b5c8a228df70145c3b543e4e5a2367de9ffb12..41265caed8060cfc72511b474656fd17666e7e20 100644 (file)
@@ -80,6 +80,11 @@ func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB {
                exec(t, db, "CREATE|magicquery|op=string,millis=int32")
                exec(t, db, "INSERT|magicquery|op=sleep,millis=10")
        }
+       if name == "tx_status" {
+               // Magic table name and column, known by fakedb_test.go.
+               exec(t, db, "CREATE|tx_status|tx_status=string")
+               exec(t, db, "INSERT|tx_status|tx_status=invalid")
+       }
        return db
 }
 
@@ -2707,6 +2712,70 @@ func TestManyErrBadConn(t *testing.T) {
        }
 }
 
+// Issue 34755: Ensure that a Tx cannot commit after a rollback.
+func TestTxCannotCommitAfterRollback(t *testing.T) {
+       db := newTestDB(t, "tx_status")
+       defer closeDB(t, db)
+
+       // First check query reporting is correct.
+       var txStatus string
+       err := db.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if g, w := txStatus, "autocommit"; g != w {
+               t.Fatalf("tx_status=%q, wanted %q", g, w)
+       }
+
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
+       tx, err := db.BeginTx(ctx, nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       // Ignore dirty session for this test.
+       // A failing test should trigger the dirty session flag as well,
+       // but that isn't exactly what this should test for.
+       tx.txi.(*fakeTx).c.skipDirtySession = true
+
+       defer tx.Rollback()
+
+       err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if g, w := txStatus, "transaction"; g != w {
+               t.Fatalf("tx_status=%q, wanted %q", g, w)
+       }
+
+       // 1. Begin a transaction.
+       // 2. (A) Start a query, (B) begin Tx rollback through a ctx cancel.
+       // 3. Check if 2.A has committed in Tx (pass) or outside of Tx (fail).
+       sendQuery := make(chan struct{})
+       hookTxGrabConn = func() {
+               cancel()
+               <-sendQuery
+       }
+       rollbackHook = func() {
+               close(sendQuery)
+       }
+       defer func() {
+               hookTxGrabConn = nil
+               rollbackHook = nil
+       }()
+
+       err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
+       if err != nil {
+               // A failure here would be expected if skipDirtySession was not set to true above.
+               t.Fatal(err)
+       }
+       if g, w := txStatus, "transaction"; g != w {
+               t.Fatalf("tx_status=%q, wanted %q", g, w)
+       }
+}
+
 // Issue32530 encounters an issue where a connection may
 // expire right after it comes out of a used connection pool
 // even when a new connection is requested.