Commit ac06c592 authored by Esad Akar's avatar Esad Akar Committed by GitHub

kademlia, hive: rate limiting on gossiping peers (send and receive) (#1654)

parent 479f6658
......@@ -66,6 +66,7 @@ require (
golang.org/x/sys v0.0.0-20210108172913-0df2131ae363
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf
golang.org/x/text v0.3.4 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
golang.org/x/tools v0.0.0-20200626171337-aa94e735be7f // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/ini.v1 v1.57.0 // indirect
......
......@@ -5,3 +5,4 @@
package hive
var MaxBatchSize = maxBatchSize
var LimitBurst = limitBurst
......@@ -12,7 +12,9 @@ package hive
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/ethersphere/bee/pkg/addressbook"
......@@ -22,6 +24,8 @@ import (
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/swarm"
"golang.org/x/time/rate"
)
const (
......@@ -32,6 +36,12 @@ const (
maxBatchSize = 30
)
var (
ErrRateLimitExceeded = errors.New("rate limit exceeded")
limitBurst = 4 * int(swarm.MaxBins)
limitRate = rate.Every(time.Minute)
)
type Service struct {
streamer p2p.Streamer
addressBook addressbook.GetPutter
......@@ -39,6 +49,8 @@ type Service struct {
networkID uint64
logger logging.Logger
metrics metrics
limiter map[string]*rate.Limiter
limiterLock sync.Mutex
}
func New(streamer p2p.Streamer, addressbook addressbook.GetPutter, networkID uint64, logger logging.Logger) *Service {
......@@ -48,6 +60,7 @@ func New(streamer p2p.Streamer, addressbook addressbook.GetPutter, networkID uin
addressBook: addressbook,
networkID: networkID,
metrics: newMetrics(),
limiter: make(map[string]*rate.Limiter),
}
}
......@@ -61,6 +74,8 @@ func (s *Service) Protocol() p2p.ProtocolSpec {
Handler: s.peersHandler,
},
},
DisconnectIn: s.disconnect,
DisconnectOut: s.disconnect,
}
}
......@@ -139,6 +154,11 @@ func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.St
s.metrics.PeersHandlerPeers.Add(float64(len(peersReq.Peers)))
if err := s.rateLimitPeer(peer.Address, len(peersReq.Peers)); err != nil {
_ = stream.Reset()
return err
}
// close the stream before processing in order to unblock the sending side
// fullclose is called async because there is no need to wait for confirmation,
// but we still want to handle not closed stream from the other side to avoid zombie stream
......@@ -169,3 +189,32 @@ func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.St
return nil
}
func (s *Service) rateLimitPeer(peer swarm.Address, count int) error {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
addr := peer.String()
limiter, ok := s.limiter[addr]
if !ok {
limiter = rate.NewLimiter(limitRate, limitBurst)
s.limiter[addr] = limiter
}
if limiter.AllowN(time.Now(), count) {
return nil
}
return ErrRateLimitExceeded
}
func (s *Service) disconnect(peer p2p.Peer) error {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
delete(s.limiter, peer.Address.String())
return nil
}
......@@ -7,6 +7,7 @@ package hive_test
import (
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
"math/rand"
......@@ -27,8 +28,76 @@ import (
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/swarm/test"
)
func TestHandlerRateLimit(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
statestore := mock.NewStateStore()
addressbook := ab.New(statestore)
networkID := uint64(1)
addressbookclean := ab.New(mock.NewStateStore())
// create a hive server that handles the incoming stream
server := hive.New(nil, addressbookclean, networkID, logger)
serverAddress := test.RandomAddress()
// setup the stream recorder to record stream data
serverRecorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
streamtest.WithBaseAddr(serverAddress),
)
peers := make([]swarm.Address, hive.LimitBurst+1)
for i := range peers {
underlay, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/" + strconv.Itoa(i))
if err != nil {
t.Fatal(err)
}
pk, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
signer := crypto.NewDefaultSigner(pk)
overlay, err := crypto.NewOverlayAddress(pk.PublicKey, networkID)
if err != nil {
t.Fatal(err)
}
bzzAddr, err := bzz.NewAddress(signer, underlay, overlay, networkID)
if err != nil {
t.Fatal(err)
}
err = addressbook.Put(bzzAddr.Overlay, *bzzAddr)
if err != nil {
t.Fatal(err)
}
peers[i] = bzzAddr.Overlay
}
// create a hive client that will do broadcast
client := hive.New(serverRecorder, addressbook, networkID, logger)
err := client.BroadcastPeers(context.Background(), serverAddress, peers...)
if err != nil {
t.Fatal(err)
}
// // get a record for this stream
rec, err := serverRecorder.Records(serverAddress, "hive", "1.0.0", "peers")
if err != nil {
t.Fatal(err)
}
lastRec := rec[len(rec)-1]
if !errors.Is(lastRec.Err(), hive.ErrRateLimitExceeded) {
t.Fatal(err)
}
}
func TestBroadcastPeers(t *testing.T) {
rand.Seed(time.Now().UnixNano())
logger := logging.New(ioutil.Discard, 0)
......
......@@ -10,10 +10,13 @@ import (
"errors"
"fmt"
"math"
"math/big"
"math/bits"
"sync"
"time"
random "crypto/rand"
"github.com/ethersphere/bee/pkg/addressbook"
"github.com/ethersphere/bee/pkg/discovery"
"github.com/ethersphere/bee/pkg/logging"
......@@ -38,6 +41,7 @@ var (
shortRetry = 30 * time.Second
saturationPeers = 4
overSaturationPeers = 16
broadcastBinSize = 4
)
type binSaturationFunc func(bin uint8, peers, connected *pslice.PSlice) (saturated bool, oversaturated bool)
......@@ -668,18 +672,20 @@ func (k *Kad) connect(ctx context.Context, peer swarm.Address, ma ma.Multiaddr,
func (k *Kad) Announce(ctx context.Context, peer swarm.Address) error {
addrs := []swarm.Address{}
_ = k.connectedPeers.EachBinRev(func(connectedPeer swarm.Address, _ uint8) (bool, bool, error) {
for bin := uint8(0); bin < swarm.MaxBins; bin++ {
connectedPeers, err := randomSubset(k.connectedPeers.BinPeers(bin), broadcastBinSize)
if err != nil {
return err
}
for _, connectedPeer := range connectedPeers {
if connectedPeer.Equal(peer) {
return false, false, nil
continue
}
addrs = append(addrs, connectedPeer)
// this needs to be in a separate goroutine since a peer we are gossipping to might
// be slow and since this function is called with the same context from kademlia connect
// function, this might result in the unfortunate situation where we end up on
// `err := k.discovery.BroadcastPeers(ctx, peer, addrs...)` with an already expired context
// indicating falsely, that the peer connection has timed out.
k.wg.Add(1)
go func(connectedPeer swarm.Address) {
defer k.wg.Done()
......@@ -687,9 +693,8 @@ func (k *Kad) Announce(ctx context.Context, peer swarm.Address) error {
k.logger.Debugf("could not gossip peer %s to peer %s: %v", peer, connectedPeer, err)
}
}(connectedPeer)
return false, false, nil
})
}
}
if len(addrs) == 0 {
return nil
......@@ -1166,3 +1171,21 @@ func (k *Kad) Close() error {
return nil
}
func randomSubset(addrs []swarm.Address, count int) ([]swarm.Address, error) {
if count >= len(addrs) {
return addrs, nil
}
for i := 0; i < len(addrs); i++ {
b, err := random.Int(random.Reader, big.NewInt(int64(len(addrs))))
if err != nil {
return nil, err
}
j := int(b.Int64())
addrs[i], addrs[j] = addrs[j], addrs[i]
}
return addrs[:count], nil
}
......@@ -95,6 +95,28 @@ func (s *PSlice) EachBinRev(pf topology.EachPeerFunc) error {
return nil
}
func (s *PSlice) BinPeers(bin uint8) []swarm.Address {
s.RLock()
defer s.RUnlock()
b := int(bin)
if b >= len(s.bins) {
return nil
}
var bEnd int
if b == len(s.bins)-1 {
bEnd = len(s.peers)
} else {
bEnd = int(s.bins[b+1])
}
ret := make([]swarm.Address, bEnd-int(s.bins[b]))
copy(ret, s.peers[s.bins[b]:bEnd])
return ret
}
func (s *PSlice) Length() int {
s.RLock()
defer s.RUnlock()
......
......@@ -6,6 +6,7 @@ package pslice_test
import (
"errors"
"sort"
"testing"
"github.com/ethersphere/bee/pkg/swarm"
......@@ -237,6 +238,83 @@ func TestIterators(t *testing.T) {
testIteratorRev(t, ps, false, false, 0, []swarm.Address{})
}
func TestBinPeers(t *testing.T) {
for _, tc := range []struct {
peersCount []int
label string
}{
{
peersCount: []int{0, 0, 0, 0},
label: "bins-empty",
},
{
peersCount: []int{0, 2, 0, 4},
label: "some-bins-empty",
},
{
peersCount: []int{0, 0, 6, 0},
label: "some-bins-empty",
},
{
peersCount: []int{3, 4, 5, 6},
label: "full-bins",
},
} {
t.Run(tc.label, func(t *testing.T) {
binPeers := make([][]swarm.Address, len(tc.peersCount))
// prepare slice
ps := pslice.New(len(tc.peersCount))
for bin, peersCount := range tc.peersCount {
for i := 0; i < peersCount; i++ {
peer := test.RandomAddress()
binPeers[bin] = append(binPeers[bin], peer)
ps.Add(peer, uint8(bin))
}
}
// compare
for bin := range tc.peersCount {
if !isEqual(binPeers[bin], ps.BinPeers(uint8(bin))) {
t.Fatal("peers list do not match")
}
}
// out of bound bin check
bins := ps.BinPeers(uint8(len(tc.peersCount)))
if bins != nil {
t.Fatal("peers must be nil for out of bound bin")
}
})
}
}
func isEqual(a, b []swarm.Address) bool {
if len(a) != len(b) {
return false
}
sort.Slice(a, func(i, j int) bool {
return a[i].String() < a[j].String()
})
sort.Slice(b, func(i, j int) bool {
return b[i].String() < b[j].String()
})
for i, addr := range a {
if !b[i].Equal(addr) {
return false
}
}
return true
}
// TestIteratorsJumpStop tests that the EachBin and EachBinRev iterators jump to next bin and stop as expected.
func TestIteratorsJumpStop(t *testing.T) {
ps := pslice.New(4)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment