]> Cypherpunks.ru repositories - udpobfs.git/blobdiff - cmd/resp/main.go
RWMutex is not that bad
[udpobfs.git] / cmd / resp / main.go
index cc973368855855ed6671693cb55cc6239712c7e3ceb045a0311970977eccdb16..75452990263b7392253f3389a32572e538d2b9797c00f5e6297f431e4e9bcce7 100644 (file)
@@ -37,6 +37,7 @@ var (
        LnUDP      *net.UDPConn
        LnTCP      *net.TCPListener
        Peers      = make(map[string]chan udpobfs.Buf)
+       PeersM     sync.RWMutex
        Bufs       = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }}
 )
 
@@ -67,6 +68,9 @@ func newPeer(conn net.Conn) {
        }
        cryptoState := udpobfs.NewCryptoState(seed, false)
        txs := make(chan udpobfs.Buf)
+       PeersM.Lock()
+       Peers[remoteAddr.String()] = txs
+       PeersM.Unlock()
        txFinished := make(chan struct{})
        var rxPkts, txPkts, rxBytes, txBytes int64
        go func() {
@@ -128,6 +132,7 @@ func newPeer(conn net.Conn) {
                buf := make([]byte, udpobfs.BufLen)
                var tx udpobfs.Buf
                var got []byte
+               var ok bool
                for {
                        select {
                        case <-txFinished:
@@ -138,9 +143,9 @@ func newPeer(conn net.Conn) {
                                        localUDP.Close()
                                        return
                                }
-                       case tx = <-txs:
-                               if tx.Buf == nil {
-                                       break
+                       case tx, ok = <-txs:
+                               if !ok {
+                                       return
                                }
                                if tx.N < udpobfs.SeqLen {
                                        logger.Warn("too short")
@@ -162,7 +167,6 @@ func newPeer(conn net.Conn) {
                        }
                }
        }()
-       Peers[remoteAddr.String()] = txs
        go func() {
                buf := make([]byte, 8)
                for {
@@ -181,12 +185,10 @@ func newPeer(conn net.Conn) {
                "rxBytes", rxBytes,
                "txPkts", txPkts,
                "txBytes", txBytes)
+       PeersM.Lock()
        delete(Peers, remoteAddr.String())
-       txs <- udpobfs.Buf{Buf: nil}
-       go func() {
-               for range txs {
-               }
-       }()
+       PeersM.Unlock()
+       close(txs)
 }
 
 func main() {
@@ -240,10 +242,12 @@ func main() {
                        if n == 0 {
                                continue
                        }
+                       PeersM.RLock()
                        txs = Peers[from.String()]
                        if txs != nil {
                                txs <- udpobfs.Buf{Buf: buf, N: n}
                        }
+                       PeersM.RUnlock()
                }
        }()