Commit 1f31cf10 authored by metacertain's avatar metacertain Committed by GitHub

Pricing update strategy (#1134)

- add pricer pkg
- 0 out cost of in-neighborhood traffic
- communicate pricetable in pricing protocol
- add pricing headers and price communication to protocols
parent abd00984
...@@ -59,7 +59,7 @@ jobs: ...@@ -59,7 +59,7 @@ jobs:
run: | run: |
echo -e "127.0.0.10\tregistry.localhost" | sudo tee -a /etc/hosts echo -e "127.0.0.10\tregistry.localhost" | sudo tee -a /etc/hosts
for ((i=0; i<REPLICA; i++)); do echo -e "127.0.1.$((i+1))\tbee-${i}.localhost bee-${i}-debug.localhost"; done | sudo tee -a /etc/hosts for ((i=0; i<REPLICA; i++)); do echo -e "127.0.1.$((i+1))\tbee-${i}.localhost bee-${i}-debug.localhost"; done | sudo tee -a /etc/hosts
./beeinfra.sh install --local -r "${REPLICA}" --bootnode /dnsaddr/localhost --geth --k3s --pay-threshold 1000000000000 ./beeinfra.sh install --local -r "${REPLICA}" --bootnode /dnsaddr/localhost --geth --k3s
- name: Test pingpong - name: Test pingpong
id: pingpong-1 id: pingpong-1
run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done
...@@ -68,7 +68,7 @@ jobs: ...@@ -68,7 +68,7 @@ jobs:
run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Test settlements - name: Test settlements
id: settlements-1 id: settlements-1
run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 1000000000000 run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --expect-settlements=false
- name: Test pushsync (bytes) - name: Test pushsync (bytes)
id: pushsync-bytes-1 id: pushsync-bytes-1
run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3 run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3
...@@ -105,7 +105,7 @@ jobs: ...@@ -105,7 +105,7 @@ jobs:
cp /etc/rancher/k3s/k3s.yaml ~/.kube/config cp /etc/rancher/k3s/k3s.yaml ~/.kube/config
- name: Set testing cluster (Node connection and clef enabled) - name: Set testing cluster (Node connection and clef enabled)
run: | run: |
./beeinfra.sh install --local -r "${REPLICA}" --geth --clef --k3s --pay-threshold 1000000000000 ./beeinfra.sh install --local -r "${REPLICA}" --geth --clef --k3s
- name: Test pingpong - name: Test pingpong
id: pingpong-2 id: pingpong-2
run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done
...@@ -114,7 +114,7 @@ jobs: ...@@ -114,7 +114,7 @@ jobs:
run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Test settlements - name: Test settlements
id: settlements-2 id: settlements-2
run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 1000000000000 run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --expect-settlements=false
- name: Test pushsync (bytes) - name: Test pushsync (bytes)
id: pushsync-bytes-2 id: pushsync-bytes-2
run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3 run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3
......
...@@ -526,6 +526,13 @@ func (p *pricingMock) AnnouncePaymentThreshold(ctx context.Context, peer swarm.A ...@@ -526,6 +526,13 @@ func (p *pricingMock) AnnouncePaymentThreshold(ctx context.Context, peer swarm.A
return nil return nil
} }
func (p *pricingMock) AnnouncePaymentThresholdAndPriceTable(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error {
p.called = true
p.peer = peer
p.paymentThreshold = paymentThreshold
return nil
}
func TestAccountingConnected(t *testing.T) { func TestAccountingConnected(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"github.com/ethersphere/bee/pkg/swarm"
)
type MockPricer struct {
peerPrice uint64
price uint64
}
func NewPricer(price, peerPrice uint64) *MockPricer {
return &MockPricer{
peerPrice: peerPrice,
price: price,
}
}
func (pricer *MockPricer) PeerPrice(peer, chunk swarm.Address) uint64 {
return pricer.peerPrice
}
func (pricer *MockPricer) Price(chunk swarm.Address) uint64 {
return pricer.price
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package accounting
import (
"github.com/ethersphere/bee/pkg/swarm"
)
// Pricer returns pricing information for chunk hashes.
type Pricer interface {
// PeerPrice is the price the peer charges for a given chunk hash.
PeerPrice(peer, chunk swarm.Address) uint64
// Price is the price we charge for a given chunk hash.
Price(chunk swarm.Address) uint64
}
// FixedPricer is a Pricer that has a fixed price for chunks.
type FixedPricer struct {
overlay swarm.Address
poPrice uint64
}
// NewFixedPricer returns a new FixedPricer with a given price.
func NewFixedPricer(overlay swarm.Address, poPrice uint64) *FixedPricer {
return &FixedPricer{
overlay: overlay,
poPrice: poPrice,
}
}
// PeerPrice implements Pricer.
func (pricer *FixedPricer) PeerPrice(peer, chunk swarm.Address) uint64 {
return uint64(swarm.MaxPO-swarm.Proximity(peer.Bytes(), chunk.Bytes())+1) * pricer.poPrice
}
// Price implements Pricer.
func (pricer *FixedPricer) Price(chunk swarm.Address) uint64 {
return pricer.PeerPrice(pricer.overlay, chunk)
}
// Copyright 2021 The Swarm Authors. All rights reserved. // Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
...@@ -35,6 +35,7 @@ import ( ...@@ -35,6 +35,7 @@ import (
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/p2p/libp2p" "github.com/ethersphere/bee/pkg/p2p/libp2p"
"github.com/ethersphere/bee/pkg/pingpong" "github.com/ethersphere/bee/pkg/pingpong"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricing" "github.com/ethersphere/bee/pkg/pricing"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/puller" "github.com/ethersphere/bee/pkg/puller"
...@@ -310,6 +311,35 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -310,6 +311,35 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
var settlement settlement.Interface var settlement settlement.Interface
var swapService *swap.Service var swapService *swap.Service
kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, StandaloneMode: o.Standalone, BootnodeMode: o.BootnodeMode})
b.topologyCloser = kad
hive.SetAddPeersHandler(kad.AddPeers)
p2ps.SetPickyNotifier(kad)
paymentThreshold, ok := new(big.Int).SetString(o.PaymentThreshold, 10)
if !ok {
return nil, fmt.Errorf("invalid payment threshold: %s", paymentThreshold)
}
pricer := pricer.New(logger, stateStore, swarmAddress, 1000000000)
pricer.SetTopology(kad)
pricing := pricing.New(p2ps, logger, paymentThreshold, pricer)
pricing.SetPriceTableObserver(pricer)
if err = p2ps.AddProtocol(pricing.Protocol()); err != nil {
return nil, fmt.Errorf("pricing service: %w", err)
}
addrs, err := p2ps.Addresses()
if err != nil {
return nil, fmt.Errorf("get server addresses: %w", err)
}
for _, addr := range addrs {
logger.Debugf("p2p address: %s", addr)
}
if o.SwapEnable { if o.SwapEnable {
swapService, err = InitSwap( swapService, err = InitSwap(
p2ps, p2ps,
...@@ -333,15 +363,6 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -333,15 +363,6 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
settlement = pseudosettleService settlement = pseudosettleService
} }
paymentThreshold, ok := new(big.Int).SetString(o.PaymentThreshold, 10)
if !ok {
return nil, fmt.Errorf("invalid payment threshold: %s", paymentThreshold)
}
pricing := pricing.New(p2ps, logger, paymentThreshold)
if err = p2ps.AddProtocol(pricing.Protocol()); err != nil {
return nil, fmt.Errorf("pricing service: %w", err)
}
paymentTolerance, ok := new(big.Int).SetString(o.PaymentTolerance, 10) paymentTolerance, ok := new(big.Int).SetString(o.PaymentTolerance, 10)
if !ok { if !ok {
return nil, fmt.Errorf("invalid payment tolerance: %s", paymentTolerance) return nil, fmt.Errorf("invalid payment tolerance: %s", paymentTolerance)
...@@ -359,25 +380,13 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -359,25 +380,13 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
settlement, settlement,
pricing, pricing,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("accounting: %w", err) return nil, fmt.Errorf("accounting: %w", err)
} }
settlement.SetNotifyPaymentFunc(acc.AsyncNotifyPayment)
pricing.SetPaymentThresholdObserver(acc) pricing.SetPaymentThresholdObserver(acc)
settlement.SetNotifyPaymentFunc(acc.AsyncNotifyPayment)
kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, StandaloneMode: o.Standalone, BootnodeMode: o.BootnodeMode})
b.topologyCloser = kad
hive.SetAddPeersHandler(kad.AddPeers)
p2ps.SetPickyNotifier(kad)
addrs, err := p2ps.Addresses()
if err != nil {
return nil, fmt.Errorf("get server addresses: %w", err)
}
for _, addr := range addrs {
logger.Debugf("p2p address: %s", addr)
}
var path string var path string
...@@ -397,7 +406,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -397,7 +406,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
} }
b.localstoreCloser = storer b.localstoreCloser = storer
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 1000000000), tracer) retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer)
tagService := tags.NewTags(stateStore, logger) tagService := tags.NewTags(stateStore, logger)
b.tagsCloser = tagService b.tagsCloser = tagService
...@@ -419,7 +428,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -419,7 +428,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
traversalService := traversal.NewService(ns) traversalService := traversal.NewService(ns)
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, pssService.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 1000000000), tracer) pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, pssService.TryUnwrap, logger, acc, pricer, tracer)
// set the pushSyncer in the PSS // set the pushSyncer in the PSS
pssService.SetPushSyncer(pushSyncProtocol) pssService.SetPushSyncer(pushSyncProtocol)
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/headers/pb" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/headers/pb"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/swarm"
) )
var sendHeadersTimeout = 10 * time.Second var sendHeadersTimeout = 10 * time.Second
...@@ -36,7 +37,7 @@ func sendHeaders(ctx context.Context, headers p2p.Headers, stream *stream) error ...@@ -36,7 +37,7 @@ func sendHeaders(ctx context.Context, headers p2p.Headers, stream *stream) error
return nil return nil
} }
func handleHeaders(headler p2p.HeadlerFunc, stream *stream) error { func handleHeaders(headler p2p.HeadlerFunc, stream *stream, peerAddress swarm.Address) error {
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
ctx, cancel := context.WithTimeout(context.Background(), sendHeadersTimeout) ctx, cancel := context.WithTimeout(context.Background(), sendHeadersTimeout)
...@@ -51,9 +52,11 @@ func handleHeaders(headler p2p.HeadlerFunc, stream *stream) error { ...@@ -51,9 +52,11 @@ func handleHeaders(headler p2p.HeadlerFunc, stream *stream) error {
var h p2p.Headers var h p2p.Headers
if headler != nil { if headler != nil {
h = headler(stream.headers) h = headler(stream.headers, peerAddress)
} }
stream.responseHeaders = h
if err := w.WriteMsgWithContext(ctx, headersP2PToPB(h)); err != nil { if err := w.WriteMsgWithContext(ctx, headersP2PToPB(h)); err != nil {
return fmt.Errorf("write message: %w", err) return fmt.Errorf("write message: %w", err)
} }
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm"
) )
func TestHeaders(t *testing.T) { func TestHeaders(t *testing.T) {
...@@ -140,7 +141,7 @@ func TestHeadler(t *testing.T) { ...@@ -140,7 +141,7 @@ func TestHeadler(t *testing.T) {
Handler: func(_ context.Context, _ p2p.Peer, stream p2p.Stream) error { Handler: func(_ context.Context, _ p2p.Peer, stream p2p.Stream) error {
return nil return nil
}, },
Headler: func(headers p2p.Headers) p2p.Headers { Headler: func(headers p2p.Headers, address swarm.Address) p2p.Headers {
defer close(handled) defer close(handled)
gotReceivedHeaders = headers gotReceivedHeaders = headers
return sentHeaders return sentHeaders
......
...@@ -29,7 +29,7 @@ const ( ...@@ -29,7 +29,7 @@ const (
// ProtocolName is the text of the name of the handshake protocol. // ProtocolName is the text of the name of the handshake protocol.
ProtocolName = "handshake" ProtocolName = "handshake"
// ProtocolVersion is the current handshake protocol version. // ProtocolVersion is the current handshake protocol version.
ProtocolVersion = "2.0.0" ProtocolVersion = "3.0.0"
// StreamName is the name of the stream used for handshake purposes. // StreamName is the name of the stream used for handshake purposes.
StreamName = "handshake" StreamName = "handshake"
// MaxWelcomeMessageLength is maximum number of characters allowed in the welcome message. // MaxWelcomeMessageLength is maximum number of characters allowed in the welcome message.
......
...@@ -56,6 +56,10 @@ func (s *Stream) Headers() p2p.Headers { ...@@ -56,6 +56,10 @@ func (s *Stream) Headers() p2p.Headers {
return nil return nil
} }
func (s *Stream) ResponseHeaders() p2p.Headers {
return nil
}
func (s *Stream) Close() error { func (s *Stream) Close() error {
return nil return nil
} }
......
...@@ -372,7 +372,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) { ...@@ -372,7 +372,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
stream := newStream(streamlibp2p) stream := newStream(streamlibp2p)
// exchange headers // exchange headers
if err := handleHeaders(ss.Headler, stream); err != nil { if err := handleHeaders(ss.Headler, stream, overlay); err != nil {
s.logger.Debugf("handle protocol %s/%s: stream %s: peer %s: handle headers: %v", p.Name, p.Version, ss.Name, overlay, err) s.logger.Debugf("handle protocol %s/%s: stream %s: peer %s: handle headers: %v", p.Name, p.Version, ss.Name, overlay, err)
_ = stream.Reset() _ = stream.Reset()
return return
......
...@@ -21,7 +21,8 @@ var _ p2p.Stream = (*stream)(nil) ...@@ -21,7 +21,8 @@ var _ p2p.Stream = (*stream)(nil)
type stream struct { type stream struct {
network.Stream network.Stream
headers map[string][]byte headers map[string][]byte
responseHeaders map[string][]byte
} }
func NewStream(s network.Stream) p2p.Stream { func NewStream(s network.Stream) p2p.Stream {
...@@ -35,6 +36,10 @@ func (s *stream) Headers() p2p.Headers { ...@@ -35,6 +36,10 @@ func (s *stream) Headers() p2p.Headers {
return s.headers return s.headers
} }
func (s *stream) ResponseHeaders() p2p.Headers {
return s.responseHeaders
}
func (s *stream) FullClose() error { func (s *stream) FullClose() error {
// close the stream to make sure it is gc'd // close the stream to make sure it is gc'd
defer s.Close() defer s.Close()
......
...@@ -67,6 +67,7 @@ type StreamerDisconnecter interface { ...@@ -67,6 +67,7 @@ type StreamerDisconnecter interface {
type Stream interface { type Stream interface {
io.ReadWriter io.ReadWriter
io.Closer io.Closer
ResponseHeaders() Headers
Headers() Headers Headers() Headers
FullClose() error FullClose() error
Reset() error Reset() error
...@@ -103,7 +104,7 @@ type HandlerMiddleware func(HandlerFunc) HandlerFunc ...@@ -103,7 +104,7 @@ type HandlerMiddleware func(HandlerFunc) HandlerFunc
// HeadlerFunc is returning response headers based on the received request // HeadlerFunc is returning response headers based on the received request
// headers. // headers.
type HeadlerFunc func(Headers) Headers type HeadlerFunc func(Headers, swarm.Address) Headers
// Headers represents a collection of p2p header key value pairs. // Headers represents a collection of p2p header key value pairs.
type Headers map[string][]byte type Headers map[string][]byte
......
...@@ -323,6 +323,10 @@ func (noopWriteCloser) Headers() p2p.Headers { ...@@ -323,6 +323,10 @@ func (noopWriteCloser) Headers() p2p.Headers {
return nil return nil
} }
func (noopWriteCloser) ResponseHeaders() p2p.Headers {
return nil
}
func (noopWriteCloser) Close() error { func (noopWriteCloser) Close() error {
return nil return nil
} }
...@@ -351,6 +355,10 @@ func (noopReadCloser) Headers() p2p.Headers { ...@@ -351,6 +355,10 @@ func (noopReadCloser) Headers() p2p.Headers {
return nil return nil
} }
func (noopReadCloser) ResponseHeaders() p2p.Headers {
return nil
}
func (noopReadCloser) Close() error { func (noopReadCloser) Close() error {
return nil return nil
} }
......
...@@ -97,7 +97,7 @@ func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Head ...@@ -97,7 +97,7 @@ func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Head
handler = r.middlewares[i](handler) handler = r.middlewares[i](handler)
} }
if headler != nil { if headler != nil {
streamOut.headers = headler(h) streamOut.headers = headler(h, addr)
} }
record := &Record{in: recordIn, out: recordOut, done: make(chan struct{})} record := &Record{in: recordIn, out: recordOut, done: make(chan struct{})}
go func() { go func() {
...@@ -194,9 +194,10 @@ func (r *Record) setErr(err error) { ...@@ -194,9 +194,10 @@ func (r *Record) setErr(err error) {
} }
type stream struct { type stream struct {
in *record in *record
out *record out *record
headers p2p.Headers headers p2p.Headers
responseHeaders p2p.Headers
} }
func newStream(in, out *record) *stream { func newStream(in, out *record) *stream {
...@@ -215,6 +216,10 @@ func (s *stream) Headers() p2p.Headers { ...@@ -215,6 +216,10 @@ func (s *stream) Headers() p2p.Headers {
return s.headers return s.headers
} }
func (s *stream) ResponseHeaders() p2p.Headers {
return s.responseHeaders
}
func (s *stream) Close() error { func (s *stream) Close() error {
return s.in.Close() return s.in.Close()
} }
......
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pricer
import (
"github.com/ethersphere/bee/pkg/swarm"
)
func (s *Pricer) PeerPricePO(peer swarm.Address, po uint8) (uint64, error) {
return s.peerPricePO(peer, po)
}
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package headerutils
const (
PriceFieldName = priceFieldName
TargetFieldName = targetFieldName
IndexFieldName = indexFieldName
)
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package headerutils
import (
"encoding/binary"
"errors"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm"
)
const (
priceFieldName = "price"
targetFieldName = "target"
indexFieldName = "index"
)
var (
// ErrFieldLength denotes p2p.Header having malformed field length in bytes
ErrFieldLength = errors.New("field length error")
// ErrNoIndexHeader denotes p2p.Header lacking specified field
ErrNoIndexHeader = errors.New("no index header")
// ErrNoTargetHeader denotes p2p.Header lacking specified field
ErrNoTargetHeader = errors.New("no target header")
// ErrNoPriceHeader denotes p2p.Header lacking specified field
ErrNoPriceHeader = errors.New("no price header")
)
// Headers, utility functions
func MakePricingHeaders(chunkPrice uint64, addr swarm.Address) (p2p.Headers, error) {
chunkPriceInBytes := make([]byte, 8)
binary.BigEndian.PutUint64(chunkPriceInBytes, chunkPrice)
headers := p2p.Headers{
priceFieldName: chunkPriceInBytes,
targetFieldName: addr.Bytes(),
}
return headers, nil
}
func MakePricingResponseHeaders(chunkPrice uint64, addr swarm.Address, index uint8) (p2p.Headers, error) {
chunkPriceInBytes := make([]byte, 8)
chunkIndexInBytes := make([]byte, 1)
binary.BigEndian.PutUint64(chunkPriceInBytes, chunkPrice)
chunkIndexInBytes[0] = index
headers := p2p.Headers{
priceFieldName: chunkPriceInBytes,
targetFieldName: addr.Bytes(),
indexFieldName: chunkIndexInBytes,
}
return headers, nil
}
// ParsePricingHeaders used by responder to read address and price from stream headers
// Returns an error if no target field attached or the contents of it are not readable
func ParsePricingHeaders(receivedHeaders p2p.Headers) (swarm.Address, uint64, error) {
target, err := ParseTargetHeader(receivedHeaders)
if err != nil {
return swarm.ZeroAddress, 0, err
}
price, err := ParsePriceHeader(receivedHeaders)
if err != nil {
return swarm.ZeroAddress, 0, err
}
return target, price, nil
}
// ParsePricingResponseHeaders used by requester to read address, price and index from response headers
// Returns an error if any fields are missing or target is unreadable
func ParsePricingResponseHeaders(receivedHeaders p2p.Headers) (swarm.Address, uint64, uint8, error) {
target, err := ParseTargetHeader(receivedHeaders)
if err != nil {
return swarm.ZeroAddress, 0, 0, err
}
price, err := ParsePriceHeader(receivedHeaders)
if err != nil {
return swarm.ZeroAddress, 0, 0, err
}
index, err := ParseIndexHeader(receivedHeaders)
if err != nil {
return swarm.ZeroAddress, 0, 0, err
}
return target, price, index, nil
}
func ParseIndexHeader(receivedHeaders p2p.Headers) (uint8, error) {
if receivedHeaders[indexFieldName] == nil {
return 0, ErrNoIndexHeader
}
if len(receivedHeaders[indexFieldName]) != 1 {
return 0, ErrFieldLength
}
index := receivedHeaders[indexFieldName][0]
return index, nil
}
func ParseTargetHeader(receivedHeaders p2p.Headers) (swarm.Address, error) {
if receivedHeaders[targetFieldName] == nil {
return swarm.ZeroAddress, ErrNoTargetHeader
}
target := swarm.NewAddress(receivedHeaders[targetFieldName])
return target, nil
}
func ParsePriceHeader(receivedHeaders p2p.Headers) (uint64, error) {
if receivedHeaders[priceFieldName] == nil {
return 0, ErrNoPriceHeader
}
if len(receivedHeaders[priceFieldName]) != 8 {
return 0, ErrFieldLength
}
receivedPrice := binary.BigEndian.Uint64(receivedHeaders[priceFieldName])
return receivedPrice, nil
}
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package headerutils_test
import (
"reflect"
"testing"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestMakePricingHeaders(t *testing.T) {
addr := swarm.MustParseHexAddress("010101e1010101")
makeHeaders, err := headerutils.MakePricingHeaders(uint64(5348), addr)
if err != nil {
t.Fatal(err)
}
expectedHeaders := p2p.Headers{
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 0, 20, 228},
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
}
if !reflect.DeepEqual(makeHeaders, expectedHeaders) {
t.Fatalf("Made headers not as expected, got %+v, want %+v", makeHeaders, expectedHeaders)
}
}
func TestMakePricingResponseHeaders(t *testing.T) {
addr := swarm.MustParseHexAddress("010101e1010101")
makeHeaders, err := headerutils.MakePricingResponseHeaders(uint64(5348), addr, uint8(11))
if err != nil {
t.Fatal(err)
}
expectedHeaders := p2p.Headers{
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 0, 20, 228},
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
headerutils.IndexFieldName: []byte{11},
}
if !reflect.DeepEqual(makeHeaders, expectedHeaders) {
t.Fatalf("Made headers not as expected, got %+v, want %+v", makeHeaders, expectedHeaders)
}
}
func TestParsePricingHeaders(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 0, 20, 228},
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
}
parsedTarget, parsedPrice, err := headerutils.ParsePricingHeaders(toReadHeaders)
if err != nil {
t.Fatal(err)
}
addr := swarm.MustParseHexAddress("010101e1010101")
if parsedPrice != uint64(5348) {
t.Fatalf("Price mismatch, got %v, want %v", parsedPrice, 5348)
}
if !parsedTarget.Equal(addr) {
t.Fatalf("Target mismatch, got %v, want %v", parsedTarget, addr)
}
}
func TestParsePricingResponseHeaders(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 0, 20, 228},
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
headerutils.IndexFieldName: []byte{11},
}
parsedTarget, parsedPrice, parsedIndex, err := headerutils.ParsePricingResponseHeaders(toReadHeaders)
if err != nil {
t.Fatal(err)
}
addr := swarm.MustParseHexAddress("010101e1010101")
if parsedPrice != uint64(5348) {
t.Fatalf("Price mismatch, got %v, want %v", parsedPrice, 5348)
}
if parsedIndex != uint8(11) {
t.Fatalf("Price mismatch, got %v, want %v", parsedPrice, 5348)
}
if !parsedTarget.Equal(addr) {
t.Fatalf("Target mismatch, got %v, want %v", parsedTarget, addr)
}
}
func TestParseIndexHeader(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.IndexFieldName: []byte{11},
}
parsedIndex, err := headerutils.ParseIndexHeader(toReadHeaders)
if err != nil {
t.Fatal(err)
}
if parsedIndex != uint8(11) {
t.Fatalf("Index mismatch, got %v, want %v", parsedIndex, 11)
}
}
func TestParseTargetHeader(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
}
parsedTarget, err := headerutils.ParseTargetHeader(toReadHeaders)
if err != nil {
t.Fatal(err)
}
addr := swarm.MustParseHexAddress("010101e1010101")
if !parsedTarget.Equal(addr) {
t.Fatalf("Target mismatch, got %v, want %v", parsedTarget, addr)
}
}
func TestParsePriceHeader(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 0, 20, 228},
}
parsedPrice, err := headerutils.ParsePriceHeader(toReadHeaders)
if err != nil {
t.Fatal(err)
}
if parsedPrice != uint64(5348) {
t.Fatalf("Index mismatch, got %v, want %v", parsedPrice, 5348)
}
}
func TestReadMalformedHeaders(t *testing.T) {
toReadHeaders := p2p.Headers{
headerutils.IndexFieldName: []byte{11, 0},
headerutils.TargetFieldName: []byte{1, 1, 1, 225, 1, 1, 1},
headerutils.PriceFieldName: []byte{0, 0, 0, 0, 0, 20, 228},
}
_, err := headerutils.ParseIndexHeader(toReadHeaders)
if err == nil {
t.Fatal("Expected error from bad length of index bytes")
}
_, err = headerutils.ParsePriceHeader(toReadHeaders)
if err == nil {
t.Fatal("Expected error from bad length of price bytes")
}
_, _, _, err = headerutils.ParsePricingResponseHeaders(toReadHeaders)
if err == nil {
t.Fatal("Expected error caused by bad length of fields")
}
_, _, err = headerutils.ParsePricingHeaders(toReadHeaders)
if err == nil {
t.Fatal("Expected error caused by bad length of fields")
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm"
)
type Service struct {
peerPrice uint64
price uint64
peerPriceFunc func(peer, chunk swarm.Address) uint64
priceForPeerFunc func(peer, chunk swarm.Address) uint64
priceTableFunc func() (priceTable []uint64)
notifyPriceTableFunc func(peer swarm.Address, priceTable []uint64) error
priceHeadlerFunc func(p2p.Headers, swarm.Address) p2p.Headers
notifyPeerPriceFunc func(peer swarm.Address, price uint64, index uint8) error
}
// WithReserveFunc sets the mock Reserve function
func WithPeerPriceFunc(f func(peer, chunk swarm.Address) uint64) Option {
return optionFunc(func(s *Service) {
s.peerPriceFunc = f
})
}
// WithReleaseFunc sets the mock Release function
func WithPriceForPeerFunc(f func(peer, chunk swarm.Address) uint64) Option {
return optionFunc(func(s *Service) {
s.priceForPeerFunc = f
})
}
func WithPrice(p uint64) Option {
return optionFunc(func(s *Service) {
s.price = p
})
}
func WithPeerPrice(p uint64) Option {
return optionFunc(func(s *Service) {
s.peerPrice = p
})
}
// WithPriceTableFunc sets the mock Release function
func WithPriceTableFunc(f func() (priceTable []uint64)) Option {
return optionFunc(func(s *Service) {
s.priceTableFunc = f
})
}
func WithPriceHeadlerFunc(f func(headers p2p.Headers, addr swarm.Address) p2p.Headers) Option {
return optionFunc(func(s *Service) {
s.priceHeadlerFunc = f
})
}
func NewMockService(opts ...Option) *Service {
mock := new(Service)
mock.price = 10
mock.peerPrice = 10
for _, o := range opts {
o.apply(mock)
}
return mock
}
func (pricer *Service) PeerPrice(peer, chunk swarm.Address) uint64 {
if pricer.peerPriceFunc != nil {
return pricer.peerPriceFunc(peer, chunk)
}
return pricer.peerPrice
}
func (pricer *Service) PriceForPeer(peer, chunk swarm.Address) uint64 {
if pricer.priceForPeerFunc != nil {
return pricer.priceForPeerFunc(peer, chunk)
}
return pricer.price
}
func (pricer *Service) PriceTable() (priceTable []uint64) {
if pricer.priceTableFunc != nil {
return pricer.priceTableFunc()
}
return nil
}
func (pricer *Service) NotifyPriceTable(peer swarm.Address, priceTable []uint64) error {
if pricer.notifyPriceTableFunc != nil {
return pricer.notifyPriceTableFunc(peer, priceTable)
}
return nil
}
func (pricer *Service) PriceHeadler(headers p2p.Headers, addr swarm.Address) p2p.Headers {
if pricer.priceHeadlerFunc != nil {
return pricer.priceHeadlerFunc(headers, addr)
}
return p2p.Headers{}
}
func (pricer *Service) NotifyPeerPrice(peer swarm.Address, price uint64, index uint8) error {
if pricer.notifyPeerPriceFunc != nil {
return pricer.notifyPeerPriceFunc(peer, price, index)
}
return nil
}
// Option is the option passed to the mock accounting service
type Option interface {
apply(*Service)
}
type optionFunc func(*Service)
func (f optionFunc) apply(r *Service) { f(r) }
This diff is collapsed.
This diff is collapsed.
...@@ -23,7 +23,8 @@ var _ = math.Inf ...@@ -23,7 +23,8 @@ var _ = math.Inf
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type AnnouncePaymentThreshold struct { type AnnouncePaymentThreshold struct {
PaymentThreshold []byte `protobuf:"bytes,1,opt,name=PaymentThreshold,proto3" json:"PaymentThreshold,omitempty"` PaymentThreshold []byte `protobuf:"bytes,1,opt,name=PaymentThreshold,proto3" json:"PaymentThreshold,omitempty"`
ProximityPrice []uint64 `protobuf:"varint,2,rep,packed,name=ProximityPrice,proto3" json:"ProximityPrice,omitempty"`
} }
func (m *AnnouncePaymentThreshold) Reset() { *m = AnnouncePaymentThreshold{} } func (m *AnnouncePaymentThreshold) Reset() { *m = AnnouncePaymentThreshold{} }
...@@ -66,6 +67,13 @@ func (m *AnnouncePaymentThreshold) GetPaymentThreshold() []byte { ...@@ -66,6 +67,13 @@ func (m *AnnouncePaymentThreshold) GetPaymentThreshold() []byte {
return nil return nil
} }
func (m *AnnouncePaymentThreshold) GetProximityPrice() []uint64 {
if m != nil {
return m.ProximityPrice
}
return nil
}
func init() { func init() {
proto.RegisterType((*AnnouncePaymentThreshold)(nil), "pricing.AnnouncePaymentThreshold") proto.RegisterType((*AnnouncePaymentThreshold)(nil), "pricing.AnnouncePaymentThreshold")
} }
...@@ -73,15 +81,17 @@ func init() { ...@@ -73,15 +81,17 @@ func init() {
func init() { proto.RegisterFile("pricing.proto", fileDescriptor_ec4cc93d045d43d0) } func init() { proto.RegisterFile("pricing.proto", fileDescriptor_ec4cc93d045d43d0) }
var fileDescriptor_ec4cc93d045d43d0 = []byte{ var fileDescriptor_ec4cc93d045d43d0 = []byte{
// 122 bytes of a gzipped FileDescriptorProto // 150 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x28, 0xca, 0x4c, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x28, 0xca, 0x4c,
0xce, 0xcc, 0x4b, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x87, 0x72, 0x95, 0xdc, 0xb8, 0xce, 0xcc, 0x4b, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x87, 0x72, 0x95, 0xf2, 0xb8,
0x24, 0x1c, 0xf3, 0xf2, 0xf2, 0x4b, 0xf3, 0x92, 0x53, 0x03, 0x12, 0x2b, 0x73, 0x53, 0xf3, 0x4a, 0x24, 0x1c, 0xf3, 0xf2, 0xf2, 0x4b, 0xf3, 0x92, 0x53, 0x03, 0x12, 0x2b, 0x73, 0x53, 0xf3, 0x4a,
0x42, 0x32, 0x8a, 0x52, 0x8b, 0x33, 0xf2, 0x73, 0x52, 0x84, 0xb4, 0xb8, 0x04, 0xd0, 0xc5, 0x24, 0x42, 0x32, 0x8a, 0x52, 0x8b, 0x33, 0xf2, 0x73, 0x52, 0x84, 0xb4, 0xb8, 0x04, 0xd0, 0xc5, 0x24,
0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x30, 0xc4, 0x9d, 0x64, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x30, 0xc4, 0x85, 0xd4, 0xb8, 0xf8, 0x02, 0x8a, 0xf2, 0x2b,
0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x09, 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8, 0x32, 0x73, 0x33, 0x4b, 0x2a, 0x03, 0x8a, 0x32, 0x93, 0x53, 0x25, 0x98, 0x14, 0x98, 0x35, 0x58,
0xf1, 0x58, 0x8e, 0x21, 0x8a, 0xa9, 0x20, 0x29, 0x89, 0x0d, 0x6c, 0xab, 0x31, 0x20, 0x00, 0x00, 0x82, 0xd0, 0x44, 0x9d, 0x64, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23,
0xff, 0xff, 0x50, 0xca, 0x0e, 0x0a, 0x86, 0x00, 0x00, 0x00, 0x39, 0xc6, 0x09, 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8, 0xf1, 0x58, 0x8e, 0x21, 0x8a,
0xa9, 0x20, 0x29, 0x89, 0x0d, 0xec, 0x3a, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff, 0x90, 0x49,
0x1c, 0x6f, 0xae, 0x00, 0x00, 0x00,
} }
func (m *AnnouncePaymentThreshold) Marshal() (dAtA []byte, err error) { func (m *AnnouncePaymentThreshold) Marshal() (dAtA []byte, err error) {
...@@ -104,6 +114,24 @@ func (m *AnnouncePaymentThreshold) MarshalToSizedBuffer(dAtA []byte) (int, error ...@@ -104,6 +114,24 @@ func (m *AnnouncePaymentThreshold) MarshalToSizedBuffer(dAtA []byte) (int, error
_ = i _ = i
var l int var l int
_ = l _ = l
if len(m.ProximityPrice) > 0 {
dAtA2 := make([]byte, len(m.ProximityPrice)*10)
var j1 int
for _, num := range m.ProximityPrice {
for num >= 1<<7 {
dAtA2[j1] = uint8(uint64(num)&0x7f | 0x80)
num >>= 7
j1++
}
dAtA2[j1] = uint8(num)
j1++
}
i -= j1
copy(dAtA[i:], dAtA2[:j1])
i = encodeVarintPricing(dAtA, i, uint64(j1))
i--
dAtA[i] = 0x12
}
if len(m.PaymentThreshold) > 0 { if len(m.PaymentThreshold) > 0 {
i -= len(m.PaymentThreshold) i -= len(m.PaymentThreshold)
copy(dAtA[i:], m.PaymentThreshold) copy(dAtA[i:], m.PaymentThreshold)
...@@ -135,6 +163,13 @@ func (m *AnnouncePaymentThreshold) Size() (n int) { ...@@ -135,6 +163,13 @@ func (m *AnnouncePaymentThreshold) Size() (n int) {
if l > 0 { if l > 0 {
n += 1 + l + sovPricing(uint64(l)) n += 1 + l + sovPricing(uint64(l))
} }
if len(m.ProximityPrice) > 0 {
l = 0
for _, e := range m.ProximityPrice {
l += sovPricing(uint64(e))
}
n += 1 + sovPricing(uint64(l)) + l
}
return n return n
} }
...@@ -207,6 +242,82 @@ func (m *AnnouncePaymentThreshold) Unmarshal(dAtA []byte) error { ...@@ -207,6 +242,82 @@ func (m *AnnouncePaymentThreshold) Unmarshal(dAtA []byte) error {
m.PaymentThreshold = []byte{} m.PaymentThreshold = []byte{}
} }
iNdEx = postIndex iNdEx = postIndex
case 2:
if wireType == 0 {
var v uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPricing
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
v |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
m.ProximityPrice = append(m.ProximityPrice, v)
} else if wireType == 2 {
var packedLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPricing
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
packedLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if packedLen < 0 {
return ErrInvalidLengthPricing
}
postIndex := iNdEx + packedLen
if postIndex < 0 {
return ErrInvalidLengthPricing
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
var elementCount int
var count int
for _, integer := range dAtA[iNdEx:postIndex] {
if integer < 128 {
count++
}
}
elementCount = count
if elementCount != 0 && len(m.ProximityPrice) == 0 {
m.ProximityPrice = make([]uint64, 0, elementCount)
}
for iNdEx < postIndex {
var v uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPricing
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
v |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
m.ProximityPrice = append(m.ProximityPrice, v)
}
} else {
return fmt.Errorf("proto: wrong wireType = %d for field ProximityPrice", wireType)
}
default: default:
iNdEx = preIndex iNdEx = preIndex
skippy, err := skipPricing(dAtA[iNdEx:]) skippy, err := skipPricing(dAtA[iNdEx:])
......
...@@ -9,5 +9,6 @@ package pricing; ...@@ -9,5 +9,6 @@ package pricing;
option go_package = "pb"; option go_package = "pb";
message AnnouncePaymentThreshold { message AnnouncePaymentThreshold {
bytes PaymentThreshold = 1; bytes PaymentThreshold = 1;
repeated uint64 ProximityPrice = 2;
} }
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricing/pb" "github.com/ethersphere/bee/pkg/pricing/pb"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -28,6 +29,12 @@ var _ Interface = (*Service)(nil) ...@@ -28,6 +29,12 @@ var _ Interface = (*Service)(nil)
// Interface is the main interface of the pricing protocol // Interface is the main interface of the pricing protocol
type Interface interface { type Interface interface {
AnnouncePaymentThreshold(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error AnnouncePaymentThreshold(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error
AnnouncePaymentThresholdAndPriceTable(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error
}
// PriceTableObserver is used for being notified of price table updates
type PriceTableObserver interface {
NotifyPriceTable(peer swarm.Address, priceTable []uint64) error
} }
// PaymentThresholdObserver is used for being notified of payment threshold updates // PaymentThresholdObserver is used for being notified of payment threshold updates
...@@ -39,13 +46,16 @@ type Service struct { ...@@ -39,13 +46,16 @@ type Service struct {
streamer p2p.Streamer streamer p2p.Streamer
logger logging.Logger logger logging.Logger
paymentThreshold *big.Int paymentThreshold *big.Int
pricer pricer.Interface
paymentThresholdObserver PaymentThresholdObserver paymentThresholdObserver PaymentThresholdObserver
priceTableObserver PriceTableObserver
} }
func New(streamer p2p.Streamer, logger logging.Logger, paymentThreshold *big.Int) *Service { func New(streamer p2p.Streamer, logger logging.Logger, paymentThreshold *big.Int, pricer pricer.Interface) *Service {
return &Service{ return &Service{
streamer: streamer, streamer: streamer,
logger: logger, logger: logger,
pricer: pricer,
paymentThreshold: paymentThreshold, paymentThreshold: paymentThreshold,
} }
} }
...@@ -77,17 +87,30 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -77,17 +87,30 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
var req pb.AnnouncePaymentThreshold var req pb.AnnouncePaymentThreshold
if err := r.ReadMsgWithContext(ctx, &req); err != nil { if err := r.ReadMsgWithContext(ctx, &req); err != nil {
s.logger.Debugf("could not receive payment threshold announcement from peer %v", p.Address) s.logger.Debugf("could not receive payment threshold and/or price table announcement from peer %v", p.Address)
return fmt.Errorf("read request from peer %v: %w", p.Address, err) return fmt.Errorf("read request from peer %v: %w", p.Address, err)
} }
paymentThreshold := big.NewInt(0).SetBytes(req.PaymentThreshold) paymentThreshold := big.NewInt(0).SetBytes(req.PaymentThreshold)
s.logger.Tracef("received payment threshold announcement from peer %v of %d", p.Address, paymentThreshold) s.logger.Tracef("received payment threshold announcement from peer %v of %d", p.Address, paymentThreshold)
if req.ProximityPrice != nil {
s.logger.Tracef("received pricetable announcement from peer %v of %v", p.Address, req.ProximityPrice)
err = s.priceTableObserver.NotifyPriceTable(p.Address, req.ProximityPrice)
if err != nil {
s.logger.Debugf("error receiving pricetable from peer %v: %w", p.Address, err)
s.logger.Errorf("error receiving pricetable from peer %v: %w", p.Address, err)
}
}
if paymentThreshold.Cmp(big.NewInt(0)) == 0 {
return err
}
return s.paymentThresholdObserver.NotifyPaymentThreshold(p.Address, paymentThreshold) return s.paymentThresholdObserver.NotifyPaymentThreshold(p.Address, paymentThreshold)
} }
func (s *Service) init(ctx context.Context, p p2p.Peer) error { func (s *Service) init(ctx context.Context, p p2p.Peer) error {
err := s.AnnouncePaymentThreshold(ctx, p.Address, s.paymentThreshold) err := s.AnnouncePaymentThresholdAndPriceTable(ctx, p.Address, s.paymentThreshold)
if err != nil { if err != nil {
s.logger.Warningf("could not send payment threshold announcement to peer %v", p.Address) s.logger.Warningf("could not send payment threshold announcement to peer %v", p.Address)
} }
...@@ -120,7 +143,39 @@ func (s *Service) AnnouncePaymentThreshold(ctx context.Context, peer swarm.Addre ...@@ -120,7 +143,39 @@ func (s *Service) AnnouncePaymentThreshold(ctx context.Context, peer swarm.Addre
return err return err
} }
// AnnouncePaymentThresholdAndPriceTable announces own payment threshold and pricetable to peer
func (s *Service) AnnouncePaymentThresholdAndPriceTable(ctx context.Context, peer swarm.Address, paymentThreshold *big.Int) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil {
return err
}
defer func() {
if err != nil {
_ = stream.Reset()
} else {
go stream.FullClose()
}
}()
s.logger.Tracef("sending payment threshold announcement to peer %v of %d", peer, paymentThreshold)
w := protobuf.NewWriter(stream)
err = w.WriteMsgWithContext(ctx, &pb.AnnouncePaymentThreshold{
PaymentThreshold: paymentThreshold.Bytes(),
ProximityPrice: s.pricer.PriceTable(),
})
return err
}
// SetPaymentThresholdObserver sets the PaymentThresholdObserver to be used when receiving a new payment threshold // SetPaymentThresholdObserver sets the PaymentThresholdObserver to be used when receiving a new payment threshold
func (s *Service) SetPaymentThresholdObserver(observer PaymentThresholdObserver) { func (s *Service) SetPaymentThresholdObserver(observer PaymentThresholdObserver) {
s.paymentThresholdObserver = observer s.paymentThresholdObserver = observer
} }
// SetPriceTableObserver sets the PriceTableObserver to be used when receiving a new pricetable
func (s *Service) SetPriceTableObserver(observer PriceTableObserver) {
s.priceTableObserver = observer
}
...@@ -9,35 +9,52 @@ import ( ...@@ -9,35 +9,52 @@ import (
"context" "context"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"reflect"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
pricermock "github.com/ethersphere/bee/pkg/pricer/mock"
"github.com/ethersphere/bee/pkg/pricing" "github.com/ethersphere/bee/pkg/pricing"
"github.com/ethersphere/bee/pkg/pricing/pb" "github.com/ethersphere/bee/pkg/pricing/pb"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
type testObserver struct { type testThresholdObserver struct {
called bool called bool
peer swarm.Address peer swarm.Address
paymentThreshold *big.Int paymentThreshold *big.Int
} }
func (t *testObserver) NotifyPaymentThreshold(peer swarm.Address, paymentThreshold *big.Int) error { type testPriceTableObserver struct {
called bool
peer swarm.Address
priceTable []uint64
}
func (t *testThresholdObserver) NotifyPaymentThreshold(peerAddr swarm.Address, paymentThreshold *big.Int) error {
t.called = true t.called = true
t.peer = peer t.peer = peerAddr
t.paymentThreshold = paymentThreshold t.paymentThreshold = paymentThreshold
return nil return nil
} }
func (t *testPriceTableObserver) NotifyPriceTable(peerAddr swarm.Address, priceTable []uint64) error {
t.called = true
t.peer = peerAddr
t.priceTable = priceTable
return nil
}
func TestAnnouncePaymentThreshold(t *testing.T) { func TestAnnouncePaymentThreshold(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
testThreshold := big.NewInt(100000) testThreshold := big.NewInt(100000)
observer := &testObserver{} observer := &testThresholdObserver{}
pricerMockService := pricermock.NewMockService()
recipient := pricing.New(nil, logger, testThreshold) recipient := pricing.New(nil, logger, testThreshold, pricerMockService)
recipient.SetPaymentThresholdObserver(observer) recipient.SetPaymentThresholdObserver(observer)
peerID := swarm.MustParseHexAddress("9ee7add7") peerID := swarm.MustParseHexAddress("9ee7add7")
...@@ -47,7 +64,7 @@ func TestAnnouncePaymentThreshold(t *testing.T) { ...@@ -47,7 +64,7 @@ func TestAnnouncePaymentThreshold(t *testing.T) {
streamtest.WithBaseAddr(peerID), streamtest.WithBaseAddr(peerID),
) )
payer := pricing.New(recorder, logger, testThreshold) payer := pricing.New(recorder, logger, testThreshold, pricerMockService)
paymentThreshold := big.NewInt(10000) paymentThreshold := big.NewInt(10000)
...@@ -96,3 +113,93 @@ func TestAnnouncePaymentThreshold(t *testing.T) { ...@@ -96,3 +113,93 @@ func TestAnnouncePaymentThreshold(t *testing.T) {
t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID) t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID)
} }
} }
func TestAnnouncePaymentThresholdAndPriceTable(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
testThreshold := big.NewInt(100000)
observer1 := &testThresholdObserver{}
observer2 := &testPriceTableObserver{}
table := []uint64{50, 25, 12, 6}
priceTableFunc := func() []uint64 {
return table
}
pricerMockService := pricermock.NewMockService(pricermock.WithPriceTableFunc(priceTableFunc))
recipient := pricing.New(nil, logger, testThreshold, pricerMockService)
recipient.SetPaymentThresholdObserver(observer1)
recipient.SetPriceTableObserver(observer2)
peerID := swarm.MustParseHexAddress("9ee7add7")
recorder := streamtest.New(
streamtest.WithProtocols(recipient.Protocol()),
streamtest.WithBaseAddr(peerID),
)
payer := pricing.New(recorder, logger, testThreshold, pricerMockService)
paymentThreshold := big.NewInt(10000)
err := payer.AnnouncePaymentThresholdAndPriceTable(context.Background(), peerID, paymentThreshold)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want %v", l, 1)
}
record := records[0]
messages, err := protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.AnnouncePaymentThreshold) },
)
if err != nil {
t.Fatal(err)
}
if len(messages) != 1 {
t.Fatalf("got %v messages, want %v", len(messages), 1)
}
sentPaymentThreshold := big.NewInt(0).SetBytes(messages[0].(*pb.AnnouncePaymentThreshold).PaymentThreshold)
if sentPaymentThreshold.Cmp(paymentThreshold) != 0 {
t.Fatalf("got message with amount %v, want %v", sentPaymentThreshold, paymentThreshold)
}
sentPriceTable := messages[0].(*pb.AnnouncePaymentThreshold).ProximityPrice
if !reflect.DeepEqual(sentPriceTable, table) {
t.Fatalf("got message with table %v, want %v", sentPriceTable, table)
}
if !observer1.called {
t.Fatal("expected threshold observer to be called")
}
if observer1.paymentThreshold.Cmp(paymentThreshold) != 0 {
t.Fatalf("observer called with wrong paymentThreshold. got %d, want %d", observer1.paymentThreshold, paymentThreshold)
}
if !observer1.peer.Equal(peerID) {
t.Fatalf("threshold observer called with wrong peer. got %v, want %v", observer1.peer, peerID)
}
if !observer2.called {
t.Fatal("expected table observer to be called")
}
if !reflect.DeepEqual(observer2.priceTable, table) {
t.Fatalf("table observer called with wrong priceTable. got %d, want %d", observer2.priceTable, table)
}
if !observer2.peer.Equal(peerID) {
t.Fatalf("table observer called with wrong peer. got %v, want %v", observer2.peer, peerID)
}
}
...@@ -17,6 +17,8 @@ import ( ...@@ -17,6 +17,8 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
"github.com/ethersphere/bee/pkg/pushsync/pb" "github.com/ethersphere/bee/pkg/pushsync/pb"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -53,14 +55,14 @@ type PushSync struct { ...@@ -53,14 +55,14 @@ type PushSync struct {
unwrap func(swarm.Chunk) unwrap func(swarm.Chunk)
logger logging.Logger logger logging.Logger
accounting accounting.Interface accounting accounting.Interface
pricer accounting.Pricer pricer pricer.Interface
metrics metrics metrics metrics
tracer *tracing.Tracer tracer *tracing.Tracer
} }
var timeToLive = 5 * time.Second // request time to live var timeToLive = 5 * time.Second // request time to live
func New(streamer p2p.StreamerDisconnecter, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, unwrap func(swarm.Chunk), logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync { func New(streamer p2p.StreamerDisconnecter, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, unwrap func(swarm.Chunk), logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, tracer *tracing.Tracer) *PushSync {
ps := &PushSync{ ps := &PushSync{
streamer: streamer, streamer: streamer,
storer: storer, storer: storer,
...@@ -84,6 +86,7 @@ func (s *PushSync) Protocol() p2p.ProtocolSpec { ...@@ -84,6 +86,7 @@ func (s *PushSync) Protocol() p2p.ProtocolSpec {
{ {
Name: streamName, Name: streamName,
Handler: s.handler, Handler: s.handler,
Headler: s.pricer.PriceHeadler,
}, },
}, },
} }
...@@ -122,6 +125,16 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -122,6 +125,16 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
span, _, ctx := ps.tracer.StartSpanFromContext(ctx, "pushsync-handler", ps.logger, opentracing.Tag{Key: "address", Value: chunk.Address().String()}) span, _, ctx := ps.tracer.StartSpanFromContext(ctx, "pushsync-handler", ps.logger, opentracing.Tag{Key: "address", Value: chunk.Address().String()})
defer span.Finish() defer span.Finish()
// Get price we charge for upstream peer read at headler
responseHeaders := stream.ResponseHeaders()
price, err := headerutils.ParsePriceHeader(responseHeaders)
if err != nil {
// if not found in returned header, compute the price we charge for this chunk and
ps.logger.Warningf("push sync: peer %v no price in previously issued response headers: %v", p.Address, err)
price = ps.pricer.PriceForPeer(p.Address, chunk.Address())
}
receipt, err := ps.pushToClosest(ctx, chunk) receipt, err := ps.pushToClosest(ctx, chunk)
if err != nil { if err != nil {
if errors.Is(err, topology.ErrWantSelf) { if errors.Is(err, topology.ErrWantSelf) {
...@@ -135,9 +148,10 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -135,9 +148,10 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err) return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
} }
return ps.accounting.Debit(p.Address, ps.pricer.Price(chunk.Address())) return ps.accounting.Debit(p.Address, price)
} }
return fmt.Errorf("handler: push to closest: %w", err) return fmt.Errorf("handler: push to closest: %w", err)
} }
// pass back the receipt // pass back the receipt
...@@ -145,7 +159,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -145,7 +159,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err) return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
} }
return ps.accounting.Debit(p.Address, ps.pricer.Price(chunk.Address())) return ps.accounting.Debit(p.Address, price)
} }
// PushChunkToClosest sends chunk to the closest peer by opening a stream. It then waits for // PushChunkToClosest sends chunk to the closest peer by opening a stream. It then waits for
...@@ -208,19 +222,44 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk) (rr *pb.R ...@@ -208,19 +222,44 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk) (rr *pb.R
// compute the price we pay for this receipt and reserve it for the rest of this function // compute the price we pay for this receipt and reserve it for the rest of this function
receiptPrice := ps.pricer.PeerPrice(peer, ch.Address()) receiptPrice := ps.pricer.PeerPrice(peer, ch.Address())
err = ps.accounting.Reserve(ctx, peer, receiptPrice)
headers, err := headerutils.MakePricingHeaders(receiptPrice, ch.Address())
if err != nil { if err != nil {
return nil, fmt.Errorf("reserve balance for peer %s: %w", peer.String(), err) return nil, err
} }
deferFuncs = append(deferFuncs, func() { ps.accounting.Release(peer, receiptPrice) })
streamer, err := ps.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName) streamer, err := ps.streamer.NewStream(ctx, peer, headers, protocolName, protocolVersion, streamName)
if err != nil { if err != nil {
lastErr = fmt.Errorf("new stream for peer %s: %w", peer.String(), err) lastErr = fmt.Errorf("new stream for peer %s: %w", peer.String(), err)
continue continue
} }
deferFuncs = append(deferFuncs, func() { go streamer.FullClose() }) deferFuncs = append(deferFuncs, func() { go streamer.FullClose() })
returnedHeaders := streamer.Headers()
returnedTarget, returnedPrice, returnedIndex, err := headerutils.ParsePricingResponseHeaders(returnedHeaders)
if err != nil {
return nil, fmt.Errorf("push price headers: read returned: %w", err)
}
ps.logger.Debugf("push price headers: returned target %v with price as %v, from peer %s", returnedTarget, returnedPrice, peer)
ps.logger.Debugf("push price headers: original target %v with price as %v, from peer %s", ch.Address(), receiptPrice, peer)
// check if returned price matches presumed price, if not, update price
if returnedPrice != receiptPrice {
err = ps.pricer.NotifyPeerPrice(peer, returnedPrice, returnedIndex) // save priceHeaders["price"] corresponding row for peer
if err != nil {
return nil, err
}
receiptPrice = returnedPrice
}
// Reserve to see whether we can make the request based on actual price
err = ps.accounting.Reserve(ctx, peer, receiptPrice)
if err != nil {
return nil, fmt.Errorf("reserve balance for peer %s: %w", peer.String(), err)
}
deferFuncs = append(deferFuncs, func() { ps.accounting.Release(peer, receiptPrice) })
w, r := protobuf.NewWriterAndReader(streamer) w, r := protobuf.NewWriterAndReader(streamer)
ctxd, canceld := context.WithTimeout(ctx, timeToLive) ctxd, canceld := context.WithTimeout(ctx, timeToLive)
deferFuncs = append(deferFuncs, func() { canceld() }) deferFuncs = append(deferFuncs, func() { canceld() })
......
This diff is collapsed.
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
pricermock "github.com/ethersphere/bee/pkg/pricer/mock"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync"
pushsyncmock "github.com/ethersphere/bee/pkg/pushsync/mock" pushsyncmock "github.com/ethersphere/bee/pkg/pushsync/mock"
...@@ -218,14 +219,14 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store ...@@ -218,14 +219,14 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store
mockStorer := storemock.NewStorer() mockStorer := storemock.NewStorer()
serverMockAccounting := accountingmock.NewAccounting() serverMockAccounting := accountingmock.NewAccounting()
price := uint64(12345)
pricerMock := accountingmock.NewPricer(price, price) pricerMock := pricermock.NewMockService()
peerID := swarm.MustParseHexAddress("deadbeef") peerID := swarm.MustParseHexAddress("deadbeef")
ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error { ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(peerID, 0) _, _, _ = f(peerID, 0)
return nil return nil
}} }}
server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps, logger, serverMockAccounting, nil, nil) server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps, logger, serverMockAccounting, pricerMock, nil)
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
) )
......
...@@ -20,6 +20,8 @@ import ( ...@@ -20,6 +20,8 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pricer"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
pb "github.com/ethersphere/bee/pkg/retrieval/pb" pb "github.com/ethersphere/bee/pkg/retrieval/pb"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -52,12 +54,12 @@ type Service struct { ...@@ -52,12 +54,12 @@ type Service struct {
singleflight singleflight.Group singleflight singleflight.Group
logger logging.Logger logger logging.Logger
accounting accounting.Interface accounting accounting.Interface
pricer accounting.Pricer
metrics metrics metrics metrics
pricer pricer.Interface
tracer *tracing.Tracer tracer *tracing.Tracer
} }
func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *Service { func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, tracer *tracing.Tracer) *Service {
return &Service{ return &Service{
addr: addr, addr: addr,
streamer: streamer, streamer: streamer,
...@@ -79,6 +81,7 @@ func (s *Service) Protocol() p2p.ProtocolSpec { ...@@ -79,6 +81,7 @@ func (s *Service) Protocol() p2p.ProtocolSpec {
{ {
Name: streamName, Name: streamName,
Handler: s.handler, Handler: s.handler,
Headler: s.pricer.PriceHeadler,
}, },
}, },
} }
...@@ -201,20 +204,23 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski ...@@ -201,20 +204,23 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski
sp.Add(peer) sp.Add(peer)
// compute the price we pay for this chunk and reserve it for the rest of this function // compute the price we presume to pay for this chunk for price header
chunkPrice := s.pricer.PeerPrice(peer, addr) chunkPrice := s.pricer.PeerPrice(peer, addr)
err = s.accounting.Reserve(ctx, peer, chunkPrice)
headers, err := headerutils.MakePricingHeaders(chunkPrice, addr)
if err != nil { if err != nil {
return nil, peer, err return nil, swarm.Address{}, err
} }
defer s.accounting.Release(peer, chunkPrice)
s.logger.Tracef("retrieval: requesting chunk %s from peer %s", addr, peer) s.logger.Tracef("retrieval: requesting chunk %s from peer %s", addr, peer)
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName) stream, err := s.streamer.NewStream(ctx, peer, headers, protocolName, protocolVersion, streamName)
if err != nil { if err != nil {
s.metrics.TotalErrors.Inc() s.metrics.TotalErrors.Inc()
return nil, peer, fmt.Errorf("new stream: %w", err) return nil, peer, fmt.Errorf("new stream: %w", err)
} }
returnedHeaders := stream.Headers()
defer func() { defer func() {
if err != nil { if err != nil {
_ = stream.Reset() _ = stream.Reset()
...@@ -223,6 +229,29 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski ...@@ -223,6 +229,29 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski
} }
}() }()
returnedTarget, returnedPrice, returnedIndex, err := headerutils.ParsePricingResponseHeaders(returnedHeaders)
if err != nil {
return nil, peer, fmt.Errorf("retrieval headers: read returned: %w", err)
}
s.logger.Debugf("retrieval headers: returned target %v with price as %v, from peer %s", returnedTarget, returnedPrice, peer)
s.logger.Debugf("retrieval headers: original target %v with price as %v, from peer %s", addr, chunkPrice, peer)
// check if returned price matches presumed price, if not, update price
if returnedPrice != chunkPrice {
err = s.pricer.NotifyPeerPrice(peer, returnedPrice, returnedIndex) // save priceHeaders["price"] corresponding row for peer
if err != nil {
return nil, peer, err
}
chunkPrice = returnedPrice
}
// Reserve to see whether we can request the chunk based on actual price
err = s.accounting.Reserve(ctx, peer, chunkPrice)
if err != nil {
return nil, peer, err
}
defer s.accounting.Release(peer, chunkPrice)
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
if err := w.WriteMsgWithContext(ctx, &pb.Request{ if err := w.WriteMsgWithContext(ctx, &pb.Request{
Addr: addr.Bytes(), Addr: addr.Bytes(),
...@@ -329,6 +358,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -329,6 +358,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
if err := r.ReadMsgWithContext(ctx, &req); err != nil { if err := r.ReadMsgWithContext(ctx, &req); err != nil {
return fmt.Errorf("read request: %w peer %s", err, p.Address.String()) return fmt.Errorf("read request: %w peer %s", err, p.Address.String())
} }
span, _, ctx := s.tracer.StartSpanFromContext(ctx, "handle-retrieve-chunk", s.logger, opentracing.Tag{Key: "address", Value: swarm.NewAddress(req.Addr).String()}) span, _, ctx := s.tracer.StartSpanFromContext(ctx, "handle-retrieve-chunk", s.logger, opentracing.Tag{Key: "address", Value: swarm.NewAddress(req.Addr).String()})
defer span.Finish() defer span.Finish()
...@@ -355,12 +385,15 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -355,12 +385,15 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String()) s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String())
// compute the price we charge for this chunk and debit it from p's balance // to get price Read in headler,
chunkPrice := s.pricer.Price(chunk.Address()) returnedHeaders := stream.ResponseHeaders()
err = s.accounting.Debit(p.Address, chunkPrice) chunkPrice, err := headerutils.ParsePriceHeader(returnedHeaders)
if err != nil { if err != nil {
return err // if not found in returned header, compute the price we charge for this chunk and
s.logger.Warningf("retrieval: peer %v no price in previously issued response headers: %v", p.Address, err)
chunkPrice = s.pricer.PriceForPeer(p.Address, chunk.Address())
} }
// debit price from p's balance
return nil return s.accounting.Debit(p.Address, chunkPrice)
} }
...@@ -11,15 +11,19 @@ import ( ...@@ -11,15 +11,19 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/big"
"sync" "sync"
"testing" "testing"
"time" "time"
accountingmock "github.com/ethersphere/bee/pkg/accounting/mock" accountingmock "github.com/ethersphere/bee/pkg/accounting/mock"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pricer/headerutils"
pricermock "github.com/ethersphere/bee/pkg/pricer/mock"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
pb "github.com/ethersphere/bee/pkg/retrieval/pb" pb "github.com/ethersphere/bee/pkg/retrieval/pb"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -34,17 +38,23 @@ var testTimeout = 5 * time.Second ...@@ -34,17 +38,23 @@ var testTimeout = 5 * time.Second
// TestDelivery tests that a naive request -> delivery flow works. // TestDelivery tests that a naive request -> delivery flow works.
func TestDelivery(t *testing.T) { func TestDelivery(t *testing.T) {
var ( var (
chunk = testingc.FixtureChunk("0033")
headlerFunc = func(h p2p.Headers, a swarm.Address) p2p.Headers {
headers, _ := headerutils.MakePricingResponseHeaders(10, chunk.Address(), 0)
return headers
}
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
mockStorer = storemock.NewStorer() mockStorer = storemock.NewStorer()
chunk = testingc.FixtureChunk("0033")
clientMockAccounting = accountingmock.NewAccounting() clientMockAccounting = accountingmock.NewAccounting()
serverMockAccounting = accountingmock.NewAccounting() serverMockAccounting = accountingmock.NewAccounting()
clientAddr = swarm.MustParseHexAddress("9ee7add8") clientAddr = swarm.MustParseHexAddress("9ee7add8")
serverAddr = swarm.MustParseHexAddress("9ee7add7") serverAddr = swarm.MustParseHexAddress("9ee7add7")
price = uint64(10) price = uint64(10)
pricerMock = accountingmock.NewPricer(price, price) pricerMock = pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc))
) )
// put testdata in the mock store of the server // put testdata in the mock store of the server
_, err := mockStorer.Put(context.Background(), storage.ModePutUpload, chunk) _, err := mockStorer.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil { if err != nil {
...@@ -132,10 +142,126 @@ func TestDelivery(t *testing.T) { ...@@ -132,10 +142,126 @@ func TestDelivery(t *testing.T) {
} }
} }
// TestDelivery tests that a naive request -> delivery flow works.
func TestDeliveryWithPriceUpdate(t *testing.T) {
var (
price = uint64(10)
serverPrice = uint64(17)
chunk = testingc.FixtureChunk("0033")
headlerFunc = func(h p2p.Headers, a swarm.Address) p2p.Headers {
headers, _ := headerutils.MakePricingResponseHeaders(serverPrice, chunk.Address(), 5)
return headers
}
logger = logging.New(ioutil.Discard, 0)
mockStorer = storemock.NewStorer()
clientMockAccounting = accountingmock.NewAccounting()
serverMockAccounting = accountingmock.NewAccounting()
clientAddr = swarm.MustParseHexAddress("9ee7add8")
serverAddr = swarm.MustParseHexAddress("9ee7add7")
clientPricerMock = pricermock.NewMockService(pricermock.WithPeerPrice(price))
serverPricerMock = pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc), pricermock.WithPrice(serverPrice))
)
// put testdata in the mock store of the server
_, err := mockStorer.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil {
t.Fatal(err)
}
// create the server that will handle the request and will serve the response
server := retrieval.New(swarm.MustParseHexAddress("0034"), mockStorer, nil, nil, logger, serverMockAccounting, serverPricerMock, nil)
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
streamtest.WithBaseAddr(clientAddr),
)
// client mock storer does not store any data at this point
// but should be checked at at the end of the test for the
// presence of the chunk address key and value to ensure delivery
// was successful
clientMockStorer := storemock.NewStorer()
ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(serverAddr, 0)
return nil
}}
client := retrieval.New(clientAddr, clientMockStorer, recorder, ps, logger, clientMockAccounting, clientPricerMock, nil)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
v, err := client.RetrieveChunk(ctx, chunk.Address())
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(v.Data(), chunk.Data()) {
t.Fatalf("request and response data not equal. got %s want %s", v, chunk.Data())
}
records, err := recorder.Records(serverAddr, "retrieval", "1.0.0", "retrieval")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want %v", l, 1)
}
record := records[0]
messages, err := protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.Request) },
)
if err != nil {
t.Fatal(err)
}
var reqs []string
for _, m := range messages {
reqs = append(reqs, hex.EncodeToString(m.(*pb.Request).Addr))
}
if len(reqs) != 1 {
t.Fatalf("got too many requests. want 1 got %d", len(reqs))
}
messages, err = protobuf.ReadMessages(
bytes.NewReader(record.Out()),
func() protobuf.Message { return new(pb.Delivery) },
)
if err != nil {
t.Fatal(err)
}
var gotDeliveries []string
for _, m := range messages {
gotDeliveries = append(gotDeliveries, string(m.(*pb.Delivery).Data))
}
if len(gotDeliveries) != 1 {
t.Fatalf("got too many deliveries. want 1 got %d", len(gotDeliveries))
}
clientBalance, _ := clientMockAccounting.Balance(serverAddr)
if clientBalance.Cmp(big.NewInt(-int64(serverPrice))) != 0 {
t.Fatalf("unexpected balance on client. want %d got %d", -serverPrice, clientBalance)
}
serverBalance, _ := serverMockAccounting.Balance(clientAddr)
if serverBalance.Cmp(big.NewInt(int64(serverPrice))) != 0 {
t.Fatalf("unexpected balance on server. want %d got %d", serverPrice, serverBalance)
}
}
func TestRetrieveChunk(t *testing.T) { func TestRetrieveChunk(t *testing.T) {
var ( var (
headlerFunc = func(h p2p.Headers, a swarm.Address) p2p.Headers {
target, _ := headerutils.ParseTargetHeader(h)
headers, _ := headerutils.MakePricingResponseHeaders(10, target, 0)
return headers
}
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
pricer = accountingmock.NewPricer(1, 1) pricer = pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc))
) )
// requesting a chunk from downstream peer is expected // requesting a chunk from downstream peer is expected
...@@ -236,8 +362,14 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -236,8 +362,14 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
chunk := testingc.FixtureChunk("0025") chunk := testingc.FixtureChunk("0025")
someOtherChunk := testingc.FixtureChunk("0033") someOtherChunk := testingc.FixtureChunk("0033")
price := uint64(1) headlerFunc := func(h p2p.Headers, a swarm.Address) p2p.Headers {
pricerMock := accountingmock.NewPricer(price, price) target, _ := headerutils.ParseTargetHeader(h)
headers, _ := headerutils.MakePricingResponseHeaders(10, target, 0)
return headers
}
price := uint64(10)
pricerMock := pricermock.NewMockService(pricermock.WithPriceHeadlerFunc(headlerFunc))
clientAddress := swarm.MustParseHexAddress("1010") clientAddress := swarm.MustParseHexAddress("1010")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment