]> Cypherpunks.ru repositories - nncp.git/blobdiff - src/sp.go
Fix files closing race when call is finished
[nncp.git] / src / sp.go
index 6d31375e137ff3a044d0bb5c1ad8ea4c481b4dd0..0b039e6bffb9f80f6eecff80b36d8ce985552215 100644 (file)
--- a/src/sp.go
+++ b/src/sp.go
@@ -236,6 +236,7 @@ type SPState struct {
        fdsLock        sync.RWMutex
        fileHashers    map[string]*HasherAndOffset
        checkerQueues  SPCheckerQueues
+       progressBars   map[string]struct{}
        sync.RWMutex
 }
 
@@ -257,11 +258,6 @@ func (state *SPState) SetDead() {
                for range state.pings {
                }
        }()
-       go func() {
-               for _, s := range state.fds {
-                       s.fd.Close()
-               }
-       }()
 }
 
 func (state *SPState) NotAlive() bool {
@@ -441,6 +437,7 @@ func (state *SPState) StartI(conn ConnDeadlined) error {
        state.pings = make(chan struct{})
        state.infosTheir = make(map[[32]byte]*SPInfo)
        state.infosOurSeen = make(map[[32]byte]uint8)
+       state.progressBars = make(map[string]struct{})
        state.started = started
        state.rxLock = rxLock
        state.txLock = txLock
@@ -558,6 +555,7 @@ func (state *SPState) StartR(conn ConnDeadlined) error {
        state.pings = make(chan struct{})
        state.infosOurSeen = make(map[[32]byte]uint8)
        state.infosTheir = make(map[[32]byte]*SPInfo)
+       state.progressBars = make(map[string]struct{})
        state.started = started
        state.xxOnly = xxOnly
 
@@ -792,13 +790,20 @@ func (state *SPState) StartWorkers(
                                pingTicker.Stop()
                                return
                        case now := <-deadlineTicker.C:
-                               if (now.Sub(state.RxLastNonPing) >= state.onlineDeadline &&
-                                       now.Sub(state.TxLastNonPing) >= state.onlineDeadline) ||
-                                       (state.maxOnlineTime > 0 && state.mustFinishAt.Before(now)) ||
-                                       (now.Sub(state.RxLastSeen) >= 2*PingTimeout) {
-                                       state.SetDead()
-                                       conn.Close() // #nosec G104
+                               if now.Sub(state.RxLastNonPing) >= state.onlineDeadline &&
+                                       now.Sub(state.TxLastNonPing) >= state.onlineDeadline {
+                                       goto Deadlined
                                }
+                               if state.maxOnlineTime > 0 && state.mustFinishAt.Before(now) {
+                                       goto Deadlined
+                               }
+                               if now.Sub(state.RxLastSeen) >= 2*PingTimeout {
+                                       goto Deadlined
+                               }
+                               break
+                       Deadlined:
+                               state.SetDead()
+                               conn.Close() // #nosec G104
                        case now := <-pingTicker.C:
                                if now.After(state.TxLastSeen.Add(PingTimeout)) {
                                        state.wg.Add(1)
@@ -989,6 +994,7 @@ func (state *SPState) StartWorkers(
                                        LE{"FullSize", fullSize},
                                )
                                if state.Ctx.ShowPrgrs {
+                                       state.progressBars[pktName] = struct{}{}
                                        Progress("Tx", lesp)
                                }
                                state.Lock()
@@ -1002,6 +1008,9 @@ func (state *SPState) StartWorkers(
                                                } else {
                                                        state.queueTheir = state.queueTheir[:0]
                                                }
+                                               if state.Ctx.ShowPrgrs {
+                                                       delete(state.progressBars, pktName)
+                                               }
                                        } else {
                                                state.queueTheir[0].freq.Offset += uint64(len(buf))
                                        }
@@ -1138,6 +1147,12 @@ func (state *SPState) Wait() {
        if txDuration > 0 {
                state.TxSpeed = state.TxBytes / txDuration
        }
+       for _, s := range state.fds {
+               s.fd.Close()
+       }
+       for pktName := range state.progressBars {
+               ProgressKill(pktName)
+       }
 }
 
 func (state *SPState) ProcessSP(payload []byte) ([][]byte, error) {
@@ -1439,11 +1454,15 @@ func (state *SPState) ProcessSP(payload []byte) ([][]byte, error) {
                        }
                        lesp = append(lesp, LE{"FullSize", fullsize})
                        if state.Ctx.ShowPrgrs {
+                               state.progressBars[pktName] = struct{}{}
                                Progress("Rx", lesp)
                        }
                        if fullsize != ourSize {
                                continue
                        }
+                       if state.Ctx.ShowPrgrs {
+                               delete(state.progressBars, pktName)
+                       }
                        logMsg = func(les LEs) string {
                                return fmt.Sprintf(
                                        "Got packet %s %d%% (%s / %s)",