Commit 6160b927 authored by acud's avatar acud Committed by GitHub

kademlia: announce in a separate goroutine (#384)

* kademlia: announce to other peers in a separate goroutine
parent e3743a4d
...@@ -65,6 +65,7 @@ type Kad struct { ...@@ -65,6 +65,7 @@ type Kad struct {
logger logging.Logger // logger logger logging.Logger // logger
quit chan struct{} // quit channel quit chan struct{} // quit channel
done chan struct{} // signal that `manage` has quit done chan struct{} // signal that `manage` has quit
wg sync.WaitGroup
} }
type retryInfo struct { type retryInfo struct {
...@@ -91,8 +92,9 @@ func New(o Options) *Kad { ...@@ -91,8 +92,9 @@ func New(o Options) *Kad {
logger: o.Logger, logger: o.Logger,
quit: make(chan struct{}), quit: make(chan struct{}),
done: make(chan struct{}), done: make(chan struct{}),
wg: sync.WaitGroup{},
} }
k.wg.Add(1)
go k.manage() go k.manage()
return k return k
} }
...@@ -105,7 +107,9 @@ func (k *Kad) manage() { ...@@ -105,7 +107,9 @@ func (k *Kad) manage() {
start time.Time start time.Time
) )
defer k.wg.Done()
defer close(k.done) defer close(k.done)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go func() { go func() {
<-k.quit <-k.quit
...@@ -321,12 +325,22 @@ func (k *Kad) announce(ctx context.Context, peer swarm.Address) error { ...@@ -321,12 +325,22 @@ func (k *Kad) announce(ctx context.Context, peer swarm.Address) error {
if connectedPeer.Equal(peer) { if connectedPeer.Equal(peer) {
return false, false, nil return false, false, nil
} }
addrs = append(addrs, connectedPeer) addrs = append(addrs, connectedPeer)
if err := k.discovery.BroadcastPeers(ctx, connectedPeer, peer); err != nil {
// we don't want to fail the whole process because of this, keep on gossiping // this needs to be in a separate goroutine since a peer we are gossipping to might
k.logger.Debugf("error gossiping peer %s to peer %s: %v", peer, connectedPeer, err) // be slow and since this function is called with the same context from kademlia connect
return false, false, nil // function, this might result in the unfortunate situation where we end up on
} // `err := k.discovery.BroadcastPeers(ctx, peer, addrs...)` with an already expired context
// indicating falsely, that the peer connection has timed out.
k.wg.Add(1)
go func(connectedPeer swarm.Address) {
defer k.wg.Done()
if err := k.discovery.BroadcastPeers(context.Background(), connectedPeer, peer); err != nil {
k.logger.Debugf("error gossiping peer %s to peer %s: %v", peer, connectedPeer, err)
}
}(connectedPeer)
return false, false, nil return false, false, nil
}) })
...@@ -338,6 +352,7 @@ func (k *Kad) announce(ctx context.Context, peer swarm.Address) error { ...@@ -338,6 +352,7 @@ func (k *Kad) announce(ctx context.Context, peer swarm.Address) error {
if err != nil { if err != nil {
_ = k.p2p.Disconnect(peer) _ = k.p2p.Disconnect(peer)
} }
return err return err
} }
...@@ -620,11 +635,26 @@ func (k *Kad) String() string { ...@@ -620,11 +635,26 @@ func (k *Kad) String() string {
// Close shuts down kademlia. // Close shuts down kademlia.
func (k *Kad) Close() error { func (k *Kad) Close() error {
k.logger.Info("kademlia shutting down")
close(k.quit) close(k.quit)
cc := make(chan struct{})
go func() {
defer close(cc)
k.wg.Wait()
}()
select {
case <-cc:
case <-time.After(10 * time.Second):
k.logger.Warning("kademlia shutting down with announce goroutines")
}
select { select {
case <-k.done: case <-k.done:
case <-time.After(3 * time.Second): case <-time.After(5 * time.Second):
k.logger.Warning("kademlia manage loop did not shut down properly") k.logger.Warning("kademlia manage loop did not shut down properly")
} }
return nil return nil
} }
...@@ -739,7 +739,7 @@ func waitCounter(t *testing.T, conns *int32, exp int32) { ...@@ -739,7 +739,7 @@ func waitCounter(t *testing.T, conns *int32, exp int32) {
// wait for discovery BroadcastPeers to happen // wait for discovery BroadcastPeers to happen
func waitBcast(t *testing.T, d *mock.Discovery, pivot swarm.Address, addrs ...swarm.Address) { func waitBcast(t *testing.T, d *mock.Discovery, pivot swarm.Address, addrs ...swarm.Address) {
t.Helper() t.Helper()
time.Sleep(50 * time.Millisecond)
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
if d.Broadcasts() > 0 { if d.Broadcasts() > 0 {
recs, ok := d.AddresseeRecords(pivot) recs, ok := d.AddresseeRecords(pivot)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment