Commit 1f31cf10 authored by metacertain's avatar metacertain Committed by GitHub

Pricing update strategy (#1134)

- add pricer pkg
- 0 out cost of in-neighborhood traffic
- communicate pricetable in pricing protocol
- add pricing headers and price communication to protocols
parent abd00984
...@@ -59,7 +59,7 @@ jobs: ...@@ -59,7 +59,7 @@ jobs:
run: | run: |
echo -e "127.0.0.10\tregistry.localhost" | sudo tee -a /etc/hosts echo -e "127.0.0.10\tregistry.localhost" | sudo tee -a /etc/hosts
for ((i=0; i<REPLICA; i++)); do echo -e "127.0.1.$((i+1))\tbee-${i}.localhost bee-${i}-debug.localhost"; done | sudo tee -a /etc/hosts for ((i=0; i<REPLICA; i++)); do echo -e "127.0.1.$((i+1))\tbee-${i}.localhost bee-${i}-debug.localhost"; done | sudo tee -a /etc/hosts
./beeinfra.sh install --local -r "${REPLICA}" --bootnode /dnsaddr/localhost --geth --k3s --pay-threshold 1000000000000 ./beeinfra.sh install --local -r "${REPLICA}" --bootnode /dnsaddr/localhost --geth --k3s
- name: Test pingpong - name: Test pingpong
id: pingpong-1 id: pingpong-1
run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done
...@@ -68,7 +68,7 @@ jobs: ...@@ -68,7 +68,7 @@ jobs:
run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Test settlements - name: Test settlements
id: settlements-1 id: settlements-1
run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 1000000000000 run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --expect-settlements=false
- name: Test pushsync (bytes) - name: Test pushsync (bytes)
id: pushsync-bytes-1 id: pushsync-bytes-1
run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3 run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3
...@@ -105,7 +105,7 @@ jobs: ...@@ -105,7 +105,7 @@ jobs:
cp /etc/rancher/k3s/k3s.yaml ~/.kube/config cp /etc/rancher/k3s/k3s.yaml ~/.kube/config
- name: Set testing cluster (Node connection and clef enabled) - name: Set testing cluster (Node connection and clef enabled)
run: | run: |
./beeinfra.sh install --local -r "${REPLICA}" --geth --clef --k3s --pay-threshold 1000000000000 ./beeinfra.sh install --local -r "${REPLICA}" --geth --clef --k3s
- name: Test pingpong - name: Test pingpong
id: pingpong-2 id: pingpong-2
run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done
...@@ -114,7 +114,7 @@ jobs: ...@@ -114,7 +114,7 @@ jobs:
run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Test settlements - name: Test settlements
id: settlements-2 id: settlements-2
run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 1000000000000 run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --expect-settlements=false
- name: Test pushsync (bytes) - name: Test pushsync (bytes)
id: pushsync-bytes-2 id: pushsync-bytes-2
run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3 run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3
......
...@@ -526,6 +526,13 @@ func (p *pricingMock) AnnouncePaymentThreshold(ctx context.Context, peer swarm.A ...@@ -526,6 +526,13 @@ func (p *pricingMock) AnnouncePaymentThreshold(ctx context.Context, peer swarm.A
return nil return nil
} }
func (p *pricingMock) AnnouncePaymentThresholdAndPriceTable(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error {
p.called = true
p.peer = peer
p.paymentThreshold = paymentThreshold
return nil
}
func TestAccountingConnected(t *testing.T) { func TestAccountingConnected(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"github.com/ethersphere/bee/pkg/swarm"
)
type MockPricer struct {
peerPrice uint64
price uint64
}
func NewPricer(price, peerPrice uint64) *MockPricer {
return &MockPricer{
peerPrice: peerPrice,
price: price,
}
}
func (pricer *MockPricer) PeerPrice(peer, chunk swarm.Address) uint64 {
return pricer.peerPrice
}
func (pricer *MockPricer) Price(chunk swarm.Address) uint64 {
return pricer.price
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package accounting
import (
"github.com/ethersphere/bee/pkg/swarm"
)
// Pricer returns pricing information for chunk hashes.
type Pricer interface {
// PeerPrice is the price the peer charges for a given chunk hash.
PeerPrice(peer, chunk swarm.Address) uint64
// Price is the price we charge for a given chunk hash.
Price(chunk swarm.Address) uint64
}
// FixedPricer is a Pricer that has a fixed price for chunks.
type FixedPricer struct {
overlay swarm.Address
poPrice uint64
}
// NewFixedPricer returns a new FixedPricer with a given price.
func NewFixedPricer(overlay swarm.Address, poPrice uint64) *FixedPricer {
return &FixedPricer{
overlay: overlay,
poPrice: poPrice,
}
}
// PeerPrice implements Pricer.
func (pricer *FixedPricer) PeerPrice(peer, chunk swarm.Address) uint64 {
return uint64(swarm.MaxPO-swarm.Proximity(peer.Bytes(), chunk.Bytes())+1) * pricer.poPrice
}
// Price implements Pricer.
func (pricer *FixedPricer) Price(chunk swarm.Address) uint64 {
return pricer.PeerPrice(pricer.overlay, chunk)
}
// Copyright 2021 The Swarm Authors. All rights reserved. // Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
...@@ -35,6 +35,7 @@ import ( ...@@ -35,6 +35,7 @@ import (
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/p2p/libp2p" "github.com/ethersphere/bee/pkg/p2p/libp2p"
"github.com/ethersphere/bee/pkg/pingpong" "github.com/ethersphere/bee/pkg/pingpong"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricing" "github.com/ethersphere/bee/pkg/pricing"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/puller" "github.com/ethersphere/bee/pkg/puller"
...@@ -310,6 +311,35 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -310,6 +311,35 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
var settlement settlement.Interface var settlement settlement.Interface
var swapService *swap.Service var swapService *swap.Service
kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, StandaloneMode: o.Standalone, BootnodeMode: o.BootnodeMode})
b.topologyCloser = kad
hive.SetAddPeersHandler(kad.AddPeers)
p2ps.SetPickyNotifier(kad)
paymentThreshold, ok := new(big.Int).SetString(o.PaymentThreshold, 10)
if !ok {
return nil, fmt.Errorf("invalid payment threshold: %s", paymentThreshold)
}
pricer := pricer.New(logger, stateStore, swarmAddress, 1000000000)
pricer.SetTopology(kad)
pricing := pricing.New(p2ps, logger, paymentThreshold, pricer)
pricing.SetPriceTableObserver(pricer)
if err = p2ps.AddProtocol(pricing.Protocol()); err != nil {
return nil, fmt.Errorf("pricing service: %w", err)
}
addrs, err := p2ps.Addresses()
if err != nil {
return nil, fmt.Errorf("get server addresses: %w", err)
}
for _, addr := range addrs {
logger.Debugf("p2p address: %s", addr)
}
if o.SwapEnable { if o.SwapEnable {
swapService, err = InitSwap( swapService, err = InitSwap(
p2ps, p2ps,
...@@ -333,15 +363,6 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -333,15 +363,6 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
settlement = pseudosettleService settlement = pseudosettleService
} }
paymentThreshold, ok := new(big.Int).SetString(o.PaymentThreshold, 10)
if !ok {
return nil, fmt.Errorf("invalid payment threshold: %s", paymentThreshold)
}
pricing := pricing.New(p2ps, logger, paymentThreshold)
if err = p2ps.AddProtocol(pricing.Protocol()); err != nil {
return nil, fmt.Errorf("pricing service: %w", err)
}
paymentTolerance, ok := new(big.Int).SetString(o.PaymentTolerance, 10) paymentTolerance, ok := new(big.Int).SetString(o.PaymentTolerance, 10)
if !ok { if !ok {
return nil, fmt.Errorf("invalid payment tolerance: %s", paymentTolerance) return nil, fmt.Errorf("invalid payment tolerance: %s", paymentTolerance)
...@@ -359,25 +380,13 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -359,25 +380,13 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
settlement, settlement,
pricing, pricing,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("accounting: %w", err) return nil, fmt.Errorf("accounting: %w", err)
} }
settlement.SetNotifyPaymentFunc(acc.AsyncNotifyPayment)
pricing.SetPaymentThresholdObserver(acc) pricing.SetPaymentThresholdObserver(acc)
settlement.SetNotifyPaymentFunc(acc.AsyncNotifyPayment)
kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, StandaloneMode: o.Standalone, BootnodeMode: o.BootnodeMode})
b.topologyCloser = kad
hive.SetAddPeersHandler(kad.AddPeers)
p2ps.SetPickyNotifier(kad)
addrs, err := p2ps.Addresses()
if err != nil {
return nil, fmt.Errorf("get server addresses: %w", err)
}
for _, addr := range addrs {
logger.Debugf("p2p address: %s", addr)
}
var path string var path string
...@@ -397,7 +406,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -397,7 +406,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
} }
b.localstoreCloser = storer b.localstoreCloser = storer
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 1000000000), tracer) retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer)
tagService := tags.NewTags(stateStore, logger) tagService := tags.NewTags(stateStore, logger)
b.tagsCloser = tagService b.tagsCloser = tagService
...@@ -419,7 +428,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -419,7 +428,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
traversalService := traversal.NewService(ns) traversalService := traversal.NewService(ns)
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, pssService.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 1000000000), tracer) pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, pssService.TryUnwrap, logger, acc, pricer, tracer)
// set the pushSyncer in the PSS // set the pushSyncer in the PSS
pssService.SetPushSyncer(pushSyncProtocol) pssService.SetPushSyncer(pushSyncProtocol)
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/headers/pb" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/headers/pb"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/swarm"
) )
var sendHeadersTimeout = 10 * time.Second var sendHeadersTimeout = 10 * time.Second
...@@ -36,7 +37,7 @@ func sendHeaders(ctx context.Context, headers p2p.Headers, stream *stream) error ...@@ -36,7 +37,7 @@ func sendHeaders(ctx context.Context, headers p2p.Headers, stream *stream) error
return nil return nil
} }
func handleHeaders(headler p2p.HeadlerFunc, stream *stream) error { func handleHeaders(headler p2p.HeadlerFunc, stream *stream, peerAddress swarm.Address) error {
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
ctx, cancel := context.WithTimeout(context.Background(), sendHeadersTimeout) ctx, cancel := context.WithTimeout(context.Background(), sendHeadersTimeout)
...@@ -51,9 +52,11 @@ func handleHeaders(headler p2p.HeadlerFunc, stream *stream) error { ...@@ -51,9 +52,11 @@ func handleHeaders(headler p2p.HeadlerFunc, stream *stream) error {
var h p2p.Headers var h p2p.Headers
if headler != nil { if headler != nil {
h = headler(stream.headers) h = headler(stream.headers, peerAddress)
} }
stream.responseHeaders = h
if err := w.WriteMsgWithContext(ctx, headersP2PToPB(h)); err != nil { if err := w.WriteMsgWithContext(ctx, headersP2PToPB(h)); err != nil {
return fmt.Errorf("write message: %w", err) return fmt.Errorf("write message: %w", err)
} }
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm"
) )
func TestHeaders(t *testing.T) { func TestHeaders(t *testing.T) {
...@@ -140,7 +141,7 @@ func TestHeadler(t *testing.T) { ...@@ -140,7 +141,7 @@ func TestHeadler(t *testing.T) {
Handler: func(_ context.Context, _ p2p.Peer, stream p2p.Stream) error { Handler: func(_ context.Context, _ p2p.Peer, stream p2p.Stream) error {
return nil return nil
}, },
Headler: func(headers p2p.Headers) p2p.Headers { Headler: func(headers p2p.Headers, address swarm.Address) p2p.Headers {
defer close(handled) defer close(handled)
gotReceivedHeaders = headers gotReceivedHeaders = headers
return sentHeaders return sentHeaders
......
...@@ -29,7 +29,7 @@ const ( ...@@ -29,7 +29,7 @@ const (
// ProtocolName is the text of the name of the handshake protocol. // ProtocolName is the text of the name of the handshake protocol.
ProtocolName = "handshake" ProtocolName = "handshake"
// ProtocolVersion is the current handshake protocol version. // ProtocolVersion is the current handshake protocol version.
ProtocolVersion = "2.0.0" ProtocolVersion = "3.0.0"
// StreamName is the name of the stream used for handshake purposes. // StreamName is the name of the stream used for handshake purposes.
StreamName = "handshake" StreamName = "handshake"
// MaxWelcomeMessageLength is maximum number of characters allowed in the welcome message. // MaxWelcomeMessageLength is maximum number of characters allowed in the welcome message.
......
...@@ -56,6 +56,10 @@ func (s *Stream) Headers() p2p.Headers { ...@@ -56,6 +56,10 @@ func (s *Stream) Headers() p2p.Headers {
return nil return nil
} }
func (s *Stream) ResponseHeaders() p2p.Headers {
return nil
}
func (s *Stream) Close() error { func (s *Stream) Close() error {
return nil return nil
} }
......
...@@ -372,7 +372,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) { ...@@ -372,7 +372,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
stream := newStream(streamlibp2p) stream := newStream(streamlibp2p)
// exchange headers // exchange headers
if err := handleHeaders(ss.Headler, stream); err != nil { if err := handleHeaders(ss.Headler, stream, overlay); err != nil {
s.logger.Debugf("handle protocol %s/%s: stream %s: peer %s: handle headers: %v", p.Name, p.Version, ss.Name, overlay, err) s.logger.Debugf("handle protocol %s/%s: stream %s: peer %s: handle headers: %v", p.Name, p.Version, ss.Name, overlay, err)
_ = stream.Reset() _ = stream.Reset()
return return
......
...@@ -21,7 +21,8 @@ var _ p2p.Stream = (*stream)(nil) ...@@ -21,7 +21,8 @@ var _ p2p.Stream = (*stream)(nil)
type stream struct { type stream struct {
network.Stream network.Stream
headers map[string][]byte headers map[string][]byte
responseHeaders map[string][]byte
} }
func NewStream(s network.Stream) p2p.Stream { func NewStream(s network.Stream) p2p.Stream {
...@@ -35,6 +36,10 @@ func (s *stream) Headers() p2p.Headers { ...@@ -35,6 +36,10 @@ func (s *stream) Headers() p2p.Headers {
return s.headers return s.headers
} }
func (s *stream) ResponseHeaders() p2p.Headers {
return s.responseHeaders
}
func (s *stream) FullClose() error { func (s *stream) FullClose() error {
// close the stream to make sure it is gc'd // close the stream to make sure it is gc'd
defer s.Close() defer s.Close()
......
...@@ -67,6 +67,7 @@ type StreamerDisconnecter interface { ...@@ -67,6 +67,7 @@ type StreamerDisconnecter interface {
type Stream interface { type Stream interface {
io.ReadWriter io.ReadWriter
io.Closer io.Closer
ResponseHeaders() Headers
Headers() Headers Headers() Headers
FullClose() error FullClose() error
Reset() error Reset() error
...@@ -103,7 +104,7 @@ type HandlerMiddleware func(HandlerFunc) HandlerFunc ...@@ -103,7 +104,7 @@ type HandlerMiddleware func(HandlerFunc) HandlerFunc
// HeadlerFunc is returning response headers based on the received request // HeadlerFunc is returning response headers based on the received request
// headers. // headers.
type HeadlerFunc func(Headers) Headers type HeadlerFunc func(Headers, swarm.Address) Headers
// Headers represents a collection of p2p header key value pairs. // Headers represents a collection of p2p header key value pairs.
type Headers map[string][]byte type Headers map[string][]byte
......
...@@ -323,6 +323,10 @@ func (noopWriteCloser) Headers() p2p.Headers { ...@@ -323,6 +323,10 @@ func (noopWriteCloser) Headers() p2p.Headers {
return nil return nil
} }
func (noopWriteCloser) ResponseHeaders() p2p.Headers {
return nil
}
func (noopWriteCloser) Close() error { func (noopWriteCloser) Close() error {
return nil return nil
} }
...@@ -351,6 +355,10 @@ func (noopReadCloser) Headers() p2p.Headers { ...@@ -351,6 +355,10 @@ func (noopReadCloser) Headers() p2p.Headers {
return nil return nil
} }
func (noopReadCloser) ResponseHeaders() p2p.Headers {
return nil
}
func (noopReadCloser) Close() error { func (noopReadCloser) Close() error {
return nil return nil
} }
......
...@@ -97,7 +97,7 @@ func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Head ...@@ -97,7 +97,7 @@ func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Head
handler = r.middlewares[i](handler) handler = r.middlewares[i](handler)
} }
if headler != nil { if headler != nil {
streamOut.headers = headler(h) streamOut.headers = headler(h, addr)
} }
record := &Record{in: recordIn, out: recordOut, done: make(chan struct{})} record := &Record{in: recordIn, out: recordOut, done: make(chan struct{})}
go func() { go func() {
...@@ -194,9 +194,10 @@ func (r *Record) setErr(err error) { ...@@ -194,9 +194,10 @@ func (r *Record) setErr(err error) {
} }
type stream struct { type stream struct {
in *record in *record
out *record out *record
headers p2p.Headers headers p2p.Headers
responseHeaders p2p.Headers
} }
func newStream(in, out *record) *stream { func newStream(in, out *record) *stream {
...@@ -215,6 +216,10 @@ func (s *stream) Headers() p2p.Headers { ...@@ -215,6 +216,10 @@ func (s *stream) Headers() p2p.Headers {
return s.headers return s.headers
} }
func (s *stream) ResponseHeaders() p2p.Headers {
return s.responseHeaders
}
func (s *stream) Close() error { func (s *stream) Close() error {
return s.in.Close() return s.in.Close()
} }
......
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pricer
import (
"github.com/ethersphere/bee/pkg/swarm"
)
func (s *Pricer) PeerPricePO(peer swarm.Address, po uint8) (uint64, error) {
return s.peerPricePO(peer, po)
}
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package headerutils
const (
PriceFieldName = priceFieldName
TargetFieldName = targetFieldName
IndexFieldName = indexFieldName
)
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package headerutils
import (
"encoding/binary"
"errors"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm"
)
const (
priceFieldName = "price"
targetFieldName = "target"
indexFieldName = "index"
)
var (
// ErrFieldLength denotes p2p.Header having malformed field length in bytes
ErrFieldLength = errors.New("field length error")
// ErrNoIndexHeader denotes p2p.Header lacking specified field
ErrNoIndexHeader = errors.New("no index header")
// ErrNoTargetHeader denotes p2p.Header lacking specified field
ErrNoTargetHeader = errors.New("no target header")
// ErrNoPriceHeader denotes p2p.Header lacking specified field
ErrNoPriceHeader = errors.New("no price header")
)
// Headers, utility functions
func MakePricingHeaders(chunkPrice uint64, addr swarm.Address) (p2p.Headers, error) {
chunkPriceInBytes := make([]byte, 8)
binary.BigEndian.PutUint64(chunkPriceInBytes, chunkPrice)
headers := p2p.Headers{
priceFieldName: chunkPriceInBytes,
targetFieldName: addr.Bytes(),
}
return headers, nil
}
func MakePricingResponseHeaders(chunkPrice uint64, addr swarm.Address, index uint8) (p2p.Headers, error) {
chunkPriceInBytes := make([]byte, 8)
chunkIndexInBytes := make([]byte, 1)
binary.BigEndian.PutUint64(chunkPriceInBytes, chunkPrice)
chunkIndexInBytes[0] = index
headers := p2p.Headers{
priceFieldName: chunkPriceInBytes,
targetFieldName: addr.Bytes(),
indexFieldName: chunkIndexInBytes,
}
return headers, nil
}
// ParsePricingHeaders used by responder to read address and price from stream headers
// Returns an error if no target field attached or the contents of it are not readable
func ParsePricingHeaders(receivedHeaders p2p.Headers) (swarm.Address, uint64, error) {
target, err := ParseTargetHeader(receivedHeaders)
if err != nil {
return swarm.ZeroAddress, 0, err
}
price, err := ParsePriceHeader(receivedHeaders)
if err != nil {
return swarm.ZeroAddress, 0, err
}
return target, price, nil
}
// ParsePricingResponseHeaders used by requester to read address, price and index from response headers
// Returns an error if any fields are missing or target is unreadable
func ParsePricingResponseHeaders(receivedHeaders p2p.Headers) (swarm.Address, uint64, uint8, error) {
target, err := ParseTargetHeader(receivedHeaders)
if err != nil {
return swarm.ZeroAddress, 0, 0, err
}
price, err := ParsePriceHeader(receivedHeaders)
if err != nil {
return swarm.ZeroAddress, 0, 0, err
}
index, err := ParseIndexHeader(receivedHeaders)
if err != nil {
return swarm.ZeroAddress, 0, 0, err
}
return target, price, index, nil
}
func ParseIndexHeader(receivedHeaders p2p.Headers) (uint8, error) {
if receivedHeaders[indexFieldName] == nil {
return 0, ErrNoIndexHeader
}
if len(receivedHeaders[indexFieldName]) != 1 {
return 0, ErrFieldLength
}
index := receivedHeaders[indexFieldName][0]
return index, nil
}
func ParseTargetHeader(receivedHeaders p2p.Headers) (swarm.Address, error) {
if receivedHeaders[targetFieldName] == nil {
return swarm.ZeroAddress, ErrNoTargetHeader
}
target := swarm.NewAddress(receivedHeaders[targetFieldName])
return target, nil
}
func ParsePriceHeader(receivedHeaders p2p.Headers) (uint64, error) {
if receivedHeaders[priceFieldName] == nil {
return 0, ErrNoPriceHeader
}
if len(receivedHeaders[priceFieldName]) != 8 {
return 0, ErrFieldLength
}
receivedPrice := binary.BigEndian.Uint64(receivedHeaders[priceFieldName])
return receivedPrice, nil
}
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package headerutils_test
import (
"reflect"
"testing"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestMakePricingHeaders(t *testing.T) {
addr := swarm.MustParseHexAddress("010101e1010101")
makeHeaders, err := headerutils.MakePricingHeaders(uint64(5348), addr)
if err != nil {
t.Fatal(err)
}
expectedHeaders := p2p.Headers{
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 0, 20, 228},
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
}
if !reflect.DeepEqual(makeHeaders, expectedHeaders) {
t.Fatalf("Made headers not as expected, got %+v, want %+v", makeHeaders, expectedHeaders)
}
}
func TestMakePricingResponseHeaders(t *testing.T) {
addr := swarm.MustParseHexAddress("010101e1010101")
makeHeaders, err := headerutils.MakePricingResponseHeaders(uint64(5348), addr, uint8(11))
if err != nil {
t.Fatal(err)
}
expectedHeaders := p2p.Headers{
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 0, 20, 228},
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
headerutils.IndexFieldName: []byte{11},
}
if !reflect.DeepEqual(makeHeaders, expectedHeaders) {
t.Fatalf("Made headers not as expected, got %+v, want %+v", makeHeaders, expectedHeaders)
}
}
func TestParsePricingHeaders(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 0, 20, 228},
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
}
parsedTarget, parsedPrice, err := headerutils.ParsePricingHeaders(toReadHeaders)
if err != nil {
t.Fatal(err)
}
addr := swarm.MustParseHexAddress("010101e1010101")
if parsedPrice != uint64(5348) {
t.Fatalf("Price mismatch, got %v, want %v", parsedPrice, 5348)
}
if !parsedTarget.Equal(addr) {
t.Fatalf("Target mismatch, got %v, want %v", parsedTarget, addr)
}
}
func TestParsePricingResponseHeaders(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 0, 20, 228},
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
headerutils.IndexFieldName: []byte{11},
}
parsedTarget, parsedPrice, parsedIndex, err := headerutils.ParsePricingResponseHeaders(toReadHeaders)
if err != nil {
t.Fatal(err)
}
addr := swarm.MustParseHexAddress("010101e1010101")
if parsedPrice != uint64(5348) {
t.Fatalf("Price mismatch, got %v, want %v", parsedPrice, 5348)
}
if parsedIndex != uint8(11) {
t.Fatalf("Price mismatch, got %v, want %v", parsedPrice, 5348)
}
if !parsedTarget.Equal(addr) {
t.Fatalf("Target mismatch, got %v, want %v", parsedTarget, addr)
}
}
func TestParseIndexHeader(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.IndexFieldName: []byte{11},
}
parsedIndex, err := headerutils.ParseIndexHeader(toReadHeaders)
if err != nil {
t.Fatal(err)
}
if parsedIndex != uint8(11) {
t.Fatalf("Index mismatch, got %v, want %v", parsedIndex, 11)
}
}
func TestParseTargetHeader(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
}
parsedTarget, err := headerutils.ParseTargetHeader(toReadHeaders)
if err != nil {
t.Fatal(err)
}
addr := swarm.MustParseHexAddress("010101e1010101")
if !parsedTarget.Equal(addr) {
t.Fatalf("Target mismatch, got %v, want %v", parsedTarget, addr)
}
}
func TestParsePriceHeader(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 0, 20, 228},
}
parsedPrice, err := headerutils.ParsePriceHeader(toReadHeaders)
if err != nil {
t.Fatal(err)
}
if parsedPrice != uint64(5348) {
t.Fatalf("Index mismatch, got %v, want %v", parsedPrice, 5348)
}
}
func TestReadMalformedHeaders(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.IndexFieldName: []byte{11, 0},
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 20, 228},
}
_, err := headerutils.ParseIndexHeader(toReadHeaders)
if err == nil {
t.Fatal("Expected error from bad length of index bytes")
}
_, err = headerutils.ParsePriceHeader(toReadHeaders)
if err == nil {
t.Fatal("Expected error from bad length of price bytes")
}
_, _, _, err = headerutils.ParsePricingResponseHeaders(toReadHeaders)
if err == nil {
t.Fatal("Expected error caused by bad length of fields")
}
_, _, err = headerutils.ParsePricingHeaders(toReadHeaders)
if err == nil {
t.Fatal("Expected error caused by bad length of fields")
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm"
)
type Service struct {
peerPrice uint64
price uint64
peerPriceFunc func(peer, chunk swarm.Address) uint64
priceForPeerFunc func(peer, chunk swarm.Address) uint64
priceTableFunc func() (priceTable []uint64)
notifyPriceTableFunc func(peer swarm.Address, priceTable []uint64) error
priceHeadlerFunc func(p2p.Headers, swarm.Address) p2p.Headers
notifyPeerPriceFunc func(peer swarm.Address, price uint64, index uint8) error
}
// WithReserveFunc sets the mock Reserve function
func WithPeerPriceFunc(f func(peer, chunk swarm.Address) uint64) Option {
return optionFunc(func(s *Service) {
s.peerPriceFunc = f
})
}
// WithReleaseFunc sets the mock Release function
func WithPriceForPeerFunc(f func(peer, chunk swarm.Address) uint64) Option {
return optionFunc(func(s *Service) {
s.priceForPeerFunc = f
})
}
func WithPrice(p uint64) Option {
return optionFunc(func(s *Service) {
s.price = p
})
}
func WithPeerPrice(p uint64) Option {
return optionFunc(func(s *Service) {
s.peerPrice = p
})
}
// WithPriceTableFunc sets the mock Release function
func WithPriceTableFunc(f func() (priceTable []uint64)) Option {
return optionFunc(func(s *Service) {
s.priceTableFunc = f
})
}
func WithPriceHeadlerFunc(f func(headers p2p.Headers, addr swarm.Address) p2p.Headers) Option {
return optionFunc(func(s *Service) {
s.priceHeadlerFunc = f
})
}
func NewMockService(opts ...Option) *Service {
mock := new(Service)
mock.price = 10
mock.peerPrice = 10
for _, o := range opts {
o.apply(mock)
}
return mock
}
func (pricer *Service) PeerPrice(peer, chunk swarm.Address) uint64 {
if pricer.peerPriceFunc != nil {
return pricer.peerPriceFunc(peer, chunk)
}
return pricer.peerPrice
}
func (pricer *Service) PriceForPeer(peer, chunk swarm.Address) uint64 {
if pricer.priceForPeerFunc != nil {
return pricer.priceForPeerFunc(peer, chunk)
}
return pricer.price
}
func (pricer *Service) PriceTable() (priceTable []uint64) {
if pricer.priceTableFunc != nil {
return pricer.priceTableFunc()
}
return nil
}
func (pricer *Service) NotifyPriceTable(peer swarm.Address, priceTable []uint64) error {
if pricer.notifyPriceTableFunc != nil {
return pricer.notifyPriceTableFunc(peer, priceTable)
}
return nil
}
func (pricer *Service) PriceHeadler(headers p2p.Headers, addr swarm.Address) p2p.Headers {
if pricer.priceHeadlerFunc != nil {
return pricer.priceHeadlerFunc(headers, addr)
}
return p2p.Headers{}
}
func (pricer *Service) NotifyPeerPrice(peer swarm.Address, price uint64, index uint8) error {
if pricer.notifyPeerPriceFunc != nil {
return pricer.notifyPeerPriceFunc(peer, price, index)
}
return nil
}
// Option is the option passed to the mock accounting service
type Option interface {
apply(*Service)
}
type optionFunc func(*Service)
func (f optionFunc) apply(r *Service) { f(r) }
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pricer
import (
"errors"
"fmt"
"sync"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
)
const (
priceTablePrefix string = "pricetable_"
)
var _ Interface = (*Pricer)(nil)
// Pricer returns pricing information for chunk hashes and proximity orders
type Interface interface {
// PriceTable returns pricetable stored for the node
PriceTable() []uint64
// PeerPrice is the price the peer charges for a given chunk hash.
PeerPrice(peer, chunk swarm.Address) uint64
// PriceForPeer is the price we charge a peer for a given chunk hash.
PriceForPeer(peer, chunk swarm.Address) uint64
// NotifyPriceTable saves a provided pricetable for a peer to store
NotifyPriceTable(peer swarm.Address, priceTable []uint64) error
// NotifyPeerPrice changes a value that belongs to an index in a peer pricetable
NotifyPeerPrice(peer swarm.Address, price uint64, index uint8) error
// PriceHeadler creates response headers with pricing information
PriceHeadler(p2p.Headers, swarm.Address) p2p.Headers
}
var (
ErrPersistingBalancePeer = errors.New("failed to persist pricetable for peer")
)
type pricingPeer struct {
lock sync.Mutex
}
type Pricer struct {
pricingPeersMu sync.Mutex
pricingPeers map[string]*pricingPeer
logger logging.Logger
store storage.StateStorer
overlay swarm.Address
topology topology.Driver
poPrice uint64
}
func New(logger logging.Logger, store storage.StateStorer, overlay swarm.Address, poPrice uint64) *Pricer {
return &Pricer{
logger: logger,
pricingPeers: make(map[string]*pricingPeer),
store: store,
overlay: overlay,
poPrice: poPrice,
}
}
// PriceTable returns the pricetable stored for the node
// If not available, the default pricetable is provided
func (s *Pricer) PriceTable() (priceTable []uint64) {
err := s.store.Get(priceTableKey(), &priceTable)
if err != nil {
priceTable = s.defaultPriceTable()
}
return priceTable
}
// peerPriceTable returns the price table stored for the given peer.
// If we can't get price table from store, we return the default price table
func (s *Pricer) peerPriceTable(peer swarm.Address) (priceTable []uint64) {
err := s.store.Get(peerPriceTableKey(peer), &priceTable)
if err != nil {
priceTable = s.defaultPriceTable() // get default pricetable
}
return priceTable
}
// PriceForPeer returns the price for the PO of a chunk from the table stored for the node.
// Taking into consideration that the peer might be an in-neighborhood peer,
// if the chunk is at least neighborhood depth proximate to both the node and the peer, the price is 0
func (s *Pricer) PriceForPeer(peer, chunk swarm.Address) uint64 {
proximity := swarm.Proximity(s.overlay.Bytes(), chunk.Bytes())
neighborhoodDepth := s.neighborhoodDepth()
if proximity >= neighborhoodDepth {
peerproximity := swarm.Proximity(peer.Bytes(), chunk.Bytes())
if peerproximity >= neighborhoodDepth {
return 0
}
}
price, err := s.pricePO(proximity)
if err != nil {
price = s.defaultPrice(proximity)
}
return price
}
// priceWithIndexForPeer returns price for a chunk for a given peer,
// and the index of PO in pricetable which is used
func (s *Pricer) priceWithIndexForPeer(peer, chunk swarm.Address) (price uint64, index uint8) {
proximity := swarm.Proximity(s.overlay.Bytes(), chunk.Bytes())
neighborhoodDepth := s.neighborhoodDepth()
priceTable := s.PriceTable()
if int(proximity) >= len(priceTable) {
proximity = uint8(len(priceTable) - 1)
}
if proximity >= neighborhoodDepth {
proximity = neighborhoodDepth
peerproximity := swarm.Proximity(peer.Bytes(), chunk.Bytes())
if peerproximity >= neighborhoodDepth {
return 0, proximity
}
}
return priceTable[proximity], proximity
}
// pricePO returns the price for a PO from the table stored for the node.
func (s *Pricer) pricePO(po uint8) (uint64, error) {
priceTable := s.PriceTable()
proximity := po
if int(po) >= len(priceTable) {
proximity = uint8(len(priceTable) - 1)
}
return priceTable[proximity], nil
}
// PeerPrice returns the price for the PO of a chunk from the table stored for the given peer.
// Taking into consideration that the peer might be an in-neighborhood peer,
// if the chunk is at least neighborhood depth proximate to both the node and the peer, the price is 0
func (s *Pricer) PeerPrice(peer, chunk swarm.Address) uint64 {
proximity := swarm.Proximity(peer.Bytes(), chunk.Bytes())
// Determine neighborhood depth presumed by peer based on pricetable rows
var priceTable []uint64
err := s.store.Get(peerPriceTableKey(peer), &priceTable)
peerNeighborhoodDepth := uint8(len(priceTable) - 1)
if err != nil {
peerNeighborhoodDepth = s.neighborhoodDepth()
}
// determine whether the chunk is within presumed neighborhood depth of peer
if proximity >= peerNeighborhoodDepth {
// determine if the chunk is within presumed neighborhood depth of peer to us
selfproximity := swarm.Proximity(s.overlay.Bytes(), chunk.Bytes())
if selfproximity >= peerNeighborhoodDepth {
return 0
}
}
price, err := s.peerPricePO(peer, proximity)
if err != nil {
price = s.defaultPrice(proximity)
}
return price
}
// peerPricePO returns the price for a PO from the table stored for the given peer.
func (s *Pricer) peerPricePO(peer swarm.Address, po uint8) (uint64, error) {
var priceTable []uint64
err := s.store.Get(peerPriceTableKey(peer), &priceTable)
if err != nil {
if !errors.Is(err, storage.ErrNotFound) {
return 0, err
}
priceTable = s.defaultPriceTable()
}
proximity := po
if int(po) >= len(priceTable) {
proximity = uint8(len(priceTable) - 1)
}
return priceTable[proximity], nil
}
// peerPriceTableKey returns the price table storage key for the given peer.
func peerPriceTableKey(peer swarm.Address) string {
return fmt.Sprintf("%s%s", priceTablePrefix, peer.String())
}
// priceTableKey returns the price table storage key for own price table
func priceTableKey() string {
return fmt.Sprintf("%s%s", priceTablePrefix, "self")
}
func (s *Pricer) getPricingPeer(peer swarm.Address) (*pricingPeer, error) {
s.pricingPeersMu.Lock()
defer s.pricingPeersMu.Unlock()
peerData, ok := s.pricingPeers[peer.String()]
if !ok {
peerData = &pricingPeer{}
s.pricingPeers[peer.String()] = peerData
}
return peerData, nil
}
func (s *Pricer) storePriceTable(peer swarm.Address, priceTable []uint64) error {
s.logger.Tracef("Storing pricetable %v for peer %v", priceTable, peer)
err := s.store.Put(peerPriceTableKey(peer), priceTable)
if err != nil {
return err
}
return nil
}
// NotifyPriceTable should be called to notify pricer of changes in the peers pricetable
func (s *Pricer) NotifyPriceTable(peer swarm.Address, priceTable []uint64) error {
pricingPeer, err := s.getPricingPeer(peer)
if err != nil {
return err
}
pricingPeer.lock.Lock()
defer pricingPeer.lock.Unlock()
return s.storePriceTable(peer, priceTable)
}
func (s *Pricer) NotifyPeerPrice(peer swarm.Address, price uint64, index uint8) error {
if price == 0 {
return nil
}
pricingPeer, err := s.getPricingPeer(peer)
if err != nil {
return err
}
pricingPeer.lock.Lock()
defer pricingPeer.lock.Unlock()
priceTable := s.peerPriceTable(peer)
currentIndexDepth := uint8(len(priceTable)) - 1
if index <= currentIndexDepth {
// Simple case, already have index depth, single value change
priceTable[index] = price
return s.storePriceTable(peer, priceTable)
}
// Complicated case, index is larger than depth of already known table
newPriceTable := make([]uint64, index+1)
// Copy previous content
_ = copy(newPriceTable, priceTable)
// Check how many rows are missing
numberOfMissingRows := index - currentIndexDepth
for i := uint8(0); i < numberOfMissingRows; i++ {
currentrow := index - i
newPriceTable[currentrow] = price + uint64(i)*s.poPrice
s.logger.Debugf("Guessing price %v for extending pricetable %v for peer %v", newPriceTable[currentrow], newPriceTable, peer)
}
return s.storePriceTable(peer, newPriceTable)
}
func (s *Pricer) defaultPriceTable() []uint64 {
neighborhoodDepth := s.neighborhoodDepth()
priceTable := make([]uint64, neighborhoodDepth+1)
for i := uint8(0); i <= neighborhoodDepth; i++ {
priceTable[i] = uint64(neighborhoodDepth-i+1) * s.poPrice
}
return priceTable
}
func (s *Pricer) defaultPrice(po uint8) uint64 {
neighborhoodDepth := s.neighborhoodDepth()
if po > neighborhoodDepth {
po = neighborhoodDepth
}
return uint64(neighborhoodDepth-po+1) * s.poPrice
}
func (s *Pricer) neighborhoodDepth() uint8 {
var neighborhoodDepth uint8
if s.topology != nil {
neighborhoodDepth = s.topology.NeighborhoodDepth()
}
return neighborhoodDepth
}
func (s *Pricer) PriceHeadler(receivedHeaders p2p.Headers, peerAddress swarm.Address) (returnHeaders p2p.Headers) {
chunkAddress, receivedPrice, err := headerutils.ParsePricingHeaders(receivedHeaders)
if err != nil {
return p2p.Headers{
"error": []byte("Error reading pricing headers"),
}
}
s.logger.Debugf("price headler: received target %v with price as %v, from peer %s", chunkAddress, receivedPrice, peerAddress)
checkPrice, index := s.priceWithIndexForPeer(peerAddress, chunkAddress)
returnHeaders, err = headerutils.MakePricingResponseHeaders(checkPrice, chunkAddress, index)
if err != nil {
return p2p.Headers{
"error": []byte("Error creating response pricing headers"),
}
}
s.logger.Debugf("price headler: response target %v with price as %v, for peer %s", chunkAddress, checkPrice, peerAddress)
return returnHeaders
}
func (s *Pricer) SetTopology(top topology.Driver) {
s.topology = top
}
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pricer_test
import (
"encoding/binary"
"io/ioutil"
"reflect"
"testing"
mockkad "github.com/ethersphere/bee/pkg/kademlia/mock"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
"github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestPriceTableDefaultTables(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
defer store.Close()
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepthCalls(0, 0, 2, 3, 5, 9))
pricer.SetTopology(kad)
table0 := []uint64{10}
table1 := []uint64{30, 20, 10}
table2 := []uint64{40, 30, 20, 10}
table3 := []uint64{60, 50, 40, 30, 20, 10}
table4 := []uint64{100, 90, 80, 70, 60, 50, 40, 30, 20, 10}
getTable0 := pricer.PriceTable()
if !reflect.DeepEqual(table0, getTable0) {
t.Fatalf("returned table does not match, got %+v expected %+v", getTable0, table0)
}
getTable1 := pricer.PriceTable()
if !reflect.DeepEqual(table1, getTable1) {
t.Fatalf("returned table does not match, got %+v expected %+v", getTable1, table1)
}
getTable2 := pricer.PriceTable()
if !reflect.DeepEqual(table2, getTable2) {
t.Fatalf("returned table does not match, got %+v expected %+v", getTable2, table2)
}
getTable3 := pricer.PriceTable()
if !reflect.DeepEqual(table3, getTable3) {
t.Fatalf("returned table does not match, got %+v expected %+v", getTable3, table3)
}
getTable4 := pricer.PriceTable()
if !reflect.DeepEqual(table4, getTable4) {
t.Fatalf("returned table does not match, got %+v expected %+v", getTable4, table4)
}
}
func TestPeerPrice(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
defer store.Close()
overlay := swarm.MustParseHexAddress("0000617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
peer := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
chunksByPOToPeer := []swarm.Address{
swarm.MustParseHexAddress("05ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("95ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("c5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("f0ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("e6ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
}
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(3))
pricer.SetTopology(kad)
peerTable := []uint64{55, 45, 35, 25, 15}
err := pricer.NotifyPriceTable(peer, peerTable)
if err != nil {
t.Fatal(err)
}
for i, ch := range chunksByPOToPeer {
getPrice := pricer.PeerPrice(peer, ch)
if getPrice != peerTable[i] {
t.Fatalf("unexpected PeerPrice, got %v expected %v", getPrice, peerTable[i])
}
}
neighborhoodPeer := swarm.MustParseHexAddress("00ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
neighborhoodPeerTable := []uint64{36, 26, 16, 6}
err = pricer.NotifyPriceTable(neighborhoodPeer, neighborhoodPeerTable)
if err != nil {
t.Fatal(err)
}
chunksByPOToNeighborhoodPeer := []swarm.Address{
swarm.MustParseHexAddress("95ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("55ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("35ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("10ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
}
neighborhoodPeerExpectedPrices := []uint64{36, 26, 16, 0}
for i, ch := range chunksByPOToNeighborhoodPeer {
getPrice := pricer.PeerPrice(neighborhoodPeer, ch)
if getPrice != neighborhoodPeerExpectedPrices[i] {
t.Fatalf("unexpected PeerPrice, got %v expected %v", getPrice, neighborhoodPeerExpectedPrices[i])
}
}
}
func TestPriceForPeer(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
defer store.Close()
peer := swarm.MustParseHexAddress("0000617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
chunksByPOToOverlay := []swarm.Address{
swarm.MustParseHexAddress("05ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("95ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("c5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("f0ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
swarm.MustParseHexAddress("e6ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700"),
}
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(4))
pricer.SetTopology(kad)
defaultTable := []uint64{50, 40, 30, 20, 10}
for i, ch := range chunksByPOToOverlay {
getPrice := pricer.PriceForPeer(peer, ch)
if getPrice != defaultTable[i] {
t.Fatalf("unexpected price for peer, got %v expected %v", getPrice, defaultTable[i])
}
}
neighborhoodPeer := swarm.MustParseHexAddress("e7ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
expectedPricesInNeighborhood := []uint64{50, 40, 30, 20, 0}
for i, ch := range chunksByPOToOverlay {
getPrice := pricer.PriceForPeer(neighborhoodPeer, ch)
if getPrice != expectedPricesInNeighborhood[i] {
t.Fatalf("unexpected price for peer, got %v expected %v", getPrice, expectedPricesInNeighborhood[i])
}
}
}
func TestNotifyPriceTable(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
defer store.Close()
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(0))
pricer.SetTopology(kad)
peer := swarm.MustParseHexAddress("ffff617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
peerTable := []uint64{66, 55, 44, 33, 22, 11}
err := pricer.NotifyPriceTable(peer, peerTable)
if err != nil {
t.Fatal(err)
}
for i := 0; i < len(peerTable); i++ {
getPrice, err := pricer.PeerPricePO(peer, uint8(i))
if err != nil {
t.Fatal(err)
}
if getPrice != peerTable[i] {
t.Fatalf("unexpected PeerPricePO, got %v expected %v", getPrice, peerTable[i])
}
}
}
func TestNotifyPeerPrice(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
defer store.Close()
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(0))
pricer.SetTopology(kad)
peer := swarm.MustParseHexAddress("ffff617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
peerTable := []uint64{33, 22, 11}
err := pricer.NotifyPriceTable(peer, peerTable)
if err != nil {
t.Fatal(err)
}
for i := 0; i < len(peerTable); i++ {
getPrice, err := pricer.PeerPricePO(peer, uint8(i))
if err != nil {
t.Fatal(err)
}
if getPrice != peerTable[i] {
t.Fatalf("unexpected PeerPricePO, got %v expected %v", getPrice, peerTable[i])
}
}
err = pricer.NotifyPeerPrice(peer, 48, 8)
if err != nil {
t.Fatal(err)
}
expectedTable := []uint64{33, 22, 11, 98, 88, 78, 68, 58, 48}
for i := 0; i < len(expectedTable); i++ {
getPrice, err := pricer.PeerPricePO(peer, uint8(i))
if err != nil {
t.Fatal(err)
}
if getPrice != expectedTable[i] {
t.Fatalf("unexpected PeerPricePO, got %v expected %v", getPrice, expectedTable[i])
}
}
err = pricer.NotifyPeerPrice(peer, 43, 9)
if err != nil {
t.Fatal(err)
}
expectedTable = []uint64{33, 22, 11, 98, 88, 78, 68, 58, 48, 43}
for i := 0; i < len(expectedTable); i++ {
getPrice, err := pricer.PeerPricePO(peer, uint8(i))
if err != nil {
t.Fatal(err)
}
if getPrice != expectedTable[i] {
t.Fatalf("unexpected PeerPricePO, got %v expected %v", getPrice, expectedTable[i])
}
}
err = pricer.NotifyPeerPrice(peer, 60, 5)
if err != nil {
t.Fatal(err)
}
expectedTable = []uint64{33, 22, 11, 98, 88, 60, 68, 58, 48, 43}
for i := 0; i < len(expectedTable); i++ {
getPrice, err := pricer.PeerPricePO(peer, uint8(i))
if err != nil {
t.Fatal(err)
}
if getPrice != expectedTable[i] {
t.Fatalf("unexpected PeerPricePO, got %v expected %v", getPrice, expectedTable[i])
}
}
}
func TestPricerHeadler(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
defer store.Close()
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(2))
pricer.SetTopology(kad)
peer := swarm.MustParseHexAddress("07e0d4ba628ad700fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
chunkAddress := swarm.MustParseHexAddress("7a22589362e34efdd7e0d4ba628ad7007a22589362e34efdd7e0d4ba628ad700")
requestHeaders, err := headerutils.MakePricingHeaders(50, chunkAddress)
if err != nil {
t.Fatal(err)
}
responseHeaders := pricer.PriceHeadler(requestHeaders, peer)
if !reflect.DeepEqual(requestHeaders["target"], responseHeaders["target"]) {
t.Fatalf("targets don't match, got %v, want %v", responseHeaders["target"], requestHeaders["target"])
}
chunkPriceInRequest := make([]byte, 8)
binary.BigEndian.PutUint64(chunkPriceInRequest, uint64(50))
chunkPriceInResponse := make([]byte, 8)
binary.BigEndian.PutUint64(chunkPriceInResponse, uint64(30))
if !reflect.DeepEqual(requestHeaders["price"], chunkPriceInRequest) {
t.Fatalf("targets don't match, got %v, want %v", responseHeaders["price"], chunkPriceInRequest)
}
if !reflect.DeepEqual(responseHeaders["price"], chunkPriceInResponse) {
t.Fatalf("targets don't match, got %v, want %v", responseHeaders["price"], chunkPriceInResponse)
}
}
func TestPricerHeadlerBadHeaders(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
overlay := swarm.MustParseHexAddress("e5ef617cadab2af7fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
defer store.Close()
pricer := pricer.New(logger, store, overlay, 10)
kad := mockkad.NewMockKademlia(mockkad.WithDepth(2))
pricer.SetTopology(kad)
peer := swarm.MustParseHexAddress("07e0d4ba628ad700fff48b16b52953487a22589362e34efdd7e0d4ba628ad700")
requestHeaders := p2p.Headers{
"irrelevantfield": []byte{},
}
responseHeaders := pricer.PriceHeadler(requestHeaders, peer)
if responseHeaders["target"] != nil {
t.Fatal("only error should be returned")
}
if responseHeaders["price"] != nil {
t.Fatal("only error should be returned")
}
if responseHeaders["index"] != nil {
t.Fatal("only error should be returned")
}
if responseHeaders["error"] == nil {
t.Fatal("error should be returned")
}
}
...@@ -23,7 +23,8 @@ var _ = math.Inf ...@@ -23,7 +23,8 @@ var _ = math.Inf
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type AnnouncePaymentThreshold struct { type AnnouncePaymentThreshold struct {
PaymentThreshold []byte `protobuf:"bytes,1,opt,name=PaymentThreshold,proto3" json:"PaymentThreshold,omitempty"` PaymentThreshold []byte `protobuf:"bytes,1,opt,name=PaymentThreshold,proto3" json:"PaymentThreshold,omitempty"`
ProximityPrice []uint64 `protobuf:"varint,2,rep,packed,name=ProximityPrice,proto3" json:"ProximityPrice,omitempty"`
} }
func (m *AnnouncePaymentThreshold) Reset() { *m = AnnouncePaymentThreshold{} } func (m *AnnouncePaymentThreshold) Reset() { *m = AnnouncePaymentThreshold{} }
...@@ -66,6 +67,13 @@ func (m *AnnouncePaymentThreshold) GetPaymentThreshold() []byte { ...@@ -66,6 +67,13 @@ func (m *AnnouncePaymentThreshold) GetPaymentThreshold() []byte {
return nil return nil
} }
func (m *AnnouncePaymentThreshold) GetProximityPrice() []uint64 {
if m != nil {
return m.ProximityPrice
}
return nil
}
func init() { func init() {
proto.RegisterType((*AnnouncePaymentThreshold)(nil), "pricing.AnnouncePaymentThreshold") proto.RegisterType((*AnnouncePaymentThreshold)(nil), "pricing.AnnouncePaymentThreshold")
} }
...@@ -73,15 +81,17 @@ func init() { ...@@ -73,15 +81,17 @@ func init() {
func init() { proto.RegisterFile("pricing.proto", fileDescriptor_ec4cc93d045d43d0) } func init() { proto.RegisterFile("pricing.proto", fileDescriptor_ec4cc93d045d43d0) }
var fileDescriptor_ec4cc93d045d43d0 = []byte{ var fileDescriptor_ec4cc93d045d43d0 = []byte{
// 122 bytes of a gzipped FileDescriptorProto // 150 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x28, 0xca, 0x4c, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x28, 0xca, 0x4c,
0xce, 0xcc, 0x4b, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x87, 0x72, 0x95, 0xdc, 0xb8, 0xce, 0xcc, 0x4b, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x87, 0x72, 0x95, 0xf2, 0xb8,
0x24, 0x1c, 0xf3, 0xf2, 0xf2, 0x4b, 0xf3, 0x92, 0x53, 0x03, 0x12, 0x2b, 0x73, 0x53, 0xf3, 0x4a, 0x24, 0x1c, 0xf3, 0xf2, 0xf2, 0x4b, 0xf3, 0x92, 0x53, 0x03, 0x12, 0x2b, 0x73, 0x53, 0xf3, 0x4a,
0x42, 0x32, 0x8a, 0x52, 0x8b, 0x33, 0xf2, 0x73, 0x52, 0x84, 0xb4, 0xb8, 0x04, 0xd0, 0xc5, 0x24, 0x42, 0x32, 0x8a, 0x52, 0x8b, 0x33, 0xf2, 0x73, 0x52, 0x84, 0xb4, 0xb8, 0x04, 0xd0, 0xc5, 0x24,
0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x30, 0xc4, 0x9d, 0x64, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x30, 0xc4, 0x85, 0xd4, 0xb8, 0xf8, 0x02, 0x8a, 0xf2, 0x2b,
0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x09, 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8, 0x32, 0x73, 0x33, 0x4b, 0x2a, 0x03, 0x8a, 0x32, 0x93, 0x53, 0x25, 0x98, 0x14, 0x98, 0x35, 0x58,
0xf1, 0x58, 0x8e, 0x21, 0x8a, 0xa9, 0x20, 0x29, 0x89, 0x0d, 0x6c, 0xab, 0x31, 0x20, 0x00, 0x00, 0x82, 0xd0, 0x44, 0x9d, 0x64, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23,
0xff, 0xff, 0x50, 0xca, 0x0e, 0x0a, 0x86, 0x00, 0x00, 0x00, 0x39, 0xc6, 0x09, 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8, 0xf1, 0x58, 0x8e, 0x21, 0x8a,
0xa9, 0x20, 0x29, 0x89, 0x0d, 0xec, 0x3a, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff, 0x90, 0x49,
0x1c, 0x6f, 0xae, 0x00, 0x00, 0x00,
} }
func (m *AnnouncePaymentThreshold) Marshal() (dAtA []byte, err error) { func (m *AnnouncePaymentThreshold) Marshal() (dAtA []byte, err error) {
...@@ -104,6 +114,24 @@ func (m *AnnouncePaymentThreshold) MarshalToSizedBuffer(dAtA []byte) (int, error ...@@ -104,6 +114,24 @@ func (m *AnnouncePaymentThreshold) MarshalToSizedBuffer(dAtA []byte) (int, error
_ = i _ = i
var l int var l int
_ = l _ = l
if len(m.ProximityPrice) > 0 {
dAtA2 := make([]byte, len(m.ProximityPrice)*10)
var j1 int
for _, num := range m.ProximityPrice {
for num >= 1<<7 {
dAtA2[j1] = uint8(uint64(num)&0x7f | 0x80)
num >>= 7
j1++
}
dAtA2[j1] = uint8(num)
j1++
}
i -= j1
copy(dAtA[i:], dAtA2[:j1])
i = encodeVarintPricing(dAtA, i, uint64(j1))
i--
dAtA[i] = 0x12
}
if len(m.PaymentThreshold) > 0 { if len(m.PaymentThreshold) > 0 {
i -= len(m.PaymentThreshold) i -= len(m.PaymentThreshold)
copy(dAtA[i:], m.PaymentThreshold) copy(dAtA[i:], m.PaymentThreshold)
...@@ -135,6 +163,13 @@ func (m *AnnouncePaymentThreshold) Size() (n int) { ...@@ -135,6 +163,13 @@ func (m *AnnouncePaymentThreshold) Size() (n int) {
if l > 0 { if l > 0 {
n += 1 + l + sovPricing(uint64(l)) n += 1 + l + sovPricing(uint64(l))
} }
if len(m.ProximityPrice) > 0 {
l = 0
for _, e := range m.ProximityPrice {
l += sovPricing(uint64(e))
}
n += 1 + sovPricing(uint64(l)) + l
}
return n return n
} }
...@@ -207,6 +242,82 @@ func (m *AnnouncePaymentThreshold) Unmarshal(dAtA []byte) error { ...@@ -207,6 +242,82 @@ func (m *AnnouncePaymentThreshold) Unmarshal(dAtA []byte) error {
m.PaymentThreshold = []byte{} m.PaymentThreshold = []byte{}
} }
iNdEx = postIndex iNdEx = postIndex
case 2:
if wireType == 0 {
var v uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPricing
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
v |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
m.ProximityPrice = append(m.ProximityPrice, v)
} else if wireType == 2 {
var packedLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPricing
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
packedLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if packedLen < 0 {
return ErrInvalidLengthPricing
}
postIndex := iNdEx + packedLen
if postIndex < 0 {
return ErrInvalidLengthPricing
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
var elementCount int
var count int
for _, integer := range dAtA[iNdEx:postIndex] {
if integer < 128 {
count++
}
}
elementCount = count
if elementCount != 0 && len(m.ProximityPrice) == 0 {
m.ProximityPrice = make([]uint64, 0, elementCount)
}
for iNdEx < postIndex {
var v uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPricing
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
v |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
m.ProximityPrice = append(m.ProximityPrice, v)
}
} else {
return fmt.Errorf("proto: wrong wireType = %d for field ProximityPrice", wireType)
}
default: default:
iNdEx = preIndex iNdEx = preIndex
skippy, err := skipPricing(dAtA[iNdEx:]) skippy, err := skipPricing(dAtA[iNdEx:])
......
...@@ -9,5 +9,6 @@ package pricing; ...@@ -9,5 +9,6 @@ package pricing;
option go_package = "pb"; option go_package = "pb";
message AnnouncePaymentThreshold { message AnnouncePaymentThreshold {
bytes PaymentThreshold = 1; bytes PaymentThreshold = 1;
repeated uint64 ProximityPrice = 2;
} }
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricing/pb" "github.com/ethersphere/bee/pkg/pricing/pb"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -28,6 +29,12 @@ var _ Interface = (*Service)(nil) ...@@ -28,6 +29,12 @@ var _ Interface = (*Service)(nil)
// Interface is the main interface of the pricing protocol // Interface is the main interface of the pricing protocol
type Interface interface { type Interface interface {
AnnouncePaymentThreshold(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error AnnouncePaymentThreshold(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error
AnnouncePaymentThresholdAndPriceTable(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error
}
// PriceTableObserver is used for being notified of price table updates
type PriceTableObserver interface {
NotifyPriceTable(peer swarm.Address, priceTable []uint64) error
} }
// PaymentThresholdObserver is used for being notified of payment threshold updates // PaymentThresholdObserver is used for being notified of payment threshold updates
...@@ -39,13 +46,16 @@ type Service struct { ...@@ -39,13 +46,16 @@ type Service struct {
streamer p2p.Streamer streamer p2p.Streamer
logger logging.Logger logger logging.Logger
paymentThreshold *big.Int paymentThreshold *big.Int
pricer pricer.Interface
paymentThresholdObserver PaymentThresholdObserver paymentThresholdObserver PaymentThresholdObserver
priceTableObserver PriceTableObserver
} }
func New(streamer p2p.Streamer, logger logging.Logger, paymentThreshold *big.Int) *Service { func New(streamer p2p.Streamer, logger logging.Logger, paymentThreshold *big.Int, pricer pricer.Interface) *Service {
return &Service{ return &Service{
streamer: streamer, streamer: streamer,
logger: logger, logger: logger,
pricer: pricer,
paymentThreshold: paymentThreshold, paymentThreshold: paymentThreshold,
} }
} }
...@@ -77,17 +87,30 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -77,17 +87,30 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
var req pb.AnnouncePaymentThreshold var req pb.AnnouncePaymentThreshold
if err := r.ReadMsgWithContext(ctx, &req); err != nil { if err := r.ReadMsgWithContext(ctx, &req); err != nil {
s.logger.Debugf("could not receive payment threshold announcement from peer %v", p.Address) s.logger.Debugf("could not receive payment threshold and/or price table announcement from peer %v", p.Address)
return fmt.Errorf("read request from peer %v: %w", p.Address, err) return fmt.Errorf("read request from peer %v: %w", p.Address, err)
} }
paymentThreshold := big.NewInt(0).SetBytes(req.PaymentThreshold) paymentThreshold := big.NewInt(0).SetBytes(req.PaymentThreshold)
s.logger.Tracef("received payment threshold announcement from peer %v of %d", p.Address, paymentThreshold) s.logger.Tracef("received payment threshold announcement from peer %v of %d", p.Address, paymentThreshold)
if req.ProximityPrice != nil {
s.logger.Tracef("received pricetable announcement from peer %v of %v", p.Address, req.ProximityPrice)
err = s.priceTableObserver.NotifyPriceTable(p.Address, req.ProximityPrice)
if err != nil {
s.logger.Debugf("error receiving pricetable from peer %v: %w", p.Address, err)
s.logger.Errorf("error receiving pricetable from peer %v: %w", p.Address, err)
}
}
if paymentThreshold.Cmp(big.NewInt(0)) == 0 {
return err
}
return s.paymentThresholdObserver.NotifyPaymentThreshold(p.Address, paymentThreshold) return s.paymentThresholdObserver.NotifyPaymentThreshold(p.Address, paymentThreshold)
} }
func (s *Service) init(ctx context.Context, p p2p.Peer) error { func (s *Service) init(ctx context.Context, p p2p.Peer) error {
err := s.AnnouncePaymentThreshold(ctx, p.Address, s.paymentThreshold) err := s.AnnouncePaymentThresholdAndPriceTable(ctx, p.Address, s.paymentThreshold)
if err != nil { if err != nil {
s.logger.Warningf("could not send payment threshold announcement to peer %v", p.Address) s.logger.Warningf("could not send payment threshold announcement to peer %v", p.Address)
} }
...@@ -120,7 +143,39 @@ func (s *Service) AnnouncePaymentThreshold(ctx context.Context, peer swarm.Addre ...@@ -120,7 +143,39 @@ func (s *Service) AnnouncePaymentThreshold(ctx context.Context, peer swarm.Addre
return err return err
} }
// AnnouncePaymentThresholdAndPriceTable announces own payment threshold and pricetable to peer
func (s *Service) AnnouncePaymentThresholdAndPriceTable(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil {
return err
}
defer func() {
if err != nil {
_ = stream.Reset()
} else {
go stream.FullClose()
}
}()
s.logger.Tracef("sending payment threshold announcement to peer %v of %d", peer, paymentThreshold)
w := protobuf.NewWriter(stream)
err = w.WriteMsgWithContext(ctx, &pb.AnnouncePaymentThreshold{
PaymentThreshold: paymentThreshold.Bytes(),
ProximityPrice: s.pricer.PriceTable(),
})
return err
}
// SetPaymentThresholdObserver sets the PaymentThresholdObserver to be used when receiving a new payment threshold // SetPaymentThresholdObserver sets the PaymentThresholdObserver to be used when receiving a new payment threshold
func (s *Service) SetPaymentThresholdObserver(observer PaymentThresholdObserver) { func (s *Service) SetPaymentThresholdObserver(observer PaymentThresholdObserver) {
s.paymentThresholdObserver = observer s.paymentThresholdObserver = observer
} }
// SetPriceTableObserver sets the PriceTableObserver to be used when receiving a new pricetable
func (s *Service) SetPriceTableObserver(observer PriceTableObserver) {
s.priceTableObserver = observer
}
...@@ -9,35 +9,52 @@ import ( ...@@ -9,35 +9,52 @@ import (
"context" "context"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"reflect"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
pricermock "github.com/ethersphere/bee/pkg/pricer/mock"
"github.com/ethersphere/bee/pkg/pricing" "github.com/ethersphere/bee/pkg/pricing"
"github.com/ethersphere/bee/pkg/pricing/pb" "github.com/ethersphere/bee/pkg/pricing/pb"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
type testObserver struct { type testThresholdObserver struct {
called bool called bool
peer swarm.Address peer swarm.Address
paymentThreshold *big.Int paymentThreshold *big.Int
} }
func (t *testObserver) NotifyPaymentThreshold(peer swarm.Address, paymentThreshold *big.Int) error { type testPriceTableObserver struct {
called bool
peer swarm.Address
priceTable []uint64
}
func (t *testThresholdObserver) NotifyPaymentThreshold(peerAddr swarm.Address, paymentThreshold *big.Int) error {
t.called = true t.called = true
t.peer = peer t.peer = peerAddr
t.paymentThreshold = paymentThreshold t.paymentThreshold = paymentThreshold
return nil return nil
} }
func (t *testPriceTableObserver) NotifyPriceTable(peerAddr swarm.Address, priceTable []uint64) error {
t.called = true
t.peer = peerAddr
t.priceTable = priceTable
return nil
}
func TestAnnouncePaymentThreshold(t *testing.T) { func TestAnnouncePaymentThreshold(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
testThreshold := big.NewInt(100000) testThreshold := big.NewInt(100000)
observer := &testObserver{} observer := &testThresholdObserver{}
pricerMockService := pricermock.NewMockService()
recipient := pricing.New(nil, logger, testThreshold) recipient := pricing.New(nil, logger, testThreshold, pricerMockService)
recipient.SetPaymentThresholdObserver(observer) recipient.SetPaymentThresholdObserver(observer)
peerID := swarm.MustParseHexAddress("9ee7add7") peerID := swarm.MustParseHexAddress("9ee7add7")
...@@ -47,7 +64,7 @@ func TestAnnouncePaymentThreshold(t *testing.T) { ...@@ -47,7 +64,7 @@ func TestAnnouncePaymentThreshold(t *testing.T) {
streamtest.WithBaseAddr(peerID), streamtest.WithBaseAddr(peerID),
) )
payer := pricing.New(recorder, logger, testThreshold) payer := pricing.New(recorder, logger, testThreshold, pricerMockService)
paymentThreshold := big.NewInt(10000) paymentThreshold := big.NewInt(10000)
...@@ -96,3 +113,93 @@ func TestAnnouncePaymentThreshold(t *testing.T) { ...@@ -96,3 +113,93 @@ func TestAnnouncePaymentThreshold(t *testing.T) {
t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID) t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID)
} }
} }
func TestAnnouncePaymentThresholdAndPriceTable(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
testThreshold := big.NewInt(100000)
observer1 := &testThresholdObserver{}
observer2 := &testPriceTableObserver{}
table := []uint64{50, 25, 12, 6}
priceTableFunc := func() []uint64 {
return table
}
pricerMockService := pricermock.NewMockService(pricermock.WithPriceTableFunc(priceTableFunc))
recipient := pricing.New(nil, logger, testThreshold, pricerMockService)
recipient.SetPaymentThresholdObserver(observer1)
recipient.SetPriceTableObserver(observer2)
peerID := swarm.MustParseHexAddress("9ee7add7")
recorder := streamtest.New(
streamtest.WithProtocols(recipient.Protocol()),
streamtest.WithBaseAddr(peerID),
)
payer := pricing.New(recorder, logger, testThreshold, pricerMockService)
paymentThreshold := big.NewInt(10000)
err := payer.AnnouncePaymentThresholdAndPriceTable(context.Background(), peerID, paymentThreshold)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want %v", l, 1)
}
record := records[0]
messages, err := protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.AnnouncePaymentThreshold) },
)
if err != nil {
t.Fatal(err)
}
if len(messages) != 1 {
t.Fatalf("got %v messages, want %v", len(messages), 1)
}
sentPaymentThreshold := big.NewInt(0).SetBytes(messages[0].(*pb.AnnouncePaymentThreshold).PaymentThreshold)
if sentPaymentThreshold.Cmp(paymentThreshold) != 0 {
t.Fatalf("got message with amount %v, want %v", sentPaymentThreshold, paymentThreshold)
}
sentPriceTable := messages[0].(*pb.AnnouncePaymentThreshold).ProximityPrice
if !reflect.DeepEqual(sentPriceTable, table) {
t.Fatalf("got message with table %v, want %v", sentPriceTable, table)
}
if !observer1.called {
t.Fatal("expected threshold observer to be called")
}
if observer1.paymentThreshold.Cmp(paymentThreshold) != 0 {
t.Fatalf("observer called with wrong paymentThreshold. got %d, want %d", observer1.paymentThreshold, paymentThreshold)
}
if !observer1.peer.Equal(peerID) {
t.Fatalf("threshold observer called with wrong peer. got %v, want %v", observer1.peer, peerID)
}
if !observer2.called {
t.Fatal("expected table observer to be called")
}
if !reflect.DeepEqual(observer2.priceTable, table) {
t.Fatalf("table observer called with wrong priceTable. got %d, want %d", observer2.priceTable, table)
}
if !observer2.peer.Equal(peerID) {
t.Fatalf("table observer called with wrong peer. got %v, want %v", observer2.peer, peerID)
}
}
...@@ -17,6 +17,8 @@ import ( ...@@ -17,6 +17,8 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
"github.com/ethersphere/bee/pkg/pushsync/pb" "github.com/ethersphere/bee/pkg/pushsync/pb"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -53,14 +55,14 @@ type PushSync struct { ...@@ -53,14 +55,14 @@ type PushSync struct {
unwrap func(swarm.Chunk) unwrap func(swarm.Chunk)
logger logging.Logger logger logging.Logger
accounting accounting.Interface accounting accounting.Interface
pricer accounting.Pricer pricer pricer.Interface
metrics metrics metrics metrics
tracer *tracing.Tracer tracer *tracing.Tracer
} }
var timeToLive = 5 * time.Second // request time to live var timeToLive = 5 * time.Second // request time to live
func New(streamer p2p.StreamerDisconnecter, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, unwrap func(swarm.Chunk), logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync { func New(streamer p2p.StreamerDisconnecter, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, unwrap func(swarm.Chunk), logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, tracer *tracing.Tracer) *PushSync {
ps := &PushSync{ ps := &PushSync{
streamer: streamer, streamer: streamer,
storer: storer, storer: storer,
...@@ -84,6 +86,7 @@ func (s *PushSync) Protocol() p2p.ProtocolSpec { ...@@ -84,6 +86,7 @@ func (s *PushSync) Protocol() p2p.ProtocolSpec {
{ {
Name: streamName, Name: streamName,
Handler: s.handler, Handler: s.handler,
Headler: s.pricer.PriceHeadler,
}, },
}, },
} }
...@@ -122,6 +125,16 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -122,6 +125,16 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
span, _, ctx := ps.tracer.StartSpanFromContext(ctx, "pushsync-handler", ps.logger, opentracing.Tag{Key: "address", Value: chunk.Address().String()}) span, _, ctx := ps.tracer.StartSpanFromContext(ctx, "pushsync-handler", ps.logger, opentracing.Tag{Key: "address", Value: chunk.Address().String()})
defer span.Finish() defer span.Finish()
// Get price we charge for upstream peer read at headler
responseHeaders := stream.ResponseHeaders()
price, err := headerutils.ParsePriceHeader(responseHeaders)
if err != nil {
// if not found in returned header, compute the price we charge for this chunk and
ps.logger.Warningf("push sync: peer %v no price in previously issued response headers: %v", p.Address, err)
price = ps.pricer.PriceForPeer(p.Address, chunk.Address())
}
receipt, err := ps.pushToClosest(ctx, chunk) receipt, err := ps.pushToClosest(ctx, chunk)
if err != nil { if err != nil {
if errors.Is(err, topology.ErrWantSelf) { if errors.Is(err, topology.ErrWantSelf) {
...@@ -135,9 +148,10 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -135,9 +148,10 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err) return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
} }
return ps.accounting.Debit(p.Address, ps.pricer.Price(chunk.Address())) return ps.accounting.Debit(p.Address, price)
} }
return fmt.Errorf("handler: push to closest: %w", err) return fmt.Errorf("handler: push to closest: %w", err)
} }
// pass back the receipt // pass back the receipt
...@@ -145,7 +159,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -145,7 +159,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err) return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
} }
return ps.accounting.Debit(p.Address, ps.pricer.Price(chunk.Address())) return ps.accounting.Debit(p.Address, price)
} }
// PushChunkToClosest sends chunk to the closest peer by opening a stream. It then waits for // PushChunkToClosest sends chunk to the closest peer by opening a stream. It then waits for
...@@ -208,19 +222,44 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk) (rr *pb.R ...@@ -208,19 +222,44 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk) (rr *pb.R
// compute the price we pay for this receipt and reserve it for the rest of this function // compute the price we pay for this receipt and reserve it for the rest of this function
receiptPrice := ps.pricer.PeerPrice(peer, ch.Address()) receiptPrice := ps.pricer.PeerPrice(peer, ch.Address())
err = ps.accounting.Reserve(ctx, peer, receiptPrice)
headers, err := headerutils.MakePricingHeaders(receiptPrice, ch.Address())
if err != nil { if err != nil {
return nil, fmt.Errorf("reserve balance for peer %s: %w", peer.String(), err) return nil, err
} }
deferFuncs = append(deferFuncs, func() { ps.accounting.Release(peer, receiptPrice) })
streamer, err := ps.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName) streamer, err := ps.streamer.NewStream(ctx, peer, headers, protocolName, protocolVersion, streamName)
if err != nil { if err != nil {
lastErr = fmt.Errorf("new stream for peer %s: %w", peer.String(), err) lastErr = fmt.Errorf("new stream for peer %s: %w", peer.String(), err)
continue continue
} }
deferFuncs = append(deferFuncs, func() { go streamer.FullClose() }) deferFuncs = append(deferFuncs, func() { go streamer.FullClose() })
returnedHeaders := streamer.Headers()
returnedTarget, returnedPrice, returnedIndex, err := headerutils.ParsePricingResponseHeaders(returnedHeaders)
if err != nil {
return nil, fmt.Errorf("push price headers: read returned: %w", err)
}
ps.logger.Debugf("push price headers: returned target %v with price as %v, from peer %s", returnedTarget, returnedPrice, peer)
ps.logger.Debugf("push price headers: original target %v with price as %v, from peer %s", ch.Address(), receiptPrice, peer)
// check if returned price matches presumed price, if not, update price
if returnedPrice != receiptPrice {
err = ps.pricer.NotifyPeerPrice(peer, returnedPrice, returnedIndex) // save priceHeaders["price"] corresponding row for peer
if err != nil {
return nil, err
}
receiptPrice = returnedPrice
}
// Reserve to see whether we can make the request based on actual price
err = ps.accounting.Reserve(ctx, peer, receiptPrice)
if err != nil {
return nil, fmt.Errorf("reserve balance for peer %s: %w", peer.String(), err)
}
deferFuncs = append(deferFuncs, func() { ps.accounting.Release(peer, receiptPrice) })
w, r := protobuf.NewWriterAndReader(streamer) w, r := protobuf.NewWriterAndReader(streamer)
ctxd, canceld := context.WithTimeout(ctx, timeToLive) ctxd, canceld := context.WithTimeout(ctx, timeToLive)
deferFuncs = append(deferFuncs, func() { canceld() }) deferFuncs = append(deferFuncs, func() { canceld() })
......
...@@ -19,6 +19,8 @@ import ( ...@@ -19,6 +19,8 @@ import (
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
pricermock "github.com/ethersphere/bee/pkg/pricer/mock"
"github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/pushsync/pb" "github.com/ethersphere/bee/pkg/pushsync/pb"
statestore "github.com/ethersphere/bee/pkg/statestore/mock" statestore "github.com/ethersphere/bee/pkg/statestore/mock"
...@@ -33,6 +35,15 @@ const ( ...@@ -33,6 +35,15 @@ const (
fixedPrice = uint64(10) fixedPrice = uint64(10)
) )
type pricerParameters struct {
price uint64
peerPrice uint64
}
var (
defaultPrices = pricerParameters{price: fixedPrice, peerPrice: fixedPrice}
)
// TestSendChunkAndGetReceipt inserts a chunk as uploaded chunk in db. This triggers sending a chunk to the closest node // TestSendChunkAndGetReceipt inserts a chunk as uploaded chunk in db. This triggers sending a chunk to the closest node
// and expects a receipt. The message are intercepted in the outgoing stream to check for correctness. // and expects a receipt. The message are intercepted in the outgoing stream to check for correctness.
func TestSendChunkAndReceiveReceipt(t *testing.T) { func TestSendChunkAndReceiveReceipt(t *testing.T) {
...@@ -45,14 +56,15 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) { ...@@ -45,14 +56,15 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) {
// peer is the node responding to the chunk receipt message // peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to // mock should return ErrWantSelf since there's no one to forward to
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, defaultPrices, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close() defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()), streamtest.WithBaseAddr(pivotNode)) recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()), streamtest.WithBaseAddr(pivotNode))
// pivot node needs the streamer since the chunk is intercepted by // pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream // the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, _, pivotAccounting := createPushSyncNode(t, pivotNode, recorder, nil, mock.WithClosestPeer(closestPeer)) psPivot, storerPivot, _, pivotAccounting := createPushSyncNode(t, pivotNode, defaultPrices, recorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivot.Close() defer storerPivot.Close()
// Trigger the sending of chunk to the closest node // Trigger the sending of chunk to the closest node
...@@ -88,6 +100,65 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) { ...@@ -88,6 +100,65 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) {
} }
} }
func TestSendChunkAfterPriceUpdate(t *testing.T) {
// chunk data to upload
chunk := testingc.FixtureChunk("7000")
// create a pivot node and a mocked closest node
pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000
closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") // binary 0110 -> po 1
// peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to
serverPrice := uint64(17)
serverPrices := pricerParameters{price: serverPrice, peerPrice: fixedPrice}
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, serverPrices, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()), streamtest.WithBaseAddr(pivotNode))
// pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, _, pivotAccounting := createPushSyncNode(t, pivotNode, defaultPrices, recorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivot.Close()
// Trigger the sending of chunk to the closest node
receipt, err := psPivot.PushChunkToClosest(context.Background(), chunk)
if err != nil {
t.Fatal(err)
}
if !chunk.Address().Equal(receipt.Address) {
t.Fatal("invalid receipt")
}
// this intercepts the outgoing delivery message
waitOnRecordAndTest(t, closestPeer, recorder, chunk.Address(), chunk.Data())
// this intercepts the incoming receipt message
waitOnRecordAndTest(t, closestPeer, recorder, chunk.Address(), nil)
balance, err := pivotAccounting.Balance(closestPeer)
if err != nil {
t.Fatal(err)
}
if balance.Int64() != -int64(serverPrice) {
t.Fatalf("unexpected balance on pivot. want %d got %d", -int64(serverPrice), balance)
}
balance, err = peerAccounting.Balance(pivotNode)
if err != nil {
t.Fatal(err)
}
if balance.Int64() != int64(serverPrice) {
t.Fatalf("unexpected balance on peer. want %d got %d", int64(serverPrice), balance)
}
}
// PushChunkToClosest tests the sending of chunk to closest peer from the origination source perspective. // PushChunkToClosest tests the sending of chunk to closest peer from the origination source perspective.
// it also checks wether the tags are incremented properly if they are present // it also checks wether the tags are incremented properly if they are present
func TestPushChunkToClosest(t *testing.T) { func TestPushChunkToClosest(t *testing.T) {
...@@ -99,14 +170,14 @@ func TestPushChunkToClosest(t *testing.T) { ...@@ -99,14 +170,14 @@ func TestPushChunkToClosest(t *testing.T) {
callbackC := make(chan struct{}, 1) callbackC := make(chan struct{}, 1)
// peer is the node responding to the chunk receipt message // peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to // mock should return ErrWantSelf since there's no one to forward to
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, nil, chanFunc(callbackC), mock.WithClosestPeerErr(topology.ErrWantSelf)) psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, defaultPrices, nil, chanFunc(callbackC), mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close() defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()), streamtest.WithBaseAddr(pivotNode)) recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()), streamtest.WithBaseAddr(pivotNode))
// pivot node needs the streamer since the chunk is intercepted by // pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream // the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, pivotTags, pivotAccounting := createPushSyncNode(t, pivotNode, recorder, nil, mock.WithClosestPeer(closestPeer)) psPivot, storerPivot, pivotTags, pivotAccounting := createPushSyncNode(t, pivotNode, defaultPrices, recorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivot.Close() defer storerPivot.Close()
ta, err := pivotTags.Create(1) ta, err := pivotTags.Create(1)
...@@ -190,10 +261,10 @@ func TestPushChunkToNextClosest(t *testing.T) { ...@@ -190,10 +261,10 @@ func TestPushChunkToNextClosest(t *testing.T) {
// peer is the node responding to the chunk receipt message // peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to // mock should return ErrWantSelf since there's no one to forward to
psPeer1, storerPeer1, _, peerAccounting1 := createPushSyncNode(t, peer1, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psPeer1, storerPeer1, _, peerAccounting1 := createPushSyncNode(t, peer1, defaultPrices, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer1.Close() defer storerPeer1.Close()
psPeer2, storerPeer2, _, peerAccounting2 := createPushSyncNode(t, peer2, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psPeer2, storerPeer2, _, peerAccounting2 := createPushSyncNode(t, peer2, defaultPrices, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer2.Close() defer storerPeer2.Close()
recorder := streamtest.New( recorder := streamtest.New(
...@@ -223,7 +294,7 @@ func TestPushChunkToNextClosest(t *testing.T) { ...@@ -223,7 +294,7 @@ func TestPushChunkToNextClosest(t *testing.T) {
// pivot node needs the streamer since the chunk is intercepted by // pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream // the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, pivotTags, pivotAccounting := createPushSyncNode(t, pivotNode, recorder, nil, psPivot, storerPivot, pivotTags, pivotAccounting := createPushSyncNode(t, pivotNode, defaultPrices, recorder, nil,
mock.WithPeers(peers...), mock.WithPeers(peers...),
) )
defer storerPivot.Close() defer storerPivot.Close()
...@@ -311,19 +382,19 @@ func TestHandler(t *testing.T) { ...@@ -311,19 +382,19 @@ func TestHandler(t *testing.T) {
closestPeer := swarm.MustParseHexAddress("f000000000000000000000000000000000000000000000000000000000000000") closestPeer := swarm.MustParseHexAddress("f000000000000000000000000000000000000000000000000000000000000000")
// Create the closest peer // Create the closest peer
psClosestPeer, closestStorerPeerDB, _, closestAccounting := createPushSyncNode(t, closestPeer, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psClosestPeer, closestStorerPeerDB, _, closestAccounting := createPushSyncNode(t, closestPeer, defaultPrices, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer closestStorerPeerDB.Close() defer closestStorerPeerDB.Close()
closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol()), streamtest.WithBaseAddr(pivotPeer)) closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol()), streamtest.WithBaseAddr(pivotPeer))
// creating the pivot peer // creating the pivot peer
psPivot, storerPivotDB, _, pivotAccounting := createPushSyncNode(t, pivotPeer, closestRecorder, nil, mock.WithClosestPeer(closestPeer)) psPivot, storerPivotDB, _, pivotAccounting := createPushSyncNode(t, pivotPeer, defaultPrices, closestRecorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivotDB.Close() defer storerPivotDB.Close()
pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol()), streamtest.WithBaseAddr(triggerPeer)) pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol()), streamtest.WithBaseAddr(triggerPeer))
// Creating the trigger peer // Creating the trigger peer
psTriggerPeer, triggerStorerDB, _, triggerAccounting := createPushSyncNode(t, triggerPeer, pivotRecorder, nil, mock.WithClosestPeer(pivotPeer)) psTriggerPeer, triggerStorerDB, _, triggerAccounting := createPushSyncNode(t, triggerPeer, defaultPrices, pivotRecorder, nil, mock.WithClosestPeer(pivotPeer))
defer triggerStorerDB.Close() defer triggerStorerDB.Close()
receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk) receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk)
...@@ -385,7 +456,97 @@ func TestHandler(t *testing.T) { ...@@ -385,7 +456,97 @@ func TestHandler(t *testing.T) {
} }
} }
func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, unwrap func(swarm.Chunk), mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) { func TestHandlerWithUpdate(t *testing.T) {
// chunk data to upload
chunk := testingc.FixtureChunk("7000")
serverPrice := uint64(17)
serverPrices := pricerParameters{price: serverPrice, peerPrice: fixedPrice}
// create a pivot node and a mocked closest node
pivotPeer := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000")
triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000")
closestPeer := swarm.MustParseHexAddress("f000000000000000000000000000000000000000000000000000000000000000")
// Create the closest peer with default prices (10)
psClosestPeer, closestStorerPeerDB, _, closestAccounting := createPushSyncNode(t, closestPeer, defaultPrices, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer closestStorerPeerDB.Close()
closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol()), streamtest.WithBaseAddr(pivotPeer))
// creating the pivot peer who will act as a forwarder node with a higher price (17)
psPivot, storerPivotDB, _, pivotAccounting := createPushSyncNode(t, pivotPeer, serverPrices, closestRecorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivotDB.Close()
pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol()), streamtest.WithBaseAddr(triggerPeer))
// Creating the trigger peer with default price (10)
psTriggerPeer, triggerStorerDB, _, triggerAccounting := createPushSyncNode(t, triggerPeer, defaultPrices, pivotRecorder, nil, mock.WithClosestPeer(pivotPeer))
defer triggerStorerDB.Close()
receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk)
if err != nil {
t.Fatal(err)
}
if !chunk.Address().Equal(receipt.Address) {
t.Fatal("invalid receipt")
}
// In pivot peer, intercept the incoming delivery chunk from the trigger peer and check for correctness
waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunk.Address(), chunk.Data())
// Pivot peer will forward the chunk to its closest peer. Intercept the incoming stream from pivot node and check
// for the correctness of the chunk
waitOnRecordAndTest(t, closestPeer, closestRecorder, chunk.Address(), chunk.Data())
// Similarly intercept the same incoming stream to see if the closest peer is sending a proper receipt
waitOnRecordAndTest(t, closestPeer, closestRecorder, chunk.Address(), nil)
// In the received stream, check if a receipt is sent from pivot peer and check for its correctness.
waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunk.Address(), nil)
balance, err := triggerAccounting.Balance(pivotPeer)
if err != nil {
t.Fatal(err)
}
// balance on triggering peer towards the forwarder should show negative serverPrice (17)
if balance.Int64() != -int64(serverPrice) {
t.Fatalf("unexpected balance on trigger. want %d got %d", -int64(serverPrice), balance)
}
// we need to check here for pivotPeer instead of triggerPeer because during streamtest the peer in the handler is actually the receiver
// balance on forwarding peer for the triggering peer should show serverPrice (17)
balance, err = pivotAccounting.Balance(triggerPeer)
if err != nil {
t.Fatal(err)
}
if balance.Int64() != int64(serverPrice) {
t.Fatalf("unexpected balance on pivot. want %d got %d", int64(serverPrice), balance)
}
balance, err = pivotAccounting.Balance(closestPeer)
if err != nil {
t.Fatal(err)
}
// balance of the forwarder peer for the closest peer should show negative default price (10)
if balance.Int64() != -int64(fixedPrice) {
t.Fatalf("unexpected balance on pivot. want %d got %d", -int64(fixedPrice), balance)
}
balance, err = closestAccounting.Balance(pivotPeer)
if err != nil {
t.Fatal(err)
}
// balance of the closest peer for the forwarder peer should show the default price (10)
if balance.Int64() != int64(fixedPrice) {
t.Fatalf("unexpected balance on closest. want %d got %d", int64(fixedPrice), balance)
}
}
func createPushSyncNode(t *testing.T, addr swarm.Address, prices pricerParameters, recorder *streamtest.Recorder, unwrap func(swarm.Chunk), mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) {
t.Helper() t.Helper()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
...@@ -398,7 +559,14 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R ...@@ -398,7 +559,14 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R
mockStatestore := statestore.NewStateStore() mockStatestore := statestore.NewStateStore()
mtag := tags.NewTags(mockStatestore, logger) mtag := tags.NewTags(mockStatestore, logger)
mockAccounting := accountingmock.NewAccounting() mockAccounting := accountingmock.NewAccounting()
mockPricer := accountingmock.NewPricer(fixedPrice, fixedPrice)
headlerFunc := func(h p2p.Headers, a swarm.Address) p2p.Headers {
target, _ := headerutils.ParseTargetHeader(h)
headers, _ := headerutils.MakePricingResponseHeaders(prices.price, target, 0)
return headers
}
mockPricer := pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc), pricermock.WithPrice(prices.price), pricermock.WithPeerPrice(prices.peerPrice))
recorderDisconnecter := streamtest.NewRecorderDisconnecter(recorder) recorderDisconnecter := streamtest.NewRecorderDisconnecter(recorder)
if unwrap == nil { if unwrap == nil {
......
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
pricermock "github.com/ethersphere/bee/pkg/pricer/mock"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync"
pushsyncmock "github.com/ethersphere/bee/pkg/pushsync/mock" pushsyncmock "github.com/ethersphere/bee/pkg/pushsync/mock"
...@@ -218,14 +219,14 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store ...@@ -218,14 +219,14 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store
mockStorer := storemock.NewStorer() mockStorer := storemock.NewStorer()
serverMockAccounting := accountingmock.NewAccounting() serverMockAccounting := accountingmock.NewAccounting()
price := uint64(12345)
pricerMock := accountingmock.NewPricer(price, price) pricerMock := pricermock.NewMockService()
peerID := swarm.MustParseHexAddress("deadbeef") peerID := swarm.MustParseHexAddress("deadbeef")
ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error { ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(peerID, 0) _, _, _ = f(peerID, 0)
return nil return nil
}} }}
server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps, logger, serverMockAccounting, nil, nil) server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps, logger, serverMockAccounting, pricerMock, nil)
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
) )
......
...@@ -20,6 +20,8 @@ import ( ...@@ -20,6 +20,8 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
pb "github.com/ethersphere/bee/pkg/retrieval/pb" pb "github.com/ethersphere/bee/pkg/retrieval/pb"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -52,12 +54,12 @@ type Service struct { ...@@ -52,12 +54,12 @@ type Service struct {
singleflight singleflight.Group singleflight singleflight.Group
logger logging.Logger logger logging.Logger
accounting accounting.Interface accounting accounting.Interface
pricer accounting.Pricer
metrics metrics metrics metrics
pricer pricer.Interface
tracer *tracing.Tracer tracer *tracing.Tracer
} }
func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *Service { func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, tracer *tracing.Tracer) *Service {
return &Service{ return &Service{
addr: addr, addr: addr,
streamer: streamer, streamer: streamer,
...@@ -79,6 +81,7 @@ func (s *Service) Protocol() p2p.ProtocolSpec { ...@@ -79,6 +81,7 @@ func (s *Service) Protocol() p2p.ProtocolSpec {
{ {
Name: streamName, Name: streamName,
Handler: s.handler, Handler: s.handler,
Headler: s.pricer.PriceHeadler,
}, },
}, },
} }
...@@ -201,20 +204,23 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski ...@@ -201,20 +204,23 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski
sp.Add(peer) sp.Add(peer)
// compute the price we pay for this chunk and reserve it for the rest of this function // compute the price we presume to pay for this chunk for price header
chunkPrice := s.pricer.PeerPrice(peer, addr) chunkPrice := s.pricer.PeerPrice(peer, addr)
err = s.accounting.Reserve(ctx, peer, chunkPrice)
headers, err := headerutils.MakePricingHeaders(chunkPrice, addr)
if err != nil { if err != nil {
return nil, peer, err return nil, swarm.Address{}, err
} }
defer s.accounting.Release(peer, chunkPrice)
s.logger.Tracef("retrieval: requesting chunk %s from peer %s", addr, peer) s.logger.Tracef("retrieval: requesting chunk %s from peer %s", addr, peer)
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName) stream, err := s.streamer.NewStream(ctx, peer, headers, protocolName, protocolVersion, streamName)
if err != nil { if err != nil {
s.metrics.TotalErrors.Inc() s.metrics.TotalErrors.Inc()
return nil, peer, fmt.Errorf("new stream: %w", err) return nil, peer, fmt.Errorf("new stream: %w", err)
} }
returnedHeaders := stream.Headers()
defer func() { defer func() {
if err != nil { if err != nil {
_ = stream.Reset() _ = stream.Reset()
...@@ -223,6 +229,29 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski ...@@ -223,6 +229,29 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski
} }
}() }()
returnedTarget, returnedPrice, returnedIndex, err := headerutils.ParsePricingResponseHeaders(returnedHeaders)
if err != nil {
return nil, peer, fmt.Errorf("retrieval headers: read returned: %w", err)
}
s.logger.Debugf("retrieval headers: returned target %v with price as %v, from peer %s", returnedTarget, returnedPrice, peer)
s.logger.Debugf("retrieval headers: original target %v with price as %v, from peer %s", addr, chunkPrice, peer)
// check if returned price matches presumed price, if not, update price
if returnedPrice != chunkPrice {
err = s.pricer.NotifyPeerPrice(peer, returnedPrice, returnedIndex) // save priceHeaders["price"] corresponding row for peer
if err != nil {
return nil, peer, err
}
chunkPrice = returnedPrice
}
// Reserve to see whether we can request the chunk based on actual price
err = s.accounting.Reserve(ctx, peer, chunkPrice)
if err != nil {
return nil, peer, err
}
defer s.accounting.Release(peer, chunkPrice)
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
if err := w.WriteMsgWithContext(ctx, &pb.Request{ if err := w.WriteMsgWithContext(ctx, &pb.Request{
Addr: addr.Bytes(), Addr: addr.Bytes(),
...@@ -329,6 +358,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -329,6 +358,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
if err := r.ReadMsgWithContext(ctx, &req); err != nil { if err := r.ReadMsgWithContext(ctx, &req); err != nil {
return fmt.Errorf("read request: %w peer %s", err, p.Address.String()) return fmt.Errorf("read request: %w peer %s", err, p.Address.String())
} }
span, _, ctx := s.tracer.StartSpanFromContext(ctx, "handle-retrieve-chunk", s.logger, opentracing.Tag{Key: "address", Value: swarm.NewAddress(req.Addr).String()}) span, _, ctx := s.tracer.StartSpanFromContext(ctx, "handle-retrieve-chunk", s.logger, opentracing.Tag{Key: "address", Value: swarm.NewAddress(req.Addr).String()})
defer span.Finish() defer span.Finish()
...@@ -355,12 +385,15 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -355,12 +385,15 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String()) s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String())
// compute the price we charge for this chunk and debit it from p's balance // to get price Read in headler,
chunkPrice := s.pricer.Price(chunk.Address()) returnedHeaders := stream.ResponseHeaders()
err = s.accounting.Debit(p.Address, chunkPrice) chunkPrice, err := headerutils.ParsePriceHeader(returnedHeaders)
if err != nil { if err != nil {
return err // if not found in returned header, compute the price we charge for this chunk and
s.logger.Warningf("retrieval: peer %v no price in previously issued response headers: %v", p.Address, err)
chunkPrice = s.pricer.PriceForPeer(p.Address, chunk.Address())
} }
// debit price from p's balance
return nil return s.accounting.Debit(p.Address, chunkPrice)
} }
...@@ -11,15 +11,19 @@ import ( ...@@ -11,15 +11,19 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/big"
"sync" "sync"
"testing" "testing"
"time" "time"
accountingmock "github.com/ethersphere/bee/pkg/accounting/mock" accountingmock "github.com/ethersphere/bee/pkg/accounting/mock"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
pricermock "github.com/ethersphere/bee/pkg/pricer/mock"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
pb "github.com/ethersphere/bee/pkg/retrieval/pb" pb "github.com/ethersphere/bee/pkg/retrieval/pb"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -34,17 +38,23 @@ var testTimeout = 5 * time.Second ...@@ -34,17 +38,23 @@ var testTimeout = 5 * time.Second
// TestDelivery tests that a naive request -> delivery flow works. // TestDelivery tests that a naive request -> delivery flow works.
func TestDelivery(t *testing.T) { func TestDelivery(t *testing.T) {
var ( var (
chunk = testingc.FixtureChunk("0033")
headlerFunc = func(h p2p.Headers, a swarm.Address) p2p.Headers {
headers, _ := headerutils.MakePricingResponseHeaders(10, chunk.Address(), 0)
return headers
}
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
mockStorer = storemock.NewStorer() mockStorer = storemock.NewStorer()
chunk = testingc.FixtureChunk("0033")
clientMockAccounting = accountingmock.NewAccounting() clientMockAccounting = accountingmock.NewAccounting()
serverMockAccounting = accountingmock.NewAccounting() serverMockAccounting = accountingmock.NewAccounting()
clientAddr = swarm.MustParseHexAddress("9ee7add8") clientAddr = swarm.MustParseHexAddress("9ee7add8")
serverAddr = swarm.MustParseHexAddress("9ee7add7") serverAddr = swarm.MustParseHexAddress("9ee7add7")
price = uint64(10) price = uint64(10)
pricerMock = accountingmock.NewPricer(price, price) pricerMock = pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc))
) )
// put testdata in the mock store of the server // put testdata in the mock store of the server
_, err := mockStorer.Put(context.Background(), storage.ModePutUpload, chunk) _, err := mockStorer.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil { if err != nil {
...@@ -132,10 +142,126 @@ func TestDelivery(t *testing.T) { ...@@ -132,10 +142,126 @@ func TestDelivery(t *testing.T) {
} }
} }
// TestDelivery tests that a naive request -> delivery flow works.
func TestDeliveryWithPriceUpdate(t *testing.T) {
var (
price = uint64(10)
serverPrice = uint64(17)
chunk = testingc.FixtureChunk("0033")
headlerFunc = func(h p2p.Headers, a swarm.Address) p2p.Headers {
headers, _ := headerutils.MakePricingResponseHeaders(serverPrice, chunk.Address(), 5)
return headers
}
logger = logging.New(ioutil.Discard, 0)
mockStorer = storemock.NewStorer()
clientMockAccounting = accountingmock.NewAccounting()
serverMockAccounting = accountingmock.NewAccounting()
clientAddr = swarm.MustParseHexAddress("9ee7add8")
serverAddr = swarm.MustParseHexAddress("9ee7add7")
clientPricerMock = pricermock.NewMockService(pricermock.WithPeerPrice(price))
serverPricerMock = pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc), pricermock.WithPrice(serverPrice))
)
// put testdata in the mock store of the server
_, err := mockStorer.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil {
t.Fatal(err)
}
// create the server that will handle the request and will serve the response
server := retrieval.New(swarm.MustParseHexAddress("0034"), mockStorer, nil, nil, logger, serverMockAccounting, serverPricerMock, nil)
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
streamtest.WithBaseAddr(clientAddr),
)
// client mock storer does not store any data at this point
// but should be checked at at the end of the test for the
// presence of the chunk address key and value to ensure delivery
// was successful
clientMockStorer := storemock.NewStorer()
ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(serverAddr, 0)
return nil
}}
client := retrieval.New(clientAddr, clientMockStorer, recorder, ps, logger, clientMockAccounting, clientPricerMock, nil)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
v, err := client.RetrieveChunk(ctx, chunk.Address())
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(v.Data(), chunk.Data()) {
t.Fatalf("request and response data not equal. got %s want %s", v, chunk.Data())
}
records, err := recorder.Records(serverAddr, "retrieval", "1.0.0", "retrieval")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want %v", l, 1)
}
record := records[0]
messages, err := protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.Request) },
)
if err != nil {
t.Fatal(err)
}
var reqs []string
for _, m := range messages {
reqs = append(reqs, hex.EncodeToString(m.(*pb.Request).Addr))
}
if len(reqs) != 1 {
t.Fatalf("got too many requests. want 1 got %d", len(reqs))
}
messages, err = protobuf.ReadMessages(
bytes.NewReader(record.Out()),
func() protobuf.Message { return new(pb.Delivery) },
)
if err != nil {
t.Fatal(err)
}
var gotDeliveries []string
for _, m := range messages {
gotDeliveries = append(gotDeliveries, string(m.(*pb.Delivery).Data))
}
if len(gotDeliveries) != 1 {
t.Fatalf("got too many deliveries. want 1 got %d", len(gotDeliveries))
}
clientBalance, _ := clientMockAccounting.Balance(serverAddr)
if clientBalance.Cmp(big.NewInt(-int64(serverPrice))) != 0 {
t.Fatalf("unexpected balance on client. want %d got %d", -serverPrice, clientBalance)
}
serverBalance, _ := serverMockAccounting.Balance(clientAddr)
if serverBalance.Cmp(big.NewInt(int64(serverPrice))) != 0 {
t.Fatalf("unexpected balance on server. want %d got %d", serverPrice, serverBalance)
}
}
func TestRetrieveChunk(t *testing.T) { func TestRetrieveChunk(t *testing.T) {
var ( var (
headlerFunc = func(h p2p.Headers, a swarm.Address) p2p.Headers {
target, _ := headerutils.ParseTargetHeader(h)
headers, _ := headerutils.MakePricingResponseHeaders(10, target, 0)
return headers
}
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
pricer = accountingmock.NewPricer(1, 1) pricer = pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc))
) )
// requesting a chunk from downstream peer is expected // requesting a chunk from downstream peer is expected
...@@ -236,8 +362,14 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -236,8 +362,14 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
chunk := testingc.FixtureChunk("0025") chunk := testingc.FixtureChunk("0025")
someOtherChunk := testingc.FixtureChunk("0033") someOtherChunk := testingc.FixtureChunk("0033")
price := uint64(1) headlerFunc := func(h p2p.Headers, a swarm.Address) p2p.Headers {
pricerMock := accountingmock.NewPricer(price, price) target, _ := headerutils.ParseTargetHeader(h)
headers, _ := headerutils.MakePricingResponseHeaders(10, target, 0)
return headers
}
price := uint64(10)
pricerMock := pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc))
clientAddress := swarm.MustParseHexAddress("1010") clientAddress := swarm.MustParseHexAddress("1010")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment