Commit 479f6658 authored by Anatolie Lupacescu's avatar Anatolie Lupacescu Committed by GitHub

fix: add minimum payment threshold validation (#1635)

parent 17f7837b
...@@ -432,7 +432,9 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -432,7 +432,9 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
pricer := pricer.NewFixedPricer(swarmAddress, 1000000000) pricer := pricer.NewFixedPricer(swarmAddress, 1000000000)
pricing := pricing.New(p2ps, logger, paymentThreshold) minThreshold := pricer.MostExpensive()
pricing := pricing.New(p2ps, logger, paymentThreshold, minThreshold)
if err = p2ps.AddProtocol(pricing.Protocol()); err != nil { if err = p2ps.AddProtocol(pricing.Protocol()); err != nil {
return nil, fmt.Errorf("pricing service: %w", err) return nil, fmt.Errorf("pricing service: %w", err)
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
package pricer package pricer
import ( import (
"math/big"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -39,3 +41,10 @@ func (pricer *FixedPricer) PeerPrice(peer, chunk swarm.Address) uint64 { ...@@ -39,3 +41,10 @@ func (pricer *FixedPricer) PeerPrice(peer, chunk swarm.Address) uint64 {
func (pricer *FixedPricer) Price(chunk swarm.Address) uint64 { func (pricer *FixedPricer) Price(chunk swarm.Address) uint64 {
return pricer.PeerPrice(pricer.overlay, chunk) return pricer.PeerPrice(pricer.overlay, chunk)
} }
func (pricer *FixedPricer) MostExpensive() *big.Int {
poPrice := new(big.Int).SetUint64(pricer.poPrice)
maxPO := new(big.Int).SetUint64(uint64(swarm.MaxPO))
tenTimesMaxPO := new(big.Int).Mul(big.NewInt(10), maxPO)
return new(big.Int).Mul(tenTimesMaxPO, poPrice)
}
...@@ -6,6 +6,7 @@ package pricing ...@@ -6,6 +6,7 @@ package pricing
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"time" "time"
...@@ -23,6 +24,11 @@ const ( ...@@ -23,6 +24,11 @@ const (
streamName = "pricing" streamName = "pricing"
) )
var (
// ErrThresholdTooLow says that the proposed payment threshold is too low for even a single reserve.
ErrThresholdTooLow = errors.New("threshold too low")
)
var _ Interface = (*Service)(nil) var _ Interface = (*Service)(nil)
// Interface is the main interface of the pricing protocol // Interface is the main interface of the pricing protocol
...@@ -44,14 +50,16 @@ type Service struct { ...@@ -44,14 +50,16 @@ type Service struct {
streamer p2p.Streamer streamer p2p.Streamer
logger logging.Logger logger logging.Logger
paymentThreshold *big.Int paymentThreshold *big.Int
minPaymentThreshold *big.Int
paymentThresholdObserver PaymentThresholdObserver paymentThresholdObserver PaymentThresholdObserver
} }
func New(streamer p2p.Streamer, logger logging.Logger, paymentThreshold *big.Int) *Service { func New(streamer p2p.Streamer, logger logging.Logger, paymentThreshold *big.Int, minThreshold *big.Int) *Service {
return &Service{ return &Service{
streamer: streamer, streamer: streamer,
logger: logger, logger: logger,
paymentThreshold: paymentThreshold, paymentThreshold: paymentThreshold,
minPaymentThreshold: minThreshold,
} }
} }
...@@ -89,6 +97,11 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -89,6 +97,11 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
paymentThreshold := big.NewInt(0).SetBytes(req.PaymentThreshold) paymentThreshold := big.NewInt(0).SetBytes(req.PaymentThreshold)
s.logger.Tracef("received payment threshold announcement from peer %v of %d", p.Address, paymentThreshold) s.logger.Tracef("received payment threshold announcement from peer %v of %d", p.Address, paymentThreshold)
if paymentThreshold.Cmp(s.minPaymentThreshold) < 0 {
s.logger.Tracef("payment threshold from peer %v of %d too small, need at least %d", p.Address, paymentThreshold, s.minPaymentThreshold)
return p2p.NewDisconnectError(ErrThresholdTooLow)
}
if paymentThreshold.Cmp(big.NewInt(0)) == 0 { if paymentThreshold.Cmp(big.NewInt(0)) == 0 {
return err return err
} }
......
...@@ -7,11 +7,13 @@ package pricing_test ...@@ -7,11 +7,13 @@ package pricing_test
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pricing" "github.com/ethersphere/bee/pkg/pricing"
...@@ -37,7 +39,7 @@ func TestAnnouncePaymentThreshold(t *testing.T) { ...@@ -37,7 +39,7 @@ func TestAnnouncePaymentThreshold(t *testing.T) {
testThreshold := big.NewInt(100000) testThreshold := big.NewInt(100000)
observer := &testThresholdObserver{} observer := &testThresholdObserver{}
recipient := pricing.New(nil, logger, testThreshold) recipient := pricing.New(nil, logger, testThreshold, big.NewInt(1000))
recipient.SetPaymentThresholdObserver(observer) recipient.SetPaymentThresholdObserver(observer)
peerID := swarm.MustParseHexAddress("9ee7add7") peerID := swarm.MustParseHexAddress("9ee7add7")
...@@ -47,9 +49,9 @@ func TestAnnouncePaymentThreshold(t *testing.T) { ...@@ -47,9 +49,9 @@ func TestAnnouncePaymentThreshold(t *testing.T) {
streamtest.WithBaseAddr(peerID), streamtest.WithBaseAddr(peerID),
) )
payer := pricing.New(recorder, logger, testThreshold) payer := pricing.New(recorder, logger, testThreshold, big.NewInt(1000))
paymentThreshold := big.NewInt(10000) paymentThreshold := big.NewInt(100000)
err := payer.AnnouncePaymentThreshold(context.Background(), peerID, paymentThreshold) err := payer.AnnouncePaymentThreshold(context.Background(), peerID, paymentThreshold)
if err != nil { if err != nil {
...@@ -96,3 +98,59 @@ func TestAnnouncePaymentThreshold(t *testing.T) { ...@@ -96,3 +98,59 @@ func TestAnnouncePaymentThreshold(t *testing.T) {
t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID) t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID)
} }
} }
func TestAnnouncePaymentWithInsufficientThreshold(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
testThreshold := big.NewInt(100_000)
observer := &testThresholdObserver{}
minThreshold := big.NewInt(1_000_000) // above requested threashold
recipient := pricing.New(nil, logger, testThreshold, minThreshold)
recipient.SetPaymentThresholdObserver(observer)
peerID := swarm.MustParseHexAddress("9ee7add7")
recorder := streamtest.New(
streamtest.WithProtocols(recipient.Protocol()),
streamtest.WithBaseAddr(peerID),
)
payer := pricing.New(recorder, logger, testThreshold, minThreshold)
paymentThreshold := big.NewInt(100_000)
err := payer.AnnouncePaymentThreshold(context.Background(), peerID, paymentThreshold)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want %v", l, 1)
}
record := records[0]
if record.Err() == nil {
t.Fatal("expected error")
}
payerErr, ok := record.Err().(*p2p.DisconnectError)
if !ok {
t.Fatalf("wanted %v, got %v", p2p.DisconnectError{}, record.Err())
}
if !errors.Is(payerErr, pricing.ErrThresholdTooLow) {
t.Fatalf("wanted error %v, got %v", pricing.ErrThresholdTooLow, err)
}
if observer.called {
t.Fatal("unexpected call to the observer")
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment