Commit ac06c592 authored by Esad Akar's avatar Esad Akar Committed by GitHub

kademlia, hive: rate limiting on gossiping peers (send and receive) (#1654)

parent 479f6658
...@@ -66,6 +66,7 @@ require ( ...@@ -66,6 +66,7 @@ require (
golang.org/x/sys v0.0.0-20210108172913-0df2131ae363 golang.org/x/sys v0.0.0-20210108172913-0df2131ae363
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf
golang.org/x/text v0.3.4 // indirect golang.org/x/text v0.3.4 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
golang.org/x/tools v0.0.0-20200626171337-aa94e735be7f // indirect golang.org/x/tools v0.0.0-20200626171337-aa94e735be7f // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/ini.v1 v1.57.0 // indirect gopkg.in/ini.v1 v1.57.0 // indirect
......
...@@ -5,3 +5,4 @@ ...@@ -5,3 +5,4 @@
package hive package hive
var MaxBatchSize = maxBatchSize var MaxBatchSize = maxBatchSize
var LimitBurst = limitBurst
...@@ -12,7 +12,9 @@ package hive ...@@ -12,7 +12,9 @@ package hive
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/ethersphere/bee/pkg/addressbook" "github.com/ethersphere/bee/pkg/addressbook"
...@@ -22,6 +24,8 @@ import ( ...@@ -22,6 +24,8 @@ import (
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"golang.org/x/time/rate"
) )
const ( const (
...@@ -32,6 +36,12 @@ const ( ...@@ -32,6 +36,12 @@ const (
maxBatchSize = 30 maxBatchSize = 30
) )
var (
ErrRateLimitExceeded = errors.New("rate limit exceeded")
limitBurst = 4 * int(swarm.MaxBins)
limitRate = rate.Every(time.Minute)
)
type Service struct { type Service struct {
streamer p2p.Streamer streamer p2p.Streamer
addressBook addressbook.GetPutter addressBook addressbook.GetPutter
...@@ -39,6 +49,8 @@ type Service struct { ...@@ -39,6 +49,8 @@ type Service struct {
networkID uint64 networkID uint64
logger logging.Logger logger logging.Logger
metrics metrics metrics metrics
limiter map[string]*rate.Limiter
limiterLock sync.Mutex
} }
func New(streamer p2p.Streamer, addressbook addressbook.GetPutter, networkID uint64, logger logging.Logger) *Service { func New(streamer p2p.Streamer, addressbook addressbook.GetPutter, networkID uint64, logger logging.Logger) *Service {
...@@ -48,6 +60,7 @@ func New(streamer p2p.Streamer, addressbook addressbook.GetPutter, networkID uin ...@@ -48,6 +60,7 @@ func New(streamer p2p.Streamer, addressbook addressbook.GetPutter, networkID uin
addressBook: addressbook, addressBook: addressbook,
networkID: networkID, networkID: networkID,
metrics: newMetrics(), metrics: newMetrics(),
limiter: make(map[string]*rate.Limiter),
} }
} }
...@@ -61,6 +74,8 @@ func (s *Service) Protocol() p2p.ProtocolSpec { ...@@ -61,6 +74,8 @@ func (s *Service) Protocol() p2p.ProtocolSpec {
Handler: s.peersHandler, Handler: s.peersHandler,
}, },
}, },
DisconnectIn: s.disconnect,
DisconnectOut: s.disconnect,
} }
} }
...@@ -139,6 +154,11 @@ func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.St ...@@ -139,6 +154,11 @@ func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.St
s.metrics.PeersHandlerPeers.Add(float64(len(peersReq.Peers))) s.metrics.PeersHandlerPeers.Add(float64(len(peersReq.Peers)))
if err := s.rateLimitPeer(peer.Address, len(peersReq.Peers)); err != nil {
_ = stream.Reset()
return err
}
// close the stream before processing in order to unblock the sending side // close the stream before processing in order to unblock the sending side
// fullclose is called async because there is no need to wait for confirmation, // fullclose is called async because there is no need to wait for confirmation,
// but we still want to handle not closed stream from the other side to avoid zombie stream // but we still want to handle not closed stream from the other side to avoid zombie stream
...@@ -169,3 +189,32 @@ func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.St ...@@ -169,3 +189,32 @@ func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.St
return nil return nil
} }
func (s *Service) rateLimitPeer(peer swarm.Address, count int) error {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
addr := peer.String()
limiter, ok := s.limiter[addr]
if !ok {
limiter = rate.NewLimiter(limitRate, limitBurst)
s.limiter[addr] = limiter
}
if limiter.AllowN(time.Now(), count) {
return nil
}
return ErrRateLimitExceeded
}
func (s *Service) disconnect(peer p2p.Peer) error {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
delete(s.limiter, peer.Address.String())
return nil
}
...@@ -7,6 +7,7 @@ package hive_test ...@@ -7,6 +7,7 @@ package hive_test
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
...@@ -27,8 +28,76 @@ import ( ...@@ -27,8 +28,76 @@ import (
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/statestore/mock" "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/swarm/test"
) )
func TestHandlerRateLimit(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
statestore := mock.NewStateStore()
addressbook := ab.New(statestore)
networkID := uint64(1)
addressbookclean := ab.New(mock.NewStateStore())
// create a hive server that handles the incoming stream
server := hive.New(nil, addressbookclean, networkID, logger)
serverAddress := test.RandomAddress()
// setup the stream recorder to record stream data
serverRecorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
streamtest.WithBaseAddr(serverAddress),
)
peers := make([]swarm.Address, hive.LimitBurst+1)
for i := range peers {
underlay, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/" + strconv.Itoa(i))
if err != nil {
t.Fatal(err)
}
pk, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
signer := crypto.NewDefaultSigner(pk)
overlay, err := crypto.NewOverlayAddress(pk.PublicKey, networkID)
if err != nil {
t.Fatal(err)
}
bzzAddr, err := bzz.NewAddress(signer, underlay, overlay, networkID)
if err != nil {
t.Fatal(err)
}
err = addressbook.Put(bzzAddr.Overlay, *bzzAddr)
if err != nil {
t.Fatal(err)
}
peers[i] = bzzAddr.Overlay
}
// create a hive client that will do broadcast
client := hive.New(serverRecorder, addressbook, networkID, logger)
err := client.BroadcastPeers(context.Background(), serverAddress, peers...)
if err != nil {
t.Fatal(err)
}
// // get a record for this stream
rec, err := serverRecorder.Records(serverAddress, "hive", "1.0.0", "peers")
if err != nil {
t.Fatal(err)
}
lastRec := rec[len(rec)-1]
if !errors.Is(lastRec.Err(), hive.ErrRateLimitExceeded) {
t.Fatal(err)
}
}
func TestBroadcastPeers(t *testing.T) { func TestBroadcastPeers(t *testing.T) {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
......
...@@ -10,10 +10,13 @@ import ( ...@@ -10,10 +10,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"math" "math"
"math/big"
"math/bits" "math/bits"
"sync" "sync"
"time" "time"
random "crypto/rand"
"github.com/ethersphere/bee/pkg/addressbook" "github.com/ethersphere/bee/pkg/addressbook"
"github.com/ethersphere/bee/pkg/discovery" "github.com/ethersphere/bee/pkg/discovery"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
...@@ -38,6 +41,7 @@ var ( ...@@ -38,6 +41,7 @@ var (
shortRetry = 30 * time.Second shortRetry = 30 * time.Second
saturationPeers = 4 saturationPeers = 4
overSaturationPeers = 16 overSaturationPeers = 16
broadcastBinSize = 4
) )
type binSaturationFunc func(bin uint8, peers, connected *pslice.PSlice) (saturated bool, oversaturated bool) type binSaturationFunc func(bin uint8, peers, connected *pslice.PSlice) (saturated bool, oversaturated bool)
...@@ -668,28 +672,29 @@ func (k *Kad) connect(ctx context.Context, peer swarm.Address, ma ma.Multiaddr, ...@@ -668,28 +672,29 @@ func (k *Kad) connect(ctx context.Context, peer swarm.Address, ma ma.Multiaddr,
func (k *Kad) Announce(ctx context.Context, peer swarm.Address) error { func (k *Kad) Announce(ctx context.Context, peer swarm.Address) error {
addrs := []swarm.Address{} addrs := []swarm.Address{}
_ = k.connectedPeers.EachBinRev(func(connectedPeer swarm.Address, _ uint8) (bool, bool, error) { for bin := uint8(0); bin < swarm.MaxBins; bin++ {
if connectedPeer.Equal(peer) {
return false, false, nil connectedPeers, err := randomSubset(k.connectedPeers.BinPeers(bin), broadcastBinSize)
if err != nil {
return err
} }
addrs = append(addrs, connectedPeer) for _, connectedPeer := range connectedPeers {
if connectedPeer.Equal(peer) {
// this needs to be in a separate goroutine since a peer we are gossipping to might continue
// be slow and since this function is called with the same context from kademlia connect
// function, this might result in the unfortunate situation where we end up on
// `err := k.discovery.BroadcastPeers(ctx, peer, addrs...)` with an already expired context
// indicating falsely, that the peer connection has timed out.
k.wg.Add(1)
go func(connectedPeer swarm.Address) {
defer k.wg.Done()
if err := k.discovery.BroadcastPeers(context.Background(), connectedPeer, peer); err != nil {
k.logger.Debugf("could not gossip peer %s to peer %s: %v", peer, connectedPeer, err)
} }
}(connectedPeer)
return false, false, nil addrs = append(addrs, connectedPeer)
})
k.wg.Add(1)
go func(connectedPeer swarm.Address) {
defer k.wg.Done()
if err := k.discovery.BroadcastPeers(context.Background(), connectedPeer, peer); err != nil {
k.logger.Debugf("could not gossip peer %s to peer %s: %v", peer, connectedPeer, err)
}
}(connectedPeer)
}
}
if len(addrs) == 0 { if len(addrs) == 0 {
return nil return nil
...@@ -1166,3 +1171,21 @@ func (k *Kad) Close() error { ...@@ -1166,3 +1171,21 @@ func (k *Kad) Close() error {
return nil return nil
} }
func randomSubset(addrs []swarm.Address, count int) ([]swarm.Address, error) {
if count >= len(addrs) {
return addrs, nil
}
for i := 0; i < len(addrs); i++ {
b, err := random.Int(random.Reader, big.NewInt(int64(len(addrs))))
if err != nil {
return nil, err
}
j := int(b.Int64())
addrs[i], addrs[j] = addrs[j], addrs[i]
}
return addrs[:count], nil
}
...@@ -95,6 +95,28 @@ func (s *PSlice) EachBinRev(pf topology.EachPeerFunc) error { ...@@ -95,6 +95,28 @@ func (s *PSlice) EachBinRev(pf topology.EachPeerFunc) error {
return nil return nil
} }
func (s *PSlice) BinPeers(bin uint8) []swarm.Address {
s.RLock()
defer s.RUnlock()
b := int(bin)
if b >= len(s.bins) {
return nil
}
var bEnd int
if b == len(s.bins)-1 {
bEnd = len(s.peers)
} else {
bEnd = int(s.bins[b+1])
}
ret := make([]swarm.Address, bEnd-int(s.bins[b]))
copy(ret, s.peers[s.bins[b]:bEnd])
return ret
}
func (s *PSlice) Length() int { func (s *PSlice) Length() int {
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
......
...@@ -6,6 +6,7 @@ package pslice_test ...@@ -6,6 +6,7 @@ package pslice_test
import ( import (
"errors" "errors"
"sort"
"testing" "testing"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -237,6 +238,83 @@ func TestIterators(t *testing.T) { ...@@ -237,6 +238,83 @@ func TestIterators(t *testing.T) {
testIteratorRev(t, ps, false, false, 0, []swarm.Address{}) testIteratorRev(t, ps, false, false, 0, []swarm.Address{})
} }
func TestBinPeers(t *testing.T) {
for _, tc := range []struct {
peersCount []int
label string
}{
{
peersCount: []int{0, 0, 0, 0},
label: "bins-empty",
},
{
peersCount: []int{0, 2, 0, 4},
label: "some-bins-empty",
},
{
peersCount: []int{0, 0, 6, 0},
label: "some-bins-empty",
},
{
peersCount: []int{3, 4, 5, 6},
label: "full-bins",
},
} {
t.Run(tc.label, func(t *testing.T) {
binPeers := make([][]swarm.Address, len(tc.peersCount))
// prepare slice
ps := pslice.New(len(tc.peersCount))
for bin, peersCount := range tc.peersCount {
for i := 0; i < peersCount; i++ {
peer := test.RandomAddress()
binPeers[bin] = append(binPeers[bin], peer)
ps.Add(peer, uint8(bin))
}
}
// compare
for bin := range tc.peersCount {
if !isEqual(binPeers[bin], ps.BinPeers(uint8(bin))) {
t.Fatal("peers list do not match")
}
}
// out of bound bin check
bins := ps.BinPeers(uint8(len(tc.peersCount)))
if bins != nil {
t.Fatal("peers must be nil for out of bound bin")
}
})
}
}
func isEqual(a, b []swarm.Address) bool {
if len(a) != len(b) {
return false
}
sort.Slice(a, func(i, j int) bool {
return a[i].String() < a[j].String()
})
sort.Slice(b, func(i, j int) bool {
return b[i].String() < b[j].String()
})
for i, addr := range a {
if !b[i].Equal(addr) {
return false
}
}
return true
}
// TestIteratorsJumpStop tests that the EachBin and EachBinRev iterators jump to next bin and stop as expected. // TestIteratorsJumpStop tests that the EachBin and EachBinRev iterators jump to next bin and stop as expected.
func TestIteratorsJumpStop(t *testing.T) { func TestIteratorsJumpStop(t *testing.T) {
ps := pslice.New(4) ps := pslice.New(4)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment