Commit ea882b52 authored by Esad's avatar Esad Committed by GitHub

feat: rate limit hive broadcast calls (#2235)

parent 3c91281f
...@@ -23,10 +23,9 @@ import ( ...@@ -23,10 +23,9 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/ratelimit"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
"golang.org/x/time/rate"
) )
const ( const (
...@@ -38,9 +37,10 @@ const ( ...@@ -38,9 +37,10 @@ const (
) )
var ( var (
limitBurst = 4 * int(swarm.MaxBins)
limitRate = time.Minute
ErrRateLimitExceeded = errors.New("rate limit exceeded") ErrRateLimitExceeded = errors.New("rate limit exceeded")
limitBurst = 4 * int(swarm.MaxBins)
limitRate = rate.Every(time.Minute)
) )
type Service struct { type Service struct {
...@@ -50,8 +50,9 @@ type Service struct { ...@@ -50,8 +50,9 @@ type Service struct {
networkID uint64 networkID uint64
logger logging.Logger logger logging.Logger
metrics metrics metrics metrics
limiter map[string]*rate.Limiter inLimiter *ratelimit.Limiter
limiterLock sync.Mutex outLimiter *ratelimit.Limiter
clearMtx sync.Mutex
} }
func New(streamer p2p.Streamer, addressbook addressbook.GetPutter, networkID uint64, logger logging.Logger) *Service { func New(streamer p2p.Streamer, addressbook addressbook.GetPutter, networkID uint64, logger logging.Logger) *Service {
...@@ -61,7 +62,8 @@ func New(streamer p2p.Streamer, addressbook addressbook.GetPutter, networkID uin ...@@ -61,7 +62,8 @@ func New(streamer p2p.Streamer, addressbook addressbook.GetPutter, networkID uin
addressBook: addressbook, addressBook: addressbook,
networkID: networkID, networkID: networkID,
metrics: newMetrics(), metrics: newMetrics(),
limiter: make(map[string]*rate.Limiter), inLimiter: ratelimit.New(limitRate, limitBurst),
outLimiter: ratelimit.New(limitRate, limitBurst),
} }
} }
...@@ -89,6 +91,12 @@ func (s *Service) BroadcastPeers(ctx context.Context, addressee swarm.Address, p ...@@ -89,6 +91,12 @@ func (s *Service) BroadcastPeers(ctx context.Context, addressee swarm.Address, p
if max > len(peers) { if max > len(peers) {
max = len(peers) max = len(peers)
} }
// If broadcasting limit is exceeded, return early
if !s.outLimiter.Allow(addressee.ByteString(), max) {
return nil
}
if err := s.sendPeers(ctx, addressee, peers[:max]); err != nil { if err := s.sendPeers(ctx, addressee, peers[:max]); err != nil {
return err return err
} }
...@@ -158,9 +166,9 @@ func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.St ...@@ -158,9 +166,9 @@ func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.St
s.metrics.PeersHandlerPeers.Add(float64(len(peersReq.Peers))) s.metrics.PeersHandlerPeers.Add(float64(len(peersReq.Peers)))
if err := s.rateLimitPeer(peer.Address, len(peersReq.Peers)); err != nil { if !s.inLimiter.Allow(peer.Address.ByteString(), len(peersReq.Peers)) {
_ = stream.Reset() _ = stream.Reset()
return err return ErrRateLimitExceeded
} }
// close the stream before processing in order to unblock the sending side // close the stream before processing in order to unblock the sending side
...@@ -200,31 +208,13 @@ func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.St ...@@ -200,31 +208,13 @@ func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.St
return nil return nil
} }
func (s *Service) rateLimitPeer(peer swarm.Address, count int) error {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
addr := peer.ByteString()
limiter, ok := s.limiter[addr]
if !ok {
limiter = rate.NewLimiter(limitRate, limitBurst)
s.limiter[addr] = limiter
}
if limiter.AllowN(time.Now(), count) {
return nil
}
return ErrRateLimitExceeded
}
func (s *Service) disconnect(peer p2p.Peer) error { func (s *Service) disconnect(peer p2p.Peer) error {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
delete(s.limiter, peer.Address.String()) s.clearMtx.Lock()
defer s.clearMtx.Unlock()
s.inLimiter.Clear(peer.Address.ByteString())
s.outLimiter.Clear(peer.Address.ByteString())
return nil return nil
} }
...@@ -7,7 +7,6 @@ package hive_test ...@@ -7,7 +7,6 @@ package hive_test
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
...@@ -98,8 +97,9 @@ func TestHandlerRateLimit(t *testing.T) { ...@@ -98,8 +97,9 @@ func TestHandlerRateLimit(t *testing.T) {
} }
lastRec := rec[len(rec)-1] lastRec := rec[len(rec)-1]
if !errors.Is(lastRec.Err(), hive.ErrRateLimitExceeded) {
t.Fatal(err) if lastRec.Err() != nil {
t.Fatal("want nil error")
} }
} }
......
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ratelimit provides a mechanism to rate limit requests based on a string key,
// refill rate and burst amount. Under the hood, it's a token bucket of size burst amount,
// that refills at the refill rate.
package ratelimit
import (
"sync"
"time"
"golang.org/x/time/rate"
)
type Limiter struct {
mtx sync.Mutex
limiter map[string]*rate.Limiter
rate rate.Limit
burst int
}
// New returns a new Limiter object with refresh rate and burst amount
func New(r time.Duration, burst int) *Limiter {
return &Limiter{
limiter: make(map[string]*rate.Limiter),
rate: rate.Every(r),
burst: burst,
}
}
// Allow checks if the limiter that belongs to 'key' has not exceeded the limit.
func (l *Limiter) Allow(key string, count int) bool {
l.mtx.Lock()
defer l.mtx.Unlock()
limiter, ok := l.limiter[key]
if !ok {
limiter = rate.NewLimiter(l.rate, l.burst)
l.limiter[key] = limiter
}
return limiter.AllowN(time.Now(), count)
}
// Clear deletes the limiter that belongs to 'key'
func (l *Limiter) Clear(key string) {
l.mtx.Lock()
defer l.mtx.Unlock()
delete(l.limiter, key)
}
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ratelimit_test
import (
"testing"
"time"
"github.com/ethersphere/bee/pkg/ratelimit"
)
func TestRateLimit(t *testing.T) {
var (
key1 = "test1"
key2 = "test2"
rate = time.Second
burst = 10
)
limiter := ratelimit.New(rate, burst)
if !limiter.Allow(key1, burst) {
t.Fatal("want allowed")
}
if limiter.Allow(key1, burst) {
t.Fatalf("want not allowed")
}
limiter.Clear(key1)
if !limiter.Allow(key1, burst) {
t.Fatal("want allowed")
}
if !limiter.Allow(key2, burst) {
t.Fatal("want allowed")
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment