Commit 3217c6e1 authored by metacertain's avatar metacertain Committed by GitHub

Pricing no protocol updates (#1563)

pushsync, retrieval, pricing: Pricing no protocol updates
parent 78aba90f
...@@ -324,11 +324,9 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -324,11 +324,9 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
return nil, fmt.Errorf("invalid payment threshold: %s", paymentThreshold) return nil, fmt.Errorf("invalid payment threshold: %s", paymentThreshold)
} }
pricer := pricer.New(logger, stateStore, swarmAddress, 1000000000) pricer := pricer.NewFixedPricer(swarmAddress, 1000000000)
pricer.SetTopology(kad)
pricing := pricing.New(p2ps, logger, paymentThreshold, pricer) pricing := pricing.New(p2ps, logger, paymentThreshold)
pricing.SetPriceTableObserver(pricer)
if err = p2ps.AddProtocol(pricing.Protocol()); err != nil { if err = p2ps.AddProtocol(pricing.Protocol()); err != nil {
return nil, fmt.Errorf("pricing service: %w", err) return nil, fmt.Errorf("pricing service: %w", err)
......
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pricer
import (
"github.com/ethersphere/bee/pkg/swarm"
)
func (s *Pricer) PeerPricePO(peer swarm.Address, po uint8) (uint64, error) {
return s.peerPricePO(peer, po)
}
...@@ -5,116 +5,25 @@ ...@@ -5,116 +5,25 @@
package mock package mock
import ( import (
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
type Service struct { type MockPricer struct {
peerPrice uint64 peerPrice uint64
price uint64 price uint64
peerPriceFunc func(peer, chunk swarm.Address) uint64
priceForPeerFunc func(peer, chunk swarm.Address) uint64
priceTableFunc func() (priceTable []uint64)
notifyPriceTableFunc func(peer swarm.Address, priceTable []uint64) error
priceHeadlerFunc func(p2p.Headers, swarm.Address) p2p.Headers
notifyPeerPriceFunc func(peer swarm.Address, price uint64, index uint8) error
} }
// WithReserveFunc sets the mock Reserve function func NewMockService(price, peerPrice uint64) *MockPricer {
func WithPeerPriceFunc(f func(peer, chunk swarm.Address) uint64) Option { return &MockPricer{
return optionFunc(func(s *Service) { peerPrice: peerPrice,
s.peerPriceFunc = f price: price,
})
}
// WithReleaseFunc sets the mock Release function
func WithPriceForPeerFunc(f func(peer, chunk swarm.Address) uint64) Option {
return optionFunc(func(s *Service) {
s.priceForPeerFunc = f
})
}
func WithPrice(p uint64) Option {
return optionFunc(func(s *Service) {
s.price = p
})
}
func WithPeerPrice(p uint64) Option {
return optionFunc(func(s *Service) {
s.peerPrice = p
})
}
// WithPriceTableFunc sets the mock Release function
func WithPriceTableFunc(f func() (priceTable []uint64)) Option {
return optionFunc(func(s *Service) {
s.priceTableFunc = f
})
}
func WithPriceHeadlerFunc(f func(headers p2p.Headers, addr swarm.Address) p2p.Headers) Option {
return optionFunc(func(s *Service) {
s.priceHeadlerFunc = f
})
}
func NewMockService(opts ...Option) *Service {
mock := new(Service)
mock.price = 10
mock.peerPrice = 10
for _, o := range opts {
o.apply(mock)
} }
return mock
} }
func (pricer *Service) PeerPrice(peer, chunk swarm.Address) uint64 { func (pricer *MockPricer) PeerPrice(peer, chunk swarm.Address) uint64 {
if pricer.peerPriceFunc != nil {
return pricer.peerPriceFunc(peer, chunk)
}
return pricer.peerPrice return pricer.peerPrice
} }
func (pricer *Service) PriceForPeer(peer, chunk swarm.Address) uint64 { func (pricer *MockPricer) Price(chunk swarm.Address) uint64 {
if pricer.priceForPeerFunc != nil {
return pricer.priceForPeerFunc(peer, chunk)
}
return pricer.price return pricer.price
} }
func (pricer *Service) PriceTable() (priceTable []uint64) {
if pricer.priceTableFunc != nil {
return pricer.priceTableFunc()
}
return nil
}
func (pricer *Service) NotifyPriceTable(peer swarm.Address, priceTable []uint64) error {
if pricer.notifyPriceTableFunc != nil {
return pricer.notifyPriceTableFunc(peer, priceTable)
}
return nil
}
func (pricer *Service) PriceHeadler(headers p2p.Headers, addr swarm.Address) p2p.Headers {
if pricer.priceHeadlerFunc != nil {
return pricer.priceHeadlerFunc(headers, addr)
}
return p2p.Headers{}
}
func (pricer *Service) NotifyPeerPrice(peer swarm.Address, price uint64, index uint8) error {
if pricer.notifyPeerPriceFunc != nil {
return pricer.notifyPeerPriceFunc(peer, price, index)
}
return nil
}
// Option is the option passed to the mock accounting service
type Option interface {
apply(*Service)
}
type optionFunc func(*Service)
func (f optionFunc) apply(r *Service) { f(r) }
...@@ -5,334 +5,37 @@ ...@@ -5,334 +5,37 @@
package pricer package pricer
import ( import (
"errors"
"fmt"
"sync"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
)
const (
priceTablePrefix string = "pricetable_"
) )
var _ Interface = (*Pricer)(nil) // Pricer returns pricing information for chunk hashes.
// Pricer returns pricing information for chunk hashes and proximity orders
type Interface interface { type Interface interface {
// PriceTable returns pricetable stored for the node
PriceTable() []uint64
// PeerPrice is the price the peer charges for a given chunk hash. // PeerPrice is the price the peer charges for a given chunk hash.
PeerPrice(peer, chunk swarm.Address) uint64 PeerPrice(peer, chunk swarm.Address) uint64
// PriceForPeer is the price we charge a peer for a given chunk hash. // Price is the price we charge for a given chunk hash.
PriceForPeer(peer, chunk swarm.Address) uint64 Price(chunk swarm.Address) uint64
// NotifyPriceTable saves a provided pricetable for a peer to store
NotifyPriceTable(peer swarm.Address, priceTable []uint64) error
// NotifyPeerPrice changes a value that belongs to an index in a peer pricetable
NotifyPeerPrice(peer swarm.Address, price uint64, index uint8) error
// PriceHeadler creates response headers with pricing information
PriceHeadler(p2p.Headers, swarm.Address) p2p.Headers
}
var (
ErrPersistingBalancePeer = errors.New("failed to persist pricetable for peer")
)
type pricingPeer struct {
lock sync.Mutex
} }
type Pricer struct { // FixedPricer is a Pricer that has a fixed price for chunks.
pricingPeersMu sync.Mutex type FixedPricer struct {
pricingPeers map[string]*pricingPeer
logger logging.Logger
store storage.StateStorer
overlay swarm.Address overlay swarm.Address
topology topology.Driver
poPrice uint64 poPrice uint64
} }
func New(logger logging.Logger, store storage.StateStorer, overlay swarm.Address, poPrice uint64) *Pricer { // NewFixedPricer returns a new FixedPricer with a given price.
return &Pricer{ func NewFixedPricer(overlay swarm.Address, poPrice uint64) *FixedPricer {
logger: logger, return &FixedPricer{
pricingPeers: make(map[string]*pricingPeer),
store: store,
overlay: overlay, overlay: overlay,
poPrice: poPrice, poPrice: poPrice,
} }
} }
// PriceTable returns the pricetable stored for the node // PeerPrice implements Pricer.
// If not available, the default pricetable is provided func (pricer *FixedPricer) PeerPrice(peer, chunk swarm.Address) uint64 {
func (s *Pricer) PriceTable() (priceTable []uint64) { return uint64(swarm.MaxPO-swarm.Proximity(peer.Bytes(), chunk.Bytes())+1) * pricer.poPrice
err := s.store.Get(priceTableKey(), &priceTable)
if err != nil {
priceTable = s.defaultPriceTable()
}
return priceTable
}
// peerPriceTable returns the price table stored for the given peer.
// If we can't get price table from store, we return the default price table
func (s *Pricer) peerPriceTable(peer swarm.Address) (priceTable []uint64) {
err := s.store.Get(peerPriceTableKey(peer), &priceTable)
if err != nil {
priceTable = s.defaultPriceTable() // get default pricetable
}
return priceTable
}
// PriceForPeer returns the price for the PO of a chunk from the table stored for the node.
// Taking into consideration that the peer might be an in-neighborhood peer,
// if the chunk is at least neighborhood depth proximate to both the node and the peer, the price is 0
func (s *Pricer) PriceForPeer(peer, chunk swarm.Address) uint64 {
proximity := swarm.Proximity(s.overlay.Bytes(), chunk.Bytes())
neighborhoodDepth := s.neighborhoodDepth()
if proximity >= neighborhoodDepth {
peerproximity := swarm.Proximity(peer.Bytes(), chunk.Bytes())
if peerproximity >= neighborhoodDepth {
return 0
}
}
price, err := s.pricePO(proximity)
if err != nil {
price = s.defaultPrice(proximity)
}
return price
}
// priceWithIndexForPeer returns price for a chunk for a given peer,
// and the index of PO in pricetable which is used
func (s *Pricer) priceWithIndexForPeer(peer, chunk swarm.Address) (price uint64, index uint8) {
proximity := swarm.Proximity(s.overlay.Bytes(), chunk.Bytes())
neighborhoodDepth := s.neighborhoodDepth()
priceTable := s.PriceTable()
if int(proximity) >= len(priceTable) {
proximity = uint8(len(priceTable) - 1)
}
if proximity >= neighborhoodDepth {
proximity = neighborhoodDepth
peerproximity := swarm.Proximity(peer.Bytes(), chunk.Bytes())
if peerproximity >= neighborhoodDepth {
return 0, proximity
}
}
return priceTable[proximity], proximity
}
// pricePO returns the price for a PO from the table stored for the node.
func (s *Pricer) pricePO(po uint8) (uint64, error) {
priceTable := s.PriceTable()
proximity := po
if int(po) >= len(priceTable) {
proximity = uint8(len(priceTable) - 1)
}
return priceTable[proximity], nil
}
// PeerPrice returns the price for the PO of a chunk from the table stored for the given peer.
// Taking into consideration that the peer might be an in-neighborhood peer,
// if the chunk is at least neighborhood depth proximate to both the node and the peer, the price is 0
func (s *Pricer) PeerPrice(peer, chunk swarm.Address) uint64 {
proximity := swarm.Proximity(peer.Bytes(), chunk.Bytes())
// Determine neighborhood depth presumed by peer based on pricetable rows
var priceTable []uint64
err := s.store.Get(peerPriceTableKey(peer), &priceTable)
peerNeighborhoodDepth := uint8(len(priceTable) - 1)
if err != nil {
peerNeighborhoodDepth = s.neighborhoodDepth()
}
// determine whether the chunk is within presumed neighborhood depth of peer
if proximity >= peerNeighborhoodDepth {
// determine if the chunk is within presumed neighborhood depth of peer to us
selfproximity := swarm.Proximity(s.overlay.Bytes(), chunk.Bytes())
if selfproximity >= peerNeighborhoodDepth {
return 0
}
}
price, err := s.peerPricePO(peer, proximity)
if err != nil {
price = s.defaultPrice(proximity)
}
return price
}
// peerPricePO returns the price for a PO from the table stored for the given peer.
func (s *Pricer) peerPricePO(peer swarm.Address, po uint8) (uint64, error) {
var priceTable []uint64
err := s.store.Get(peerPriceTableKey(peer), &priceTable)
if err != nil {
if !errors.Is(err, storage.ErrNotFound) {
return 0, err
}
priceTable = s.defaultPriceTable()
}
proximity := po
if int(po) >= len(priceTable) {
proximity = uint8(len(priceTable) - 1)
}
return priceTable[proximity], nil
}
// peerPriceTableKey returns the price table storage key for the given peer.
func peerPriceTableKey(peer swarm.Address) string {
return fmt.Sprintf("%s%s", priceTablePrefix, peer.String())
}
// priceTableKey returns the price table storage key for own price table
func priceTableKey() string {
return fmt.Sprintf("%s%s", priceTablePrefix, "self")
}
func (s *Pricer) getPricingPeer(peer swarm.Address) (*pricingPeer, error) {
s.pricingPeersMu.Lock()
defer s.pricingPeersMu.Unlock()
peerData, ok := s.pricingPeers[peer.String()]
if !ok {
peerData = &pricingPeer{}
s.pricingPeers[peer.String()] = peerData
}
return peerData, nil
}
func (s *Pricer) storePriceTable(peer swarm.Address, priceTable []uint64) error {
s.logger.Tracef("Storing pricetable %v for peer %v", priceTable, peer)
err := s.store.Put(peerPriceTableKey(peer), priceTable)
if err != nil {
return err
}
return nil
}
// NotifyPriceTable should be called to notify pricer of changes in the peers pricetable
func (s *Pricer) NotifyPriceTable(peer swarm.Address, priceTable []uint64) error {
pricingPeer, err := s.getPricingPeer(peer)
if err != nil {
return err
}
pricingPeer.lock.Lock()
defer pricingPeer.lock.Unlock()
return s.storePriceTable(peer, priceTable)
}
func (s *Pricer) NotifyPeerPrice(peer swarm.Address, price uint64, index uint8) error {
if price == 0 {
return nil
}
pricingPeer, err := s.getPricingPeer(peer)
if err != nil {
return err
}
pricingPeer.lock.Lock()
defer pricingPeer.lock.Unlock()
priceTable := s.peerPriceTable(peer)
currentIndexDepth := uint8(len(priceTable)) - 1
if index <= currentIndexDepth {
// Simple case, already have index depth, single value change
priceTable[index] = price
return s.storePriceTable(peer, priceTable)
}
// Complicated case, index is larger than depth of already known table
newPriceTable := make([]uint64, index+1)
// Copy previous content
_ = copy(newPriceTable, priceTable)
// Check how many rows are missing
numberOfMissingRows := index - currentIndexDepth
for i := uint8(0); i < numberOfMissingRows; i++ {
currentrow := index - i
newPriceTable[currentrow] = price + uint64(i)*s.poPrice
s.logger.Debugf("Guessing price %v for extending pricetable %v for peer %v", newPriceTable[currentrow], newPriceTable, peer)
}
return s.storePriceTable(peer, newPriceTable)
}
func (s *Pricer) defaultPriceTable() []uint64 {
neighborhoodDepth := s.neighborhoodDepth()
priceTable := make([]uint64, neighborhoodDepth+1)
for i := uint8(0); i <= neighborhoodDepth; i++ {
priceTable[i] = uint64(neighborhoodDepth-i+1) * s.poPrice
}
return priceTable
}
func (s *Pricer) defaultPrice(po uint8) uint64 {
neighborhoodDepth := s.neighborhoodDepth()
if po > neighborhoodDepth {
po = neighborhoodDepth
}
return uint64(neighborhoodDepth-po+1) * s.poPrice
}
func (s *Pricer) neighborhoodDepth() uint8 {
var neighborhoodDepth uint8
if s.topology != nil {
neighborhoodDepth = s.topology.NeighborhoodDepth()
}
return neighborhoodDepth
}
func (s *Pricer) PriceHeadler(receivedHeaders p2p.Headers, peerAddress swarm.Address) (returnHeaders p2p.Headers) {
chunkAddress, receivedPrice, err := headerutils.ParsePricingHeaders(receivedHeaders)
if err != nil {
return p2p.Headers{
"error": []byte("Error reading pricing headers"),
}
}
s.logger.Debugf("price headler: received target %v with price as %v, from peer %s", chunkAddress, receivedPrice, peerAddress)
checkPrice, index := s.priceWithIndexForPeer(peerAddress, chunkAddress)
returnHeaders, err = headerutils.MakePricingResponseHeaders(checkPrice, chunkAddress, index)
if err != nil {
return p2p.Headers{
"error": []byte("Error creating response pricing headers"),
}
}
s.logger.Debugf("price headler: response target %v with price as %v, for peer %s", chunkAddress, checkPrice, peerAddress)
return returnHeaders
} }
func (s *Pricer) SetTopology(top topology.Driver) { // Price implements Pricer.
s.topology = top func (pricer *FixedPricer) Price(chunk swarm.Address) uint64 {
return pricer.PeerPrice(pricer.overlay, chunk)
} }
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pricer_test
import (
"encoding/binary"
"io/ioutil"
"reflect"
"testing"
mockkad "github.com/ethersphere/bee/pkg/kademlia/mock"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
"github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestPriceTableDefaultTables(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
defer store.Close()
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepthCalls(0, 0, 2, 3, 5, 9))
pricer.SetTopology(kad)
table0 := []uint64{10}
table1 := []uint64{30, 20, 10}
table2 := []uint64{40, 30, 20, 10}
table3 := []uint64{60, 50, 40, 30, 20, 10}
table4 := []uint64{100, 90, 80, 70, 60, 50, 40, 30, 20, 10}
getTable0 := pricer.PriceTable()
if !reflect.DeepEqual(table0, getTable0) {
t.Fatalf("returned table does not match, got %+v expected %+v", getTable0, table0)
}
getTable1 := pricer.PriceTable()
if !reflect.DeepEqual(table1, getTable1) {
t.Fatalf("returned table does not match, got %+v expected %+v", getTable1, table1)
}
getTable2 := pricer.PriceTable()
if !reflect.DeepEqual(table2, getTable2) {
t.Fatalf("returned table does not match, got %+v expected %+v", getTable2, table2)
}
getTable3 := pricer.PriceTable()
if !reflect.DeepEqual(table3, getTable3) {
t.Fatalf("returned table does not match, got %+v expected %+v", getTable3, table3)
}
getTable4 := pricer.PriceTable()
if !reflect.DeepEqual(table4, getTable4) {
t.Fatalf("returned table does not match, got %+v expected %+v", getTable4, table4)
}
}
func TestPeerPrice(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
defer store.Close()
overlay := swarm.MustParseHexAddress("0000617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
peer := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
chunksByPOToPeer := []swarm.Address{
swarm.MustParseHexAddress("05ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("95ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("c5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("f0ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("e6ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
}
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(3))
pricer.SetTopology(kad)
peerTable := []uint64{55, 45, 35, 25, 15}
err := pricer.NotifyPriceTable(peer, peerTable)
if err != nil {
t.Fatal(err)
}
for i, ch := range chunksByPOToPeer {
getPrice := pricer.PeerPrice(peer, ch)
if getPrice != peerTable[i] {
t.Fatalf("unexpected PeerPrice, got %v expected %v", getPrice, peerTable[i])
}
}
neighborhoodPeer := swarm.MustParseHexAddress("00ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
neighborhoodPeerTable := []uint64{36, 26, 16, 6}
err = pricer.NotifyPriceTable(neighborhoodPeer, neighborhoodPeerTable)
if err != nil {
t.Fatal(err)
}
chunksByPOToNeighborhoodPeer := []swarm.Address{
swarm.MustParseHexAddress("95ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("55ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("35ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("10ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
}
neighborhoodPeerExpectedPrices := []uint64{36, 26, 16, 0}
for i, ch := range chunksByPOToNeighborhoodPeer {
getPrice := pricer.PeerPrice(neighborhoodPeer, ch)
if getPrice != neighborhoodPeerExpectedPrices[i] {
t.Fatalf("unexpected PeerPrice, got %v expected %v", getPrice, neighborhoodPeerExpectedPrices[i])
}
}
}
func TestPriceForPeer(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
defer store.Close()
peer := swarm.MustParseHexAddress("0000617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
chunksByPOToOverlay := []swarm.Address{
swarm.MustParseHexAddress("05ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("95ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("c5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("f0ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("e6ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
}
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(4))
pricer.SetTopology(kad)
defaultTable := []uint64{50, 40, 30, 20, 10}
for i, ch := range chunksByPOToOverlay {
getPrice := pricer.PriceForPeer(peer, ch)
if getPrice != defaultTable[i] {
t.Fatalf("unexpected price for peer, got %v expected %v", getPrice, defaultTable[i])
}
}
neighborhoodPeer := swarm.MustParseHexAddress("e7ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
expectedPricesInNeighborhood := []uint64{50, 40, 30, 20, 0}
for i, ch := range chunksByPOToOverlay {
getPrice := pricer.PriceForPeer(neighborhoodPeer, ch)
if getPrice != expectedPricesInNeighborhood[i] {
t.Fatalf("unexpected price for peer, got %v expected %v", getPrice, expectedPricesInNeighborhood[i])
}
}
}
func TestNotifyPriceTable(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
defer store.Close()
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(0))
pricer.SetTopology(kad)
peer := swarm.MustParseHexAddress("ffff617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
peerTable := []uint64{66, 55, 44, 33, 22, 11}
err := pricer.NotifyPriceTable(peer, peerTable)
if err != nil {
t.Fatal(err)
}
for i := 0; i < len(peerTable); i++ {
getPrice, err := pricer.PeerPricePO(peer, uint8(i))
if err != nil {
t.Fatal(err)
}
if getPrice != peerTable[i] {
t.Fatalf("unexpected PeerPricePO, got %v expected %v", getPrice, peerTable[i])
}
}
}
func TestNotifyPeerPrice(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
defer store.Close()
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(0))
pricer.SetTopology(kad)
peer := swarm.MustParseHexAddress("ffff617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
peerTable := []uint64{33, 22, 11}
err := pricer.NotifyPriceTable(peer, peerTable)
if err != nil {
t.Fatal(err)
}
for i := 0; i < len(peerTable); i++ {
getPrice, err := pricer.PeerPricePO(peer, uint8(i))
if err != nil {
t.Fatal(err)
}
if getPrice != peerTable[i] {
t.Fatalf("unexpected PeerPricePO, got %v expected %v", getPrice, peerTable[i])
}
}
err = pricer.NotifyPeerPrice(peer, 48, 8)
if err != nil {
t.Fatal(err)
}
expectedTable := []uint64{33, 22, 11, 98, 88, 78, 68, 58, 48}
for i := 0; i < len(expectedTable); i++ {
getPrice, err := pricer.PeerPricePO(peer, uint8(i))
if err != nil {
t.Fatal(err)
}
if getPrice != expectedTable[i] {
t.Fatalf("unexpected PeerPricePO, got %v expected %v", getPrice, expectedTable[i])
}
}
err = pricer.NotifyPeerPrice(peer, 43, 9)
if err != nil {
t.Fatal(err)
}
expectedTable = []uint64{33, 22, 11, 98, 88, 78, 68, 58, 48, 43}
for i := 0; i < len(expectedTable); i++ {
getPrice, err := pricer.PeerPricePO(peer, uint8(i))
if err != nil {
t.Fatal(err)
}
if getPrice != expectedTable[i] {
t.Fatalf("unexpected PeerPricePO, got %v expected %v", getPrice, expectedTable[i])
}
}
err = pricer.NotifyPeerPrice(peer, 60, 5)
if err != nil {
t.Fatal(err)
}
expectedTable = []uint64{33, 22, 11, 98, 88, 60, 68, 58, 48, 43}
for i := 0; i < len(expectedTable); i++ {
getPrice, err := pricer.PeerPricePO(peer, uint8(i))
if err != nil {
t.Fatal(err)
}
if getPrice != expectedTable[i] {
t.Fatalf("unexpected PeerPricePO, got %v expected %v", getPrice, expectedTable[i])
}
}
}
func TestPricerHeadler(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
defer store.Close()
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(2))
pricer.SetTopology(kad)
peer := swarm.MustParseHexAddress("07e0d4ba628ad700fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
chunkAddress := swarm.MustParseHexAddress("7a22589362e34efdd7e0d4ba628ad7007a22589362e34efdd7e0d4ba628ad700")
requestHeaders, err := headerutils.MakePricingHeaders(50, chunkAddress)
if err != nil {
t.Fatal(err)
}
responseHeaders := pricer.PriceHeadler(requestHeaders, peer)
if !reflect.DeepEqual(requestHeaders["target"], responseHeaders["target"]) {
t.Fatalf("targets don't match, got %v, want %v", responseHeaders["target"], requestHeaders["target"])
}
chunkPriceInRequest := make([]byte, 8)
binary.BigEndian.PutUint64(chunkPriceInRequest, uint64(50))
chunkPriceInResponse := make([]byte, 8)
binary.BigEndian.PutUint64(chunkPriceInResponse, uint64(30))
if !reflect.DeepEqual(requestHeaders["price"], chunkPriceInRequest) {
t.Fatalf("targets don't match, got %v, want %v", responseHeaders["price"], chunkPriceInRequest)
}
if !reflect.DeepEqual(responseHeaders["price"], chunkPriceInResponse) {
t.Fatalf("targets don't match, got %v, want %v", responseHeaders["price"], chunkPriceInResponse)
}
}
func TestPricerHeadlerBadHeaders(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
defer store.Close()
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(2))
pricer.SetTopology(kad)
peer := swarm.MustParseHexAddress("07e0d4ba628ad700fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
requestHeaders := p2p.Headers{
"irrelevantfield": []byte{},
}
responseHeaders := pricer.PriceHeadler(requestHeaders, peer)
if responseHeaders["target"] != nil {
t.Fatal("only error should be returned")
}
if responseHeaders["price"] != nil {
t.Fatal("only error should be returned")
}
if responseHeaders["index"] != nil {
t.Fatal("only error should be returned")
}
if responseHeaders["error"] == nil {
t.Fatal("error should be returned")
}
}
...@@ -24,7 +24,6 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package ...@@ -24,7 +24,6 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type AnnouncePaymentThreshold struct { type AnnouncePaymentThreshold struct {
PaymentThreshold []byte `protobuf:"bytes,1,opt,name=PaymentThreshold,proto3" json:"PaymentThreshold,omitempty"` PaymentThreshold []byte `protobuf:"bytes,1,opt,name=PaymentThreshold,proto3" json:"PaymentThreshold,omitempty"`
ProximityPrice []uint64 `protobuf:"varint,2,rep,packed,name=ProximityPrice,proto3" json:"ProximityPrice,omitempty"`
} }
func (m *AnnouncePaymentThreshold) Reset() { *m = AnnouncePaymentThreshold{} } func (m *AnnouncePaymentThreshold) Reset() { *m = AnnouncePaymentThreshold{} }
...@@ -67,13 +66,6 @@ func (m *AnnouncePaymentThreshold) GetPaymentThreshold() []byte { ...@@ -67,13 +66,6 @@ func (m *AnnouncePaymentThreshold) GetPaymentThreshold() []byte {
return nil return nil
} }
func (m *AnnouncePaymentThreshold) GetProximityPrice() []uint64 {
if m != nil {
return m.ProximityPrice
}
return nil
}
func init() { func init() {
proto.RegisterType((*AnnouncePaymentThreshold)(nil), "pricing.AnnouncePaymentThreshold") proto.RegisterType((*AnnouncePaymentThreshold)(nil), "pricing.AnnouncePaymentThreshold")
} }
...@@ -81,17 +73,15 @@ func init() { ...@@ -81,17 +73,15 @@ func init() {
func init() { proto.RegisterFile("pricing.proto", fileDescriptor_ec4cc93d045d43d0) } func init() { proto.RegisterFile("pricing.proto", fileDescriptor_ec4cc93d045d43d0) }
var fileDescriptor_ec4cc93d045d43d0 = []byte{ var fileDescriptor_ec4cc93d045d43d0 = []byte{
// 150 bytes of a gzipped FileDescriptorProto // 122 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x28, 0xca, 0x4c, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x28, 0xca, 0x4c,
0xce, 0xcc, 0x4b, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x87, 0x72, 0x95, 0xf2, 0xb8, 0xce, 0xcc, 0x4b, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x87, 0x72, 0x95, 0xdc, 0xb8,
0x24, 0x1c, 0xf3, 0xf2, 0xf2, 0x4b, 0xf3, 0x92, 0x53, 0x03, 0x12, 0x2b, 0x73, 0x53, 0xf3, 0x4a, 0x24, 0x1c, 0xf3, 0xf2, 0xf2, 0x4b, 0xf3, 0x92, 0x53, 0x03, 0x12, 0x2b, 0x73, 0x53, 0xf3, 0x4a,
0x42, 0x32, 0x8a, 0x52, 0x8b, 0x33, 0xf2, 0x73, 0x52, 0x84, 0xb4, 0xb8, 0x04, 0xd0, 0xc5, 0x24, 0x42, 0x32, 0x8a, 0x52, 0x8b, 0x33, 0xf2, 0x73, 0x52, 0x84, 0xb4, 0xb8, 0x04, 0xd0, 0xc5, 0x24,
0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x30, 0xc4, 0x85, 0xd4, 0xb8, 0xf8, 0x02, 0x8a, 0xf2, 0x2b, 0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x30, 0xc4, 0x9d, 0x64, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0,
0x32, 0x73, 0x33, 0x4b, 0x2a, 0x03, 0x8a, 0x32, 0x93, 0x53, 0x25, 0x98, 0x14, 0x98, 0x35, 0x58, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x09, 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8,
0x82, 0xd0, 0x44, 0x9d, 0x64, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0xf1, 0x58, 0x8e, 0x21, 0x8a, 0xa9, 0x20, 0x29, 0x89, 0x0d, 0x6c, 0xab, 0x31, 0x20, 0x00, 0x00,
0x39, 0xc6, 0x09, 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8, 0xf1, 0x58, 0x8e, 0x21, 0x8a, 0xff, 0xff, 0x50, 0xca, 0x0e, 0x0a, 0x86, 0x00, 0x00, 0x00,
0xa9, 0x20, 0x29, 0x89, 0x0d, 0xec, 0x3a, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff, 0x90, 0x49,
0x1c, 0x6f, 0xae, 0x00, 0x00, 0x00,
} }
func (m *AnnouncePaymentThreshold) Marshal() (dAtA []byte, err error) { func (m *AnnouncePaymentThreshold) Marshal() (dAtA []byte, err error) {
...@@ -114,24 +104,6 @@ func (m *AnnouncePaymentThreshold) MarshalToSizedBuffer(dAtA []byte) (int, error ...@@ -114,24 +104,6 @@ func (m *AnnouncePaymentThreshold) MarshalToSizedBuffer(dAtA []byte) (int, error
_ = i _ = i
var l int var l int
_ = l _ = l
if len(m.ProximityPrice) > 0 {
dAtA2 := make([]byte, len(m.ProximityPrice)*10)
var j1 int
for _, num := range m.ProximityPrice {
for num >= 1<<7 {
dAtA2[j1] = uint8(uint64(num)&0x7f | 0x80)
num >>= 7
j1++
}
dAtA2[j1] = uint8(num)
j1++
}
i -= j1
copy(dAtA[i:], dAtA2[:j1])
i = encodeVarintPricing(dAtA, i, uint64(j1))
i--
dAtA[i] = 0x12
}
if len(m.PaymentThreshold) > 0 { if len(m.PaymentThreshold) > 0 {
i -= len(m.PaymentThreshold) i -= len(m.PaymentThreshold)
copy(dAtA[i:], m.PaymentThreshold) copy(dAtA[i:], m.PaymentThreshold)
...@@ -163,13 +135,6 @@ func (m *AnnouncePaymentThreshold) Size() (n int) { ...@@ -163,13 +135,6 @@ func (m *AnnouncePaymentThreshold) Size() (n int) {
if l > 0 { if l > 0 {
n += 1 + l + sovPricing(uint64(l)) n += 1 + l + sovPricing(uint64(l))
} }
if len(m.ProximityPrice) > 0 {
l = 0
for _, e := range m.ProximityPrice {
l += sovPricing(uint64(e))
}
n += 1 + sovPricing(uint64(l)) + l
}
return n return n
} }
...@@ -242,82 +207,6 @@ func (m *AnnouncePaymentThreshold) Unmarshal(dAtA []byte) error { ...@@ -242,82 +207,6 @@ func (m *AnnouncePaymentThreshold) Unmarshal(dAtA []byte) error {
m.PaymentThreshold = []byte{} m.PaymentThreshold = []byte{}
} }
iNdEx = postIndex iNdEx = postIndex
case 2:
if wireType == 0 {
var v uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPricing
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
v |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
m.ProximityPrice = append(m.ProximityPrice, v)
} else if wireType == 2 {
var packedLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPricing
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
packedLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if packedLen < 0 {
return ErrInvalidLengthPricing
}
postIndex := iNdEx + packedLen
if postIndex < 0 {
return ErrInvalidLengthPricing
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
var elementCount int
var count int
for _, integer := range dAtA[iNdEx:postIndex] {
if integer < 128 {
count++
}
}
elementCount = count
if elementCount != 0 && len(m.ProximityPrice) == 0 {
m.ProximityPrice = make([]uint64, 0, elementCount)
}
for iNdEx < postIndex {
var v uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPricing
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
v |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
m.ProximityPrice = append(m.ProximityPrice, v)
}
} else {
return fmt.Errorf("proto: wrong wireType = %d for field ProximityPrice", wireType)
}
default: default:
iNdEx = preIndex iNdEx = preIndex
skippy, err := skipPricing(dAtA[iNdEx:]) skippy, err := skipPricing(dAtA[iNdEx:])
......
...@@ -10,5 +10,4 @@ option go_package = "pb"; ...@@ -10,5 +10,4 @@ option go_package = "pb";
message AnnouncePaymentThreshold { message AnnouncePaymentThreshold {
bytes PaymentThreshold = 1; bytes PaymentThreshold = 1;
repeated uint64 ProximityPrice = 2;
} }
...@@ -13,7 +13,6 @@ import ( ...@@ -13,7 +13,6 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricing/pb" "github.com/ethersphere/bee/pkg/pricing/pb"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -29,7 +28,6 @@ var _ Interface = (*Service)(nil) ...@@ -29,7 +28,6 @@ var _ Interface = (*Service)(nil)
// Interface is the main interface of the pricing protocol // Interface is the main interface of the pricing protocol
type Interface interface { type Interface interface {
AnnouncePaymentThreshold(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error AnnouncePaymentThreshold(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error
AnnouncePaymentThresholdAndPriceTable(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error
} }
// PriceTableObserver is used for being notified of price table updates // PriceTableObserver is used for being notified of price table updates
...@@ -46,16 +44,13 @@ type Service struct { ...@@ -46,16 +44,13 @@ type Service struct {
streamer p2p.Streamer streamer p2p.Streamer
logger logging.Logger logger logging.Logger
paymentThreshold *big.Int paymentThreshold *big.Int
pricer pricer.Interface
paymentThresholdObserver PaymentThresholdObserver paymentThresholdObserver PaymentThresholdObserver
priceTableObserver PriceTableObserver
} }
func New(streamer p2p.Streamer, logger logging.Logger, paymentThreshold *big.Int, pricer pricer.Interface) *Service { func New(streamer p2p.Streamer, logger logging.Logger, paymentThreshold *big.Int) *Service {
return &Service{ return &Service{
streamer: streamer, streamer: streamer,
logger: logger, logger: logger,
pricer: pricer,
paymentThreshold: paymentThreshold, paymentThreshold: paymentThreshold,
} }
} }
...@@ -94,15 +89,6 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -94,15 +89,6 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
paymentThreshold := big.NewInt(0).SetBytes(req.PaymentThreshold) paymentThreshold := big.NewInt(0).SetBytes(req.PaymentThreshold)
s.logger.Tracef("received payment threshold announcement from peer %v of %d", p.Address, paymentThreshold) s.logger.Tracef("received payment threshold announcement from peer %v of %d", p.Address, paymentThreshold)
if req.ProximityPrice != nil {
s.logger.Tracef("received pricetable announcement from peer %v of %v", p.Address, req.ProximityPrice)
err = s.priceTableObserver.NotifyPriceTable(p.Address, req.ProximityPrice)
if err != nil {
s.logger.Debugf("error receiving pricetable from peer %v: %w", p.Address, err)
s.logger.Errorf("error receiving pricetable from peer %v: %w", p.Address, err)
}
}
if paymentThreshold.Cmp(big.NewInt(0)) == 0 { if paymentThreshold.Cmp(big.NewInt(0)) == 0 {
return err return err
} }
...@@ -110,7 +96,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -110,7 +96,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
} }
func (s *Service) init(ctx context.Context, p p2p.Peer) error { func (s *Service) init(ctx context.Context, p p2p.Peer) error {
err := s.AnnouncePaymentThresholdAndPriceTable(ctx, p.Address, s.paymentThreshold) err := s.AnnouncePaymentThreshold(ctx, p.Address, s.paymentThreshold)
if err != nil { if err != nil {
s.logger.Warningf("could not send payment threshold announcement to peer %v", p.Address) s.logger.Warningf("could not send payment threshold announcement to peer %v", p.Address)
} }
...@@ -143,39 +129,7 @@ func (s *Service) AnnouncePaymentThreshold(ctx context.Context, peer swarm.Addre ...@@ -143,39 +129,7 @@ func (s *Service) AnnouncePaymentThreshold(ctx context.Context, peer swarm.Addre
return err return err
} }
// AnnouncePaymentThresholdAndPriceTable announces own payment threshold and pricetable to peer
func (s *Service) AnnouncePaymentThresholdAndPriceTable(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil {
return err
}
defer func() {
if err != nil {
_ = stream.Reset()
} else {
go stream.FullClose()
}
}()
s.logger.Tracef("sending payment threshold announcement to peer %v of %d", peer, paymentThreshold)
w := protobuf.NewWriter(stream)
err = w.WriteMsgWithContext(ctx, &pb.AnnouncePaymentThreshold{
PaymentThreshold: paymentThreshold.Bytes(),
ProximityPrice: s.pricer.PriceTable(),
})
return err
}
// SetPaymentThresholdObserver sets the PaymentThresholdObserver to be used when receiving a new payment threshold // SetPaymentThresholdObserver sets the PaymentThresholdObserver to be used when receiving a new payment threshold
func (s *Service) SetPaymentThresholdObserver(observer PaymentThresholdObserver) { func (s *Service) SetPaymentThresholdObserver(observer PaymentThresholdObserver) {
s.paymentThresholdObserver = observer s.paymentThresholdObserver = observer
} }
// SetPriceTableObserver sets the PriceTableObserver to be used when receiving a new pricetable
func (s *Service) SetPriceTableObserver(observer PriceTableObserver) {
s.priceTableObserver = observer
}
...@@ -9,13 +9,11 @@ import ( ...@@ -9,13 +9,11 @@ import (
"context" "context"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"reflect"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
pricermock "github.com/ethersphere/bee/pkg/pricer/mock"
"github.com/ethersphere/bee/pkg/pricing" "github.com/ethersphere/bee/pkg/pricing"
"github.com/ethersphere/bee/pkg/pricing/pb" "github.com/ethersphere/bee/pkg/pricing/pb"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -27,12 +25,6 @@ type testThresholdObserver struct { ...@@ -27,12 +25,6 @@ type testThresholdObserver struct {
paymentThreshold *big.Int paymentThreshold *big.Int
} }
type testPriceTableObserver struct {
called bool
peer swarm.Address
priceTable []uint64
}
func (t *testThresholdObserver) NotifyPaymentThreshold(peerAddr swarm.Address, paymentThreshold *big.Int) error { func (t *testThresholdObserver) NotifyPaymentThreshold(peerAddr swarm.Address, paymentThreshold *big.Int) error {
t.called = true t.called = true
t.peer = peerAddr t.peer = peerAddr
...@@ -40,21 +32,12 @@ func (t *testThresholdObserver) NotifyPaymentThreshold(peerAddr swarm.Address, p ...@@ -40,21 +32,12 @@ func (t *testThresholdObserver) NotifyPaymentThreshold(peerAddr swarm.Address, p
return nil return nil
} }
func (t *testPriceTableObserver) NotifyPriceTable(peerAddr swarm.Address, priceTable []uint64) error {
t.called = true
t.peer = peerAddr
t.priceTable = priceTable
return nil
}
func TestAnnouncePaymentThreshold(t *testing.T) { func TestAnnouncePaymentThreshold(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
testThreshold := big.NewInt(100000) testThreshold := big.NewInt(100000)
observer := &testThresholdObserver{} observer := &testThresholdObserver{}
pricerMockService := pricermock.NewMockService() recipient := pricing.New(nil, logger, testThreshold)
recipient := pricing.New(nil, logger, testThreshold, pricerMockService)
recipient.SetPaymentThresholdObserver(observer) recipient.SetPaymentThresholdObserver(observer)
peerID := swarm.MustParseHexAddress("9ee7add7") peerID := swarm.MustParseHexAddress("9ee7add7")
...@@ -64,7 +47,7 @@ func TestAnnouncePaymentThreshold(t *testing.T) { ...@@ -64,7 +47,7 @@ func TestAnnouncePaymentThreshold(t *testing.T) {
streamtest.WithBaseAddr(peerID), streamtest.WithBaseAddr(peerID),
) )
payer := pricing.New(recorder, logger, testThreshold, pricerMockService) payer := pricing.New(recorder, logger, testThreshold)
paymentThreshold := big.NewInt(10000) paymentThreshold := big.NewInt(10000)
...@@ -113,93 +96,3 @@ func TestAnnouncePaymentThreshold(t *testing.T) { ...@@ -113,93 +96,3 @@ func TestAnnouncePaymentThreshold(t *testing.T) {
t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID) t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID)
} }
} }
func TestAnnouncePaymentThresholdAndPriceTable(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
testThreshold := big.NewInt(100000)
observer1 := &testThresholdObserver{}
observer2 := &testPriceTableObserver{}
table := []uint64{50, 25, 12, 6}
priceTableFunc := func() []uint64 {
return table
}
pricerMockService := pricermock.NewMockService(pricermock.WithPriceTableFunc(priceTableFunc))
recipient := pricing.New(nil, logger, testThreshold, pricerMockService)
recipient.SetPaymentThresholdObserver(observer1)
recipient.SetPriceTableObserver(observer2)
peerID := swarm.MustParseHexAddress("9ee7add7")
recorder := streamtest.New(
streamtest.WithProtocols(recipient.Protocol()),
streamtest.WithBaseAddr(peerID),
)
payer := pricing.New(recorder, logger, testThreshold, pricerMockService)
paymentThreshold := big.NewInt(10000)
err := payer.AnnouncePaymentThresholdAndPriceTable(context.Background(), peerID, paymentThreshold)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want %v", l, 1)
}
record := records[0]
messages, err := protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.AnnouncePaymentThreshold) },
)
if err != nil {
t.Fatal(err)
}
if len(messages) != 1 {
t.Fatalf("got %v messages, want %v", len(messages), 1)
}
sentPaymentThreshold := big.NewInt(0).SetBytes(messages[0].(*pb.AnnouncePaymentThreshold).PaymentThreshold)
if sentPaymentThreshold.Cmp(paymentThreshold) != 0 {
t.Fatalf("got message with amount %v, want %v", sentPaymentThreshold, paymentThreshold)
}
sentPriceTable := messages[0].(*pb.AnnouncePaymentThreshold).ProximityPrice
if !reflect.DeepEqual(sentPriceTable, table) {
t.Fatalf("got message with table %v, want %v", sentPriceTable, table)
}
if !observer1.called {
t.Fatal("expected threshold observer to be called")
}
if observer1.paymentThreshold.Cmp(paymentThreshold) != 0 {
t.Fatalf("observer called with wrong paymentThreshold. got %d, want %d", observer1.paymentThreshold, paymentThreshold)
}
if !observer1.peer.Equal(peerID) {
t.Fatalf("threshold observer called with wrong peer. got %v, want %v", observer1.peer, peerID)
}
if !observer2.called {
t.Fatal("expected table observer to be called")
}
if !reflect.DeepEqual(observer2.priceTable, table) {
t.Fatalf("table observer called with wrong priceTable. got %d, want %d", observer2.priceTable, table)
}
if !observer2.peer.Equal(peerID) {
t.Fatalf("table observer called with wrong peer. got %v, want %v", observer2.peer, peerID)
}
}
...@@ -19,7 +19,6 @@ import ( ...@@ -19,7 +19,6 @@ import (
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pricer" "github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
"github.com/ethersphere/bee/pkg/pushsync/pb" "github.com/ethersphere/bee/pkg/pushsync/pb"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -98,7 +97,6 @@ func (s *PushSync) Protocol() p2p.ProtocolSpec { ...@@ -98,7 +97,6 @@ func (s *PushSync) Protocol() p2p.ProtocolSpec {
{ {
Name: streamName, Name: streamName,
Handler: s.handler, Handler: s.handler,
Headler: s.pricer.PriceHeadler,
}, },
}, },
} }
...@@ -134,14 +132,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -134,14 +132,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
return swarm.ErrInvalidChunk return swarm.ErrInvalidChunk
} }
// Get price we charge for upstream peer read at headler. price := ps.pricer.Price(chunk.Address())
responseHeaders := stream.ResponseHeaders()
price, err := headerutils.ParsePriceHeader(responseHeaders)
// if not found in returned header, compute the price we charge for this chunk.
if err != nil {
ps.logger.Warningf("pushsync: peer %v no price in previously issued response headers: %v", p.Address, err)
price = ps.pricer.PriceForPeer(p.Address, chunk.Address())
}
// if the peer is closer to the chunk, we were selected for replication. Return early. // if the peer is closer to the chunk, we were selected for replication. Return early.
if dcmp, _ := swarm.DistanceCmp(chunk.Address().Bytes(), p.Address.Bytes(), ps.address.Bytes()); dcmp == 1 { if dcmp, _ := swarm.DistanceCmp(chunk.Address().Bytes(), p.Address.Bytes(), ps.address.Bytes()); dcmp == 1 {
...@@ -153,6 +144,8 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -153,6 +144,8 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
return ps.accounting.Debit(p.Address, price) return ps.accounting.Debit(p.Address, price)
} }
fmt.Println("YAY")
fmt.Println(ps.address)
return ErrOutOfDepthReplication return ErrOutOfDepthReplication
} }
...@@ -204,34 +197,22 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -204,34 +197,22 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
}() }()
// price for neighborhood replication // price for neighborhood replication
const receiptPrice uint64 = 0 receiptPrice := ps.pricer.PeerPrice(peer, chunk.Address())
headers, err := headerutils.MakePricingHeaders(receiptPrice, chunk.Address()) err = ps.accounting.Reserve(ctx, peer, receiptPrice)
if err != nil { if err != nil {
err = fmt.Errorf("make pricing headers: %w", err) err = fmt.Errorf("reserve balance for peer %s: %w", peer.String(), err)
return return
} }
defer ps.accounting.Release(peer, receiptPrice)
streamer, err := ps.streamer.NewStream(ctx, peer, headers, protocolName, protocolVersion, streamName) streamer, err := ps.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil { if err != nil {
err = fmt.Errorf("new stream for peer %s: %w", peer.String(), err) err = fmt.Errorf("new stream for peer %s: %w", peer.String(), err)
return return
} }
defer streamer.Close() defer streamer.Close()
returnedHeaders := streamer.Headers()
_, returnedPrice, returnedIndex, err := headerutils.ParsePricingResponseHeaders(returnedHeaders)
if err != nil {
err = fmt.Errorf("push price headers read returned: %w", err)
return
}
// check if returned price matches presumed price, if not, return early.
if returnedPrice != receiptPrice {
err = ps.pricer.NotifyPeerPrice(peer, returnedPrice, returnedIndex)
return
}
w := protobuf.NewWriter(streamer) w := protobuf.NewWriter(streamer)
ctx, cancel := context.WithTimeout(ctx, timeToWaitForPushsyncToNeighbor) ctx, cancel := context.WithTimeout(ctx, timeToWaitForPushsyncToNeighbor)
defer cancel() defer cancel()
...@@ -245,6 +226,8 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -245,6 +226,8 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
return return
} }
_ = ps.accounting.Credit(peer, receiptPrice)
}(peer) }(peer)
return false, false, nil return false, false, nil
...@@ -341,34 +324,14 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk) (rr *pb.R ...@@ -341,34 +324,14 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk) (rr *pb.R
// compute the price we pay for this receipt and reserve it for the rest of this function // compute the price we pay for this receipt and reserve it for the rest of this function
receiptPrice := ps.pricer.PeerPrice(peer, ch.Address()) receiptPrice := ps.pricer.PeerPrice(peer, ch.Address())
headers, err := headerutils.MakePricingHeaders(receiptPrice, ch.Address()) streamer, err := ps.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil {
return nil, err
}
streamer, err := ps.streamer.NewStream(ctx, peer, headers, protocolName, protocolVersion, streamName)
if err != nil { if err != nil {
lastErr = fmt.Errorf("new stream for peer %s: %w", peer.String(), err) lastErr = fmt.Errorf("new stream for peer %s: %w", peer.String(), err)
continue continue
} }
deferFuncs = append(deferFuncs, func() { go streamer.FullClose() }) deferFuncs = append(deferFuncs, func() { go streamer.FullClose() })
returnedHeaders := streamer.Headers() // Reserve to see whether we can make the request
_, returnedPrice, returnedIndex, err := headerutils.ParsePricingResponseHeaders(returnedHeaders)
if err != nil {
return nil, fmt.Errorf("push price headers: read returned: %w", err)
}
// check if returned price matches presumed price, if not, update price
if returnedPrice != receiptPrice {
err = ps.pricer.NotifyPeerPrice(peer, returnedPrice, returnedIndex) // save priceHeaders["price"] corresponding row for peer
if err != nil {
return nil, err
}
receiptPrice = returnedPrice
}
// Reserve to see whether we can make the request based on actual price
err = ps.accounting.Reserve(ctx, peer, receiptPrice) err = ps.accounting.Reserve(ctx, peer, receiptPrice)
if err != nil { if err != nil {
return nil, fmt.Errorf("reserve balance for peer %s: %w", peer.String(), err) return nil, fmt.Errorf("reserve balance for peer %s: %w", peer.String(), err)
......
...@@ -22,7 +22,6 @@ import ( ...@@ -22,7 +22,6 @@ import (
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
pricermock "github.com/ethersphere/bee/pkg/pricer/mock" pricermock "github.com/ethersphere/bee/pkg/pricer/mock"
"github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/pushsync/pb" "github.com/ethersphere/bee/pkg/pushsync/pb"
...@@ -107,65 +106,6 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) { ...@@ -107,65 +106,6 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) {
} }
} }
func TestSendChunkAfterPriceUpdate(t *testing.T) {
// chunk data to upload
chunk := testingc.FixtureChunk("7000")
// create a pivot node and a mocked closest node
pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000
closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") // binary 0110 -> po 1
// peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to
serverPrice := uint64(17)
serverPrices := pricerParameters{price: serverPrice, peerPrice: fixedPrice}
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, serverPrices, nil, nil, defaultSigner, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()), streamtest.WithBaseAddr(pivotNode))
// pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, _, pivotAccounting := createPushSyncNode(t, pivotNode, defaultPrices, recorder, nil, defaultSigner, mock.WithClosestPeer(closestPeer))
defer storerPivot.Close()
// Trigger the sending of chunk to the closest node
receipt, err := psPivot.PushChunkToClosest(context.Background(), chunk)
if err != nil {
t.Fatal(err)
}
if !chunk.Address().Equal(receipt.Address) {
t.Fatal("invalid receipt")
}
// this intercepts the outgoing delivery message
waitOnRecordAndTest(t, closestPeer, recorder, chunk.Address(), chunk.Data())
// this intercepts the incoming receipt message
waitOnRecordAndTest(t, closestPeer, recorder, chunk.Address(), nil)
balance, err := pivotAccounting.Balance(closestPeer)
if err != nil {
t.Fatal(err)
}
if balance.Int64() != -int64(serverPrice) {
t.Fatalf("unexpected balance on pivot. want %d got %d", -int64(serverPrice), balance)
}
balance, err = peerAccounting.Balance(pivotNode)
if err != nil {
t.Fatal(err)
}
if balance.Int64() != int64(serverPrice) {
t.Fatalf("unexpected balance on peer. want %d got %d", int64(serverPrice), balance)
}
}
// TestReplicateBeforeReceipt tests that a chunk is pushed and a receipt is received. // TestReplicateBeforeReceipt tests that a chunk is pushed and a receipt is received.
// Also the storer node initiates a pushsync to N closest nodes of the chunk as it's sending back the receipt. // Also the storer node initiates a pushsync to N closest nodes of the chunk as it's sending back the receipt.
// The second storer should only store it and not forward it. The balance of all nodes is tested. // The second storer should only store it and not forward it. The balance of all nodes is tested.
...@@ -174,8 +114,6 @@ func TestReplicateBeforeReceipt(t *testing.T) { ...@@ -174,8 +114,6 @@ func TestReplicateBeforeReceipt(t *testing.T) {
// chunk data to upload // chunk data to upload
chunk := testingc.FixtureChunk("7000") // base 0111 chunk := testingc.FixtureChunk("7000") // base 0111
neighborPrice := pricerParameters{price: 0, peerPrice: 0}
// create a pivot node and a mocked closest node // create a pivot node and a mocked closest node
pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000 pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000
closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") // binary 0110 closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") // binary 0110
...@@ -187,9 +125,13 @@ func TestReplicateBeforeReceipt(t *testing.T) { ...@@ -187,9 +125,13 @@ func TestReplicateBeforeReceipt(t *testing.T) {
_, storerEmpty, _, _ := createPushSyncNode(t, emptyPeer, defaultPrices, nil, nil, defaultSigner) _, storerEmpty, _, _ := createPushSyncNode(t, emptyPeer, defaultPrices, nil, nil, defaultSigner)
defer storerEmpty.Close() defer storerEmpty.Close()
wFunc := func(addr swarm.Address) bool {
return true
}
// node that is connected to closestPeer // node that is connected to closestPeer
// will receieve chunk from closestPeer // will receieve chunk from closestPeer
psSecond, storerSecond, _, secondAccounting := createPushSyncNode(t, secondPeer, neighborPrice, nil, nil, defaultSigner, mock.WithPeers(emptyPeer)) psSecond, storerSecond, _, secondAccounting := createPushSyncNode(t, secondPeer, defaultPrices, nil, nil, defaultSigner, mock.WithPeers(emptyPeer), mock.WithIsWithinFunc(wFunc))
defer storerSecond.Close() defer storerSecond.Close()
secondRecorder := streamtest.New(streamtest.WithProtocols(psSecond.Protocol()), streamtest.WithBaseAddr(closestPeer)) secondRecorder := streamtest.New(streamtest.WithProtocols(psSecond.Protocol()), streamtest.WithBaseAddr(closestPeer))
...@@ -253,16 +195,16 @@ func TestReplicateBeforeReceipt(t *testing.T) { ...@@ -253,16 +195,16 @@ func TestReplicateBeforeReceipt(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if balance.Int64() != 0 { if balance.Int64() != int64(fixedPrice) {
t.Fatalf("unexpected balance on second storer. want %d got %d", 0, balance) t.Fatalf("unexpected balance on second storer. want %d got %d", int64(fixedPrice), balance)
} }
balance, err = storerAccounting.Balance(secondPeer) balance, err = storerAccounting.Balance(secondPeer)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if balance.Int64() != 0 { if balance.Int64() != -int64(fixedPrice) {
t.Fatalf("unexpected balance on storer node. want %d got %d", 0, balance) t.Fatalf("unexpected balance on storer node. want %d got %d", -int64(fixedPrice), balance)
} }
} }
...@@ -566,96 +508,6 @@ func TestHandler(t *testing.T) { ...@@ -566,96 +508,6 @@ func TestHandler(t *testing.T) {
} }
} }
func TestHandlerWithUpdate(t *testing.T) {
// chunk data to upload
chunk := testingc.FixtureChunk("7000")
serverPrice := uint64(17)
serverPrices := pricerParameters{price: serverPrice, peerPrice: fixedPrice}
// create a pivot node and a mocked closest node
triggerPeer := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000")
pivotPeer := swarm.MustParseHexAddress("5000000000000000000000000000000000000000000000000000000000000000")
closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000")
// Create the closest peer with default prices (10)
psClosestPeer, closestStorerPeerDB, _, closestAccounting := createPushSyncNode(t, closestPeer, defaultPrices, nil, nil, defaultSigner, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer closestStorerPeerDB.Close()
closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol()), streamtest.WithBaseAddr(pivotPeer))
// creating the pivot peer who will act as a forwarder node with a higher price (17)
psPivot, storerPivotDB, _, pivotAccounting := createPushSyncNode(t, pivotPeer, serverPrices, closestRecorder, nil, defaultSigner, mock.WithClosestPeer(closestPeer))
defer storerPivotDB.Close()
pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol()), streamtest.WithBaseAddr(triggerPeer))
// Creating the trigger peer with default price (10)
psTriggerPeer, triggerStorerDB, _, triggerAccounting := createPushSyncNode(t, triggerPeer, defaultPrices, pivotRecorder, nil, defaultSigner, mock.WithClosestPeer(pivotPeer))
defer triggerStorerDB.Close()
receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk)
if err != nil {
t.Fatal(err)
}
if !chunk.Address().Equal(receipt.Address) {
t.Fatal("invalid receipt")
}
// In pivot peer, intercept the incoming delivery chunk from the trigger peer and check for correctness
waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunk.Address(), chunk.Data())
// Pivot peer will forward the chunk to its closest peer. Intercept the incoming stream from pivot node and check
// for the correctness of the chunk
waitOnRecordAndTest(t, closestPeer, closestRecorder, chunk.Address(), chunk.Data())
// Similarly intercept the same incoming stream to see if the closest peer is sending a proper receipt
waitOnRecordAndTest(t, closestPeer, closestRecorder, chunk.Address(), nil)
// In the received stream, check if a receipt is sent from pivot peer and check for its correctness.
waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunk.Address(), nil)
balance, err := triggerAccounting.Balance(pivotPeer)
if err != nil {
t.Fatal(err)
}
// balance on triggering peer towards the forwarder should show negative serverPrice (17)
if balance.Int64() != -int64(serverPrice) {
t.Fatalf("unexpected balance on trigger. want %d got %d", -int64(serverPrice), balance)
}
// we need to check here for pivotPeer instead of triggerPeer because during streamtest the peer in the handler is actually the receiver
// balance on forwarding peer for the triggering peer should show serverPrice (17)
balance, err = pivotAccounting.Balance(triggerPeer)
if err != nil {
t.Fatal(err)
}
if balance.Int64() != int64(serverPrice) {
t.Fatalf("unexpected balance on pivot. want %d got %d", int64(serverPrice), balance)
}
balance, err = pivotAccounting.Balance(closestPeer)
if err != nil {
t.Fatal(err)
}
// balance of the forwarder peer for the closest peer should show negative default price (10)
if balance.Int64() != -int64(fixedPrice) {
t.Fatalf("unexpected balance on pivot. want %d got %d", -int64(fixedPrice), balance)
}
balance, err = closestAccounting.Balance(pivotPeer)
if err != nil {
t.Fatal(err)
}
// balance of the closest peer for the forwarder peer should show the default price (10)
if balance.Int64() != int64(fixedPrice) {
t.Fatalf("unexpected balance on closest. want %d got %d", int64(fixedPrice), balance)
}
}
func TestSignsReceipt(t *testing.T) { func TestSignsReceipt(t *testing.T) {
// chunk data to upload // chunk data to upload
...@@ -711,13 +563,7 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, prices pricerParameter ...@@ -711,13 +563,7 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, prices pricerParameter
mtag := tags.NewTags(mockStatestore, logger) mtag := tags.NewTags(mockStatestore, logger)
mockAccounting := accountingmock.NewAccounting() mockAccounting := accountingmock.NewAccounting()
headlerFunc := func(h p2p.Headers, a swarm.Address) p2p.Headers { mockPricer := pricermock.NewMockService(prices.price, prices.peerPrice)
target, _ := headerutils.ParseTargetHeader(h)
headers, _ := headerutils.MakePricingResponseHeaders(prices.price, target, 0)
return headers
}
mockPricer := pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc), pricermock.WithPrice(prices.price), pricermock.WithPeerPrice(prices.peerPrice))
recorderDisconnecter := streamtest.NewRecorderDisconnecter(recorder) recorderDisconnecter := streamtest.NewRecorderDisconnecter(recorder)
if unwrap == nil { if unwrap == nil {
......
...@@ -220,7 +220,7 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store ...@@ -220,7 +220,7 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store
mockStorer := storemock.NewStorer() mockStorer := storemock.NewStorer()
serverMockAccounting := accountingmock.NewAccounting() serverMockAccounting := accountingmock.NewAccounting()
pricerMock := pricermock.NewMockService() pricerMock := pricermock.NewMockService(10, 10)
peerID := swarm.MustParseHexAddress("deadbeef") peerID := swarm.MustParseHexAddress("deadbeef")
ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error { ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(peerID, 0) _, _, _ = f(peerID, 0)
......
...@@ -21,7 +21,6 @@ import ( ...@@ -21,7 +21,6 @@ import (
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pricer" "github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
pb "github.com/ethersphere/bee/pkg/retrieval/pb" pb "github.com/ethersphere/bee/pkg/retrieval/pb"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -81,7 +80,6 @@ func (s *Service) Protocol() p2p.ProtocolSpec { ...@@ -81,7 +80,6 @@ func (s *Service) Protocol() p2p.ProtocolSpec {
{ {
Name: streamName, Name: streamName,
Handler: s.handler, Handler: s.handler,
Headler: s.pricer.PriceHeadler,
}, },
}, },
} }
...@@ -204,23 +202,16 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski ...@@ -204,23 +202,16 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski
sp.Add(peer) sp.Add(peer)
// compute the price we presume to pay for this chunk for price header // compute the peer's price for this chunk for price header
chunkPrice := s.pricer.PeerPrice(peer, addr) chunkPrice := s.pricer.PeerPrice(peer, addr)
headers, err := headerutils.MakePricingHeaders(chunkPrice, addr)
if err != nil {
return nil, swarm.Address{}, err
}
s.logger.Tracef("retrieval: requesting chunk %s from peer %s", addr, peer) s.logger.Tracef("retrieval: requesting chunk %s from peer %s", addr, peer)
stream, err := s.streamer.NewStream(ctx, peer, headers, protocolName, protocolVersion, streamName) stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil { if err != nil {
s.metrics.TotalErrors.Inc() s.metrics.TotalErrors.Inc()
return nil, peer, fmt.Errorf("new stream: %w", err) return nil, peer, fmt.Errorf("new stream: %w", err)
} }
returnedHeaders := stream.Headers()
defer func() { defer func() {
if err != nil { if err != nil {
_ = stream.Reset() _ = stream.Reset()
...@@ -229,21 +220,7 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski ...@@ -229,21 +220,7 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski
} }
}() }()
_, returnedPrice, returnedIndex, err := headerutils.ParsePricingResponseHeaders(returnedHeaders) // Reserve to see whether we can request the chunk
if err != nil {
return nil, peer, fmt.Errorf("retrieval headers: read returned: %w", err)
}
// check if returned price matches presumed price, if not, update price
if returnedPrice != chunkPrice {
err = s.pricer.NotifyPeerPrice(peer, returnedPrice, returnedIndex) // save priceHeaders["price"] corresponding row for peer
if err != nil {
return nil, peer, err
}
chunkPrice = returnedPrice
}
// Reserve to see whether we can request the chunk based on actual price
err = s.accounting.Reserve(ctx, peer, chunkPrice) err = s.accounting.Reserve(ctx, peer, chunkPrice)
if err != nil { if err != nil {
return nil, peer, err return nil, peer, err
...@@ -383,15 +360,8 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -383,15 +360,8 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String()) s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String())
// to get price Read in headler, chunkPrice := s.pricer.Price(chunk.Address())
returnedHeaders := stream.ResponseHeaders()
chunkPrice, err := headerutils.ParsePriceHeader(returnedHeaders)
if err != nil {
// if not found in returned header, compute the price we charge for this chunk and
s.logger.Warningf("retrieval: peer %v no price in previously issued response headers: %v", p.Address, err)
chunkPrice = s.pricer.PriceForPeer(p.Address, chunk.Address())
}
// debit price from p's balance // debit price from p's balance
return s.accounting.Debit(p.Address, chunkPrice) return s.accounting.Debit(p.Address, chunkPrice)
} }
...@@ -11,7 +11,6 @@ import ( ...@@ -11,7 +11,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/big"
"sync" "sync"
"testing" "testing"
"time" "time"
...@@ -22,7 +21,6 @@ import ( ...@@ -22,7 +21,6 @@ import (
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
pricermock "github.com/ethersphere/bee/pkg/pricer/mock" pricermock "github.com/ethersphere/bee/pkg/pricer/mock"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
pb "github.com/ethersphere/bee/pkg/retrieval/pb" pb "github.com/ethersphere/bee/pkg/retrieval/pb"
...@@ -33,17 +31,15 @@ import ( ...@@ -33,17 +31,15 @@ import (
"github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/topology"
) )
var testTimeout = 5 * time.Second var (
testTimeout = 5 * time.Second
defaultPrice = uint64(10)
)
// TestDelivery tests that a naive request -> delivery flow works. // TestDelivery tests that a naive request -> delivery flow works.
func TestDelivery(t *testing.T) { func TestDelivery(t *testing.T) {
var ( var (
chunk = testingc.FixtureChunk("0033") chunk = testingc.FixtureChunk("0033")
headlerFunc = func(h p2p.Headers, a swarm.Address) p2p.Headers {
headers, _ := headerutils.MakePricingResponseHeaders(10, chunk.Address(), 0)
return headers
}
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
mockStorer = storemock.NewStorer() mockStorer = storemock.NewStorer()
clientMockAccounting = accountingmock.NewAccounting() clientMockAccounting = accountingmock.NewAccounting()
...@@ -51,8 +47,7 @@ func TestDelivery(t *testing.T) { ...@@ -51,8 +47,7 @@ func TestDelivery(t *testing.T) {
clientAddr = swarm.MustParseHexAddress("9ee7add8") clientAddr = swarm.MustParseHexAddress("9ee7add8")
serverAddr = swarm.MustParseHexAddress("9ee7add7") serverAddr = swarm.MustParseHexAddress("9ee7add7")
price = uint64(10) pricerMock = pricermock.NewMockService(defaultPrice, defaultPrice)
pricerMock = pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc))
) )
// put testdata in the mock store of the server // put testdata in the mock store of the server
...@@ -132,136 +127,21 @@ func TestDelivery(t *testing.T) { ...@@ -132,136 +127,21 @@ func TestDelivery(t *testing.T) {
} }
clientBalance, _ := clientMockAccounting.Balance(serverAddr) clientBalance, _ := clientMockAccounting.Balance(serverAddr)
if clientBalance.Int64() != -int64(price) { if clientBalance.Int64() != -int64(defaultPrice) {
t.Fatalf("unexpected balance on client. want %d got %d", -price, clientBalance) t.Fatalf("unexpected balance on client. want %d got %d", -defaultPrice, clientBalance)
}
serverBalance, _ := serverMockAccounting.Balance(clientAddr)
if serverBalance.Int64() != int64(price) {
t.Fatalf("unexpected balance on server. want %d got %d", price, serverBalance)
}
}
// TestDelivery tests that a naive request -> delivery flow works.
func TestDeliveryWithPriceUpdate(t *testing.T) {
var (
price = uint64(10)
serverPrice = uint64(17)
chunk = testingc.FixtureChunk("0033")
headlerFunc = func(h p2p.Headers, a swarm.Address) p2p.Headers {
headers, _ := headerutils.MakePricingResponseHeaders(serverPrice, chunk.Address(), 5)
return headers
}
logger = logging.New(ioutil.Discard, 0)
mockStorer = storemock.NewStorer()
clientMockAccounting = accountingmock.NewAccounting()
serverMockAccounting = accountingmock.NewAccounting()
clientAddr = swarm.MustParseHexAddress("9ee7add8")
serverAddr = swarm.MustParseHexAddress("9ee7add7")
clientPricerMock = pricermock.NewMockService(pricermock.WithPeerPrice(price))
serverPricerMock = pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc), pricermock.WithPrice(serverPrice))
)
// put testdata in the mock store of the server
_, err := mockStorer.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil {
t.Fatal(err)
}
// create the server that will handle the request and will serve the response
server := retrieval.New(swarm.MustParseHexAddress("0034"), mockStorer, nil, nil, logger, serverMockAccounting, serverPricerMock, nil)
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
streamtest.WithBaseAddr(clientAddr),
)
// client mock storer does not store any data at this point
// but should be checked at at the end of the test for the
// presence of the chunk address key and value to ensure delivery
// was successful
clientMockStorer := storemock.NewStorer()
ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(serverAddr, 0)
return nil
}}
client := retrieval.New(clientAddr, clientMockStorer, recorder, ps, logger, clientMockAccounting, clientPricerMock, nil)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
v, err := client.RetrieveChunk(ctx, chunk.Address())
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(v.Data(), chunk.Data()) {
t.Fatalf("request and response data not equal. got %s want %s", v, chunk.Data())
}
records, err := recorder.Records(serverAddr, "retrieval", "1.0.0", "retrieval")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want %v", l, 1)
}
record := records[0]
messages, err := protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.Request) },
)
if err != nil {
t.Fatal(err)
}
var reqs []string
for _, m := range messages {
reqs = append(reqs, hex.EncodeToString(m.(*pb.Request).Addr))
}
if len(reqs) != 1 {
t.Fatalf("got too many requests. want 1 got %d", len(reqs))
}
messages, err = protobuf.ReadMessages(
bytes.NewReader(record.Out()),
func() protobuf.Message { return new(pb.Delivery) },
)
if err != nil {
t.Fatal(err)
}
var gotDeliveries []string
for _, m := range messages {
gotDeliveries = append(gotDeliveries, string(m.(*pb.Delivery).Data))
}
if len(gotDeliveries) != 1 {
t.Fatalf("got too many deliveries. want 1 got %d", len(gotDeliveries))
}
clientBalance, _ := clientMockAccounting.Balance(serverAddr)
if clientBalance.Cmp(big.NewInt(-int64(serverPrice))) != 0 {
t.Fatalf("unexpected balance on client. want %d got %d", -serverPrice, clientBalance)
} }
serverBalance, _ := serverMockAccounting.Balance(clientAddr) serverBalance, _ := serverMockAccounting.Balance(clientAddr)
if serverBalance.Cmp(big.NewInt(int64(serverPrice))) != 0 { if serverBalance.Int64() != int64(defaultPrice) {
t.Fatalf("unexpected balance on server. want %d got %d", serverPrice, serverBalance) t.Fatalf("unexpected balance on server. want %d got %d", defaultPrice, serverBalance)
} }
} }
func TestRetrieveChunk(t *testing.T) { func TestRetrieveChunk(t *testing.T) {
var ( var (
headlerFunc = func(h p2p.Headers, a swarm.Address) p2p.Headers {
target, _ := headerutils.ParseTargetHeader(h)
headers, _ := headerutils.MakePricingResponseHeaders(10, target, 0)
return headers
}
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
pricer = pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc)) pricer = pricermock.NewMockService(defaultPrice, defaultPrice)
) )
// requesting a chunk from downstream peer is expected // requesting a chunk from downstream peer is expected
...@@ -362,14 +242,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -362,14 +242,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
chunk := testingc.FixtureChunk("0025") chunk := testingc.FixtureChunk("0025")
someOtherChunk := testingc.FixtureChunk("0033") someOtherChunk := testingc.FixtureChunk("0033")
headlerFunc := func(h p2p.Headers, a swarm.Address) p2p.Headers { pricerMock := pricermock.NewMockService(defaultPrice, defaultPrice)
target, _ := headerutils.ParseTargetHeader(h)
headers, _ := headerutils.MakePricingResponseHeaders(10, target, 0)
return headers
}
price := uint64(10)
pricerMock := pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc))
clientAddress := swarm.MustParseHexAddress("1010") clientAddress := swarm.MustParseHexAddress("1010")
...@@ -559,12 +432,12 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -559,12 +432,12 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
clientServer1Balance, _ := clientMockAccounting.Balance(serverAddress1) clientServer1Balance, _ := clientMockAccounting.Balance(serverAddress1)
if clientServer1Balance.Int64() != 0 { if clientServer1Balance.Int64() != 0 {
t.Fatalf("unexpected balance on client. want %d got %d", -price, clientServer1Balance) t.Fatalf("unexpected balance on client. want %d got %d", -defaultPrice, clientServer1Balance)
} }
clientServer2Balance, _ := clientMockAccounting.Balance(serverAddress2) clientServer2Balance, _ := clientMockAccounting.Balance(serverAddress2)
if clientServer2Balance.Int64() != -int64(price) { if clientServer2Balance.Int64() != -int64(defaultPrice) {
t.Fatalf("unexpected balance on client. want %d got %d", -price, clientServer2Balance) t.Fatalf("unexpected balance on client. want %d got %d", -defaultPrice, clientServer2Balance)
} }
// wait and check balance again // wait and check balance again
...@@ -572,13 +445,13 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -572,13 +445,13 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
clientServer1Balance, _ = clientMockAccounting.Balance(serverAddress1) clientServer1Balance, _ = clientMockAccounting.Balance(serverAddress1)
if clientServer1Balance.Int64() != -int64(price) { if clientServer1Balance.Int64() != -int64(defaultPrice) {
t.Fatalf("unexpected balance on client. want %d got %d", -price, clientServer1Balance) t.Fatalf("unexpected balance on client. want %d got %d", -defaultPrice, clientServer1Balance)
} }
clientServer2Balance, _ = clientMockAccounting.Balance(serverAddress2) clientServer2Balance, _ = clientMockAccounting.Balance(serverAddress2)
if clientServer2Balance.Int64() != -int64(price) { if clientServer2Balance.Int64() != -int64(defaultPrice) {
t.Fatalf("unexpected balance on client. want %d got %d", -price, clientServer2Balance) t.Fatalf("unexpected balance on client. want %d got %d", -defaultPrice, clientServer2Balance)
} }
}) })
......
...@@ -18,6 +18,7 @@ type mock struct { ...@@ -18,6 +18,7 @@ type mock struct {
closestPeerErr error closestPeerErr error
peersErr error peersErr error
addPeersErr error addPeersErr error
isWithinFunc func(c swarm.Address) bool
marshalJSONFunc func() ([]byte, error) marshalJSONFunc func() ([]byte, error)
mtx sync.Mutex mtx sync.Mutex
} }
...@@ -52,6 +53,12 @@ func WithMarshalJSONFunc(f func() ([]byte, error)) Option { ...@@ -52,6 +53,12 @@ func WithMarshalJSONFunc(f func() ([]byte, error)) Option {
}) })
} }
func WithIsWithinFunc(f func(swarm.Address) bool) Option {
return optionFunc(func(d *mock) {
d.isWithinFunc = f
})
}
func NewTopologyDriver(opts ...Option) topology.Driver { func NewTopologyDriver(opts ...Option) topology.Driver {
d := new(mock) d := new(mock)
for _, o := range opts { for _, o := range opts {
...@@ -139,6 +146,9 @@ func (*mock) NeighborhoodDepth() uint8 { ...@@ -139,6 +146,9 @@ func (*mock) NeighborhoodDepth() uint8 {
} }
func (m *mock) IsWithinDepth(addr swarm.Address) bool { func (m *mock) IsWithinDepth(addr swarm.Address) bool {
if m.isWithinFunc != nil {
return m.isWithinFunc(addr)
}
return false return false
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment