From 7c5da149c92ff9372de2b179f9f80932a004f7f525374a240d7663fcfa289395 Mon Sep 17 00:00:00 2001 From: Sergey Matveev Date: Fri, 20 Oct 2023 00:39:24 +0300 Subject: [PATCH] RWMutex is not that bad --- cmd/init/main.go | 34 +++++++++++++++++++--------------- cmd/resp/main.go | 22 +++++++++++++--------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/cmd/init/main.go b/cmd/init/main.go index 77b8213..3564dbf 100644 --- a/cmd/init/main.go +++ b/cmd/init/main.go @@ -44,6 +44,7 @@ var ( TLSConfig *tls.Config LnUDP *net.UDPConn Peers = make(map[string]chan udpobfs.Buf) + PeersM sync.RWMutex Bufs = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }} ) @@ -83,8 +84,17 @@ func newPeer(localAddr net.Addr, dataInitial []byte) { } cryptoState := udpobfs.NewCryptoState(seed, true) txs := make(chan udpobfs.Buf) - rxFinished := make(chan struct{}) + PeersM.Lock() + Peers[localAddr.String()] = txs + PeersM.Unlock() var rxPkts, txPkts, rxBytes, txBytes int64 + { + txPkts++ + txBytes += int64(len(dataInitial)) + tmp := make([]byte, udpobfs.SeqLen+len(dataInitial)) + connUDP.WriteTo(cryptoState.Tx(tmp, dataInitial), DstAddrUDP) + } + rxFinished := make(chan struct{}) go func() { var n int var err error @@ -129,6 +139,7 @@ func newPeer(localAddr net.Addr, dataInitial []byte) { lastPing := now last := now var got []byte + var ok bool for { select { case <-ticker.C: @@ -142,8 +153,8 @@ func newPeer(localAddr net.Addr, dataInitial []byte) { cryptoState.Tx(buf[:udpobfs.SeqLen], nil), DstAddrUDP) lastPing = now } - case tx = <-txs: - if tx.Buf == nil { + case tx, ok = <-txs: + if !ok { return } got = cryptoState.Tx(buf[:udpobfs.SeqLen+tx.N], (*tx.Buf)[:tx.N]) @@ -156,13 +167,6 @@ func newPeer(localAddr net.Addr, dataInitial []byte) { } } }() - Peers[localAddr.String()] = txs - { - txPkts++ - txBytes += int64(len(dataInitial)) - tmp := make([]byte, udpobfs.SeqLen+len(dataInitial)) - connUDP.WriteTo(cryptoState.Tx(tmp, dataInitial), DstAddrUDP) - } go func() { defer connUDP.Close() ticker := time.NewTicker(udpobfs.LifetimeDuration) @@ -202,12 +206,10 @@ func newPeer(localAddr net.Addr, dataInitial []byte) { "rxBytes", rxBytes, "txPkts", txPkts, "txBytes", txBytes) + PeersM.Lock() delete(Peers, localAddr.String()) - txs <- udpobfs.Buf{Buf: nil} - go func() { - for range txs { - } - }() + PeersM.Unlock() + close(txs) } func main() { @@ -277,10 +279,12 @@ func main() { if n == 0 { continue } + PeersM.RLock() txs = Peers[from.String()] if txs != nil { txs <- udpobfs.Buf{Buf: buf, N: n} } + PeersM.RUnlock() if txs == nil { neu := make([]byte, n) copy(neu, (*buf)[:n]) diff --git a/cmd/resp/main.go b/cmd/resp/main.go index cc97336..7545299 100644 --- a/cmd/resp/main.go +++ b/cmd/resp/main.go @@ -37,6 +37,7 @@ var ( LnUDP *net.UDPConn LnTCP *net.TCPListener Peers = make(map[string]chan udpobfs.Buf) + PeersM sync.RWMutex Bufs = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }} ) @@ -67,6 +68,9 @@ func newPeer(conn net.Conn) { } cryptoState := udpobfs.NewCryptoState(seed, false) txs := make(chan udpobfs.Buf) + PeersM.Lock() + Peers[remoteAddr.String()] = txs + PeersM.Unlock() txFinished := make(chan struct{}) var rxPkts, txPkts, rxBytes, txBytes int64 go func() { @@ -128,6 +132,7 @@ func newPeer(conn net.Conn) { buf := make([]byte, udpobfs.BufLen) var tx udpobfs.Buf var got []byte + var ok bool for { select { case <-txFinished: @@ -138,9 +143,9 @@ func newPeer(conn net.Conn) { localUDP.Close() return } - case tx = <-txs: - if tx.Buf == nil { - break + case tx, ok = <-txs: + if !ok { + return } if tx.N < udpobfs.SeqLen { logger.Warn("too short") @@ -162,7 +167,6 @@ func newPeer(conn net.Conn) { } } }() - Peers[remoteAddr.String()] = txs go func() { buf := make([]byte, 8) for { @@ -181,12 +185,10 @@ func newPeer(conn net.Conn) { "rxBytes", rxBytes, "txPkts", txPkts, "txBytes", txBytes) + PeersM.Lock() delete(Peers, remoteAddr.String()) - txs <- udpobfs.Buf{Buf: nil} - go func() { - for range txs { - } - }() + PeersM.Unlock() + close(txs) } func main() { @@ -240,10 +242,12 @@ func main() { if n == 0 { continue } + PeersM.RLock() txs = Peers[from.String()] if txs != nil { txs <- udpobfs.Buf{Buf: buf, N: n} } + PeersM.RUnlock() } }() -- 2.44.0