Commit c37b31db authored by Ralph Pichler's avatar Ralph Pichler Committed by GitHub

feat: async settlement (#1578)

parent 5bf7137f
...@@ -64,7 +64,7 @@ jobs: ...@@ -64,7 +64,7 @@ jobs:
run: | run: |
echo -e "127.0.0.10\tregistry.localhost" | sudo tee -a /etc/hosts echo -e "127.0.0.10\tregistry.localhost" | sudo tee -a /etc/hosts
for ((i=0; i<REPLICA; i++)); do echo -e "127.0.1.$((i+1))\tbee-${i}.localhost bee-${i}-debug.localhost"; done | sudo tee -a /etc/hosts for ((i=0; i<REPLICA; i++)); do echo -e "127.0.1.$((i+1))\tbee-${i}.localhost bee-${i}-debug.localhost"; done | sudo tee -a /etc/hosts
timeout 30m ./beeinfra.sh install --local -r "${REPLICA}" --bootnode /dnsaddr/localhost --geth --k3s --pay-threshold 1000000000000 --postage timeout 30m ./beeinfra.sh install --local -r "${REPLICA}" --bootnode /dnsaddr/localhost --geth --k3s --pay-threshold 2000000000000 --postage
- name: Test pingpong - name: Test pingpong
id: pingpong-1 id: pingpong-1
run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done
...@@ -73,7 +73,7 @@ jobs: ...@@ -73,7 +73,7 @@ jobs:
run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Test settlements - name: Test settlements
id: settlements-1 id: settlements-1
run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 1000000000000 run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 2000000000000
- name: Test pushsync (chunks) - name: Test pushsync (chunks)
id: pushsync-chunks-1 id: pushsync-chunks-1
run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3 --upload-chunks --retry-delay 10s run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3 --upload-chunks --retry-delay 10s
...@@ -101,7 +101,7 @@ jobs: ...@@ -101,7 +101,7 @@ jobs:
cp /etc/rancher/k3s/k3s.yaml ~/.kube/config cp /etc/rancher/k3s/k3s.yaml ~/.kube/config
- name: Set testing cluster (Node connection and clef enabled) - name: Set testing cluster (Node connection and clef enabled)
run: | run: |
timeout 30m ./beeinfra.sh install --local -r "${REPLICA}" --geth --clef --k3s --pay-threshold 1000000000000 --postage timeout 30m ./beeinfra.sh install --local -r "${REPLICA}" --geth --clef --k3s --pay-threshold 2000000000000 --postage
- name: Test pingpong - name: Test pingpong
id: pingpong-2 id: pingpong-2
run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done
...@@ -110,7 +110,7 @@ jobs: ...@@ -110,7 +110,7 @@ jobs:
run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Test settlements - name: Test settlements
id: settlements-2 id: settlements-2
run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 1000000000000 run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 2000000000000
- name: Destroy the cluster - name: Destroy the cluster
run: | run: |
./beeinfra.sh uninstall ./beeinfra.sh uninstall
...@@ -126,7 +126,7 @@ jobs: ...@@ -126,7 +126,7 @@ jobs:
cp /etc/rancher/k3s/k3s.yaml ~/.kube/config cp /etc/rancher/k3s/k3s.yaml ~/.kube/config
- name: Set testing cluster (storage incentives setup) - name: Set testing cluster (storage incentives setup)
run: | run: |
timeout 10m ./beeinfra.sh install --local -r "${REPLICA}" --geth --k3s --pay-threshold 1000000000000 --postage --db-capacity 100 timeout 10m ./beeinfra.sh install --local -r "${REPLICA}" --geth --k3s --pay-threshold 2000000000000 --postage --db-capacity 100
- name: Test gc - name: Test gc
id: gc-chunk-1 id: gc-chunk-1
run: ./beekeeper check gc --cache-capacity 100 --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check gc --cache-capacity 100 --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
......
...@@ -18,7 +18,6 @@ import ( ...@@ -18,7 +18,6 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/pricing" "github.com/ethersphere/bee/pkg/pricing"
"github.com/ethersphere/bee/pkg/settlement"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -55,11 +54,14 @@ type Interface interface { ...@@ -55,11 +54,14 @@ type Interface interface {
CompensatedBalances() (map[string]*big.Int, error) CompensatedBalances() (map[string]*big.Int, error)
} }
type PayFunc func(context.Context, swarm.Address, *big.Int)
// accountingPeer holds all in-memory accounting information for one peer. // accountingPeer holds all in-memory accounting information for one peer.
type accountingPeer struct { type accountingPeer struct {
lock sync.Mutex // lock to be held during any accounting action for this peer lock sync.Mutex // lock to be held during any accounting action for this peer
reservedBalance *big.Int // amount currently reserved for active peer interaction reservedBalance *big.Int // amount currently reserved for active peer interaction
paymentThreshold *big.Int // the threshold at which the peer expects us to pay paymentThreshold *big.Int // the threshold at which the peer expects us to pay
paymentOngoing bool // indicate if we are currently settling with the peer
} }
// Accounting is the main implementation of the accounting interface. // Accounting is the main implementation of the accounting interface.
...@@ -75,7 +77,7 @@ type Accounting struct { ...@@ -75,7 +77,7 @@ type Accounting struct {
// disconnect them. // disconnect them.
paymentTolerance *big.Int paymentTolerance *big.Int
earlyPayment *big.Int earlyPayment *big.Int
settlement settlement.Interface payFunction PayFunc
pricing pricing.Interface pricing pricing.Interface
metrics metrics metrics metrics
} }
...@@ -100,7 +102,6 @@ func NewAccounting( ...@@ -100,7 +102,6 @@ func NewAccounting(
EarlyPayment *big.Int, EarlyPayment *big.Int,
Logger logging.Logger, Logger logging.Logger,
Store storage.StateStorer, Store storage.StateStorer,
Settlement settlement.Interface,
Pricing pricing.Interface, Pricing pricing.Interface,
) (*Accounting, error) { ) (*Accounting, error) {
return &Accounting{ return &Accounting{
...@@ -110,7 +111,6 @@ func NewAccounting( ...@@ -110,7 +111,6 @@ func NewAccounting(
earlyPayment: new(big.Int).Set(EarlyPayment), earlyPayment: new(big.Int).Set(EarlyPayment),
logger: Logger, logger: Logger,
store: Store, store: Store,
settlement: Settlement,
pricing: Pricing, pricing: Pricing,
metrics: newMetrics(), metrics: newMetrics(),
}, nil }, nil
...@@ -118,10 +118,7 @@ func NewAccounting( ...@@ -118,10 +118,7 @@ func NewAccounting(
// Reserve reserves a portion of the balance for peer and attempts settlements if necessary. // Reserve reserves a portion of the balance for peer and attempts settlements if necessary.
func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint64) error { func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint64) error {
accountingPeer, err := a.getAccountingPeer(peer) accountingPeer := a.getAccountingPeer(peer)
if err != nil {
return err
}
accountingPeer.lock.Lock() accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock() defer accountingPeer.lock.Unlock()
...@@ -167,14 +164,10 @@ func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint ...@@ -167,14 +164,10 @@ func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint
// and we are actually in debt, trigger settlement. // and we are actually in debt, trigger settlement.
// we pay early to avoid needlessly blocking request later when concurrent requests occur and we are already close to the payment threshold. // we pay early to avoid needlessly blocking request later when concurrent requests occur and we are already close to the payment threshold.
if increasedExpectedDebt.Cmp(threshold) >= 0 && currentBalance.Cmp(big.NewInt(0)) < 0 { if increasedExpectedDebt.Cmp(threshold) >= 0 && currentBalance.Cmp(big.NewInt(0)) < 0 {
err = a.settle(ctx, peer, accountingPeer) err = a.settle(context.Background(), peer, accountingPeer)
if err != nil { if err != nil {
return fmt.Errorf("failed to settle with peer %v: %v", peer, err) return fmt.Errorf("failed to settle with peer %v: %v", peer, err)
} }
// if we settled successfully our balance is back at 0
// and the expected debt therefore equals next reserved amount
expectedDebt = nextReserved
increasedExpectedDebt = new(big.Int).Add(expectedDebt, additionalDebt)
} }
// if expectedDebt would still exceed the paymentThreshold at this point block this request // if expectedDebt would still exceed the paymentThreshold at this point block this request
...@@ -190,11 +183,7 @@ func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint ...@@ -190,11 +183,7 @@ func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint
// Release releases reserved funds. // Release releases reserved funds.
func (a *Accounting) Release(peer swarm.Address, price uint64) { func (a *Accounting) Release(peer swarm.Address, price uint64) {
accountingPeer, err := a.getAccountingPeer(peer) accountingPeer := a.getAccountingPeer(peer)
if err != nil {
a.logger.Errorf("cannot release balance for peer: %v", err)
return
}
accountingPeer.lock.Lock() accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock() defer accountingPeer.lock.Unlock()
...@@ -213,10 +202,7 @@ func (a *Accounting) Release(peer swarm.Address, price uint64) { ...@@ -213,10 +202,7 @@ func (a *Accounting) Release(peer swarm.Address, price uint64) {
// Credit increases the amount of credit we have with the given peer // Credit increases the amount of credit we have with the given peer
// (and decreases existing debt). // (and decreases existing debt).
func (a *Accounting) Credit(peer swarm.Address, price uint64) error { func (a *Accounting) Credit(peer swarm.Address, price uint64) error {
accountingPeer, err := a.getAccountingPeer(peer) accountingPeer := a.getAccountingPeer(peer)
if err != nil {
return err
}
accountingPeer.lock.Lock() accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock() defer accountingPeer.lock.Unlock()
...@@ -246,6 +232,10 @@ func (a *Accounting) Credit(peer swarm.Address, price uint64) error { ...@@ -246,6 +232,10 @@ func (a *Accounting) Credit(peer swarm.Address, price uint64) error {
// Settle all debt with a peer. The lock on the accountingPeer must be held when // Settle all debt with a peer. The lock on the accountingPeer must be held when
// called. // called.
func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *accountingPeer) error { func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *accountingPeer) error {
if balance.paymentOngoing {
return nil
}
oldBalance, err := a.Balance(peer) oldBalance, err := a.Balance(peer)
if err != nil { if err != nil {
if !errors.Is(err, ErrPeerNoBalance) { if !errors.Is(err, ErrPeerNoBalance) {
...@@ -263,24 +253,9 @@ func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *ac ...@@ -263,24 +253,9 @@ func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *ac
// This is safe because of the earlier check for oldbalance < 0 and the check for != MinInt64 // This is safe because of the earlier check for oldbalance < 0 and the check for != MinInt64
paymentAmount := new(big.Int).Neg(oldBalance) paymentAmount := new(big.Int).Neg(oldBalance)
// Try to save the next balance first. balance.paymentOngoing = true
// Otherwise we might pay and then not be able to save, forcing us to pay
// again after restart.
err = a.store.Put(peerBalanceKey(peer), big.NewInt(0))
if err != nil {
return fmt.Errorf("failed to persist balance: %w", err)
}
err = a.settlement.Pay(ctx, peer, paymentAmount) go a.payFunction(ctx, peer, paymentAmount)
if err != nil {
err = fmt.Errorf("settlement for amount %d failed: %w", paymentAmount, err)
// If the payment didn't succeed we should restore the old balance in
// the state store.
if storeErr := a.store.Put(peerBalanceKey(peer), oldBalance); storeErr != nil {
a.logger.Errorf("failed to restore balance after failed settlement for peer %v: %v", peer, storeErr)
}
return err
}
return nil return nil
} }
...@@ -288,16 +263,14 @@ func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *ac ...@@ -288,16 +263,14 @@ func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *ac
// Debit increases the amount of debt we have with the given peer (and decreases // Debit increases the amount of debt we have with the given peer (and decreases
// existing credit). // existing credit).
func (a *Accounting) Debit(peer swarm.Address, price uint64) error { func (a *Accounting) Debit(peer swarm.Address, price uint64) error {
accountingPeer, err := a.getAccountingPeer(peer) accountingPeer := a.getAccountingPeer(peer)
if err != nil {
return err
}
accountingPeer.lock.Lock() accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock() defer accountingPeer.lock.Unlock()
cost := new(big.Int).SetUint64(price) cost := new(big.Int).SetUint64(price)
// see if peer has surplus balance to deduct this transaction of // see if peer has surplus balance to deduct this transaction of
surplusBalance, err := a.SurplusBalance(peer) surplusBalance, err := a.SurplusBalance(peer)
if err != nil { if err != nil {
return fmt.Errorf("failed to get surplus balance: %w", err) return fmt.Errorf("failed to get surplus balance: %w", err)
...@@ -439,7 +412,7 @@ func peerSurplusBalanceKey(peer swarm.Address) string { ...@@ -439,7 +412,7 @@ func peerSurplusBalanceKey(peer swarm.Address) string {
// getAccountingPeer returns the accountingPeer for a given swarm address. // getAccountingPeer returns the accountingPeer for a given swarm address.
// If not found in memory it will initialize it. // If not found in memory it will initialize it.
func (a *Accounting) getAccountingPeer(peer swarm.Address) (*accountingPeer, error) { func (a *Accounting) getAccountingPeer(peer swarm.Address) *accountingPeer {
a.accountingPeersMu.Lock() a.accountingPeersMu.Lock()
defer a.accountingPeersMu.Unlock() defer a.accountingPeersMu.Unlock()
...@@ -453,7 +426,7 @@ func (a *Accounting) getAccountingPeer(peer swarm.Address) (*accountingPeer, err ...@@ -453,7 +426,7 @@ func (a *Accounting) getAccountingPeer(peer swarm.Address) (*accountingPeer, err
a.accountingPeers[peer.String()] = peerData a.accountingPeers[peer.String()] = peerData
} }
return peerData, nil return peerData
} }
// Balances gets balances for all peers from store. // Balances gets balances for all peers from store.
...@@ -569,11 +542,8 @@ func surplusBalanceKeyPeer(key []byte) (swarm.Address, error) { ...@@ -569,11 +542,8 @@ func surplusBalanceKeyPeer(key []byte) (swarm.Address, error) {
} }
// NotifyPayment is called by Settlement when we receive a payment. // NotifyPayment is called by Settlement when we receive a payment.
func (a *Accounting) NotifyPayment(peer swarm.Address, amount *big.Int) error { func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error {
accountingPeer, err := a.getAccountingPeer(peer) accountingPeer := a.getAccountingPeer(peer)
if err != nil {
return err
}
accountingPeer.lock.Lock() accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock() defer accountingPeer.lock.Unlock()
...@@ -644,28 +614,67 @@ func (a *Accounting) NotifyPayment(peer swarm.Address, amount *big.Int) error { ...@@ -644,28 +614,67 @@ func (a *Accounting) NotifyPayment(peer swarm.Address, amount *big.Int) error {
return nil return nil
} }
// AsyncNotifyPayment calls notify payment in a go routine. // NotifyPaymentThreshold should be called to notify accounting of changes in the payment threshold
// This is needed when accounting needs to be notified but the accounting lock is already held. func (a *Accounting) NotifyPaymentThreshold(peer swarm.Address, paymentThreshold *big.Int) error {
func (a *Accounting) AsyncNotifyPayment(peer swarm.Address, amount *big.Int) error { accountingPeer := a.getAccountingPeer(peer)
go func() {
err := a.NotifyPayment(peer, amount) accountingPeer.lock.Lock()
if err != nil { defer accountingPeer.lock.Unlock()
a.logger.Errorf("failed to notify accounting of payment: %v", err)
} accountingPeer.paymentThreshold.Set(paymentThreshold)
}()
return nil return nil
} }
// NotifyPaymentThreshold should be called to notify accounting of changes in the payment threshold func (a *Accounting) PeerDebt(peer swarm.Address) (*big.Int, error) {
func (a *Accounting) NotifyPaymentThreshold(peer swarm.Address, paymentThreshold *big.Int) error { zero := big.NewInt(0)
accountingPeer, err := a.getAccountingPeer(peer) balance, err := a.Balance(peer)
if err != nil { if err != nil {
return err if errors.Is(err, ErrPeerNoBalance) {
return zero, nil
}
return nil, err
}
if balance.Cmp(zero) <= 0 {
return zero, nil
} }
return balance, nil
}
func (a *Accounting) NotifyPaymentSent(peer swarm.Address, amount *big.Int, receivedError error) {
accountingPeer := a.getAccountingPeer(peer)
accountingPeer.lock.Lock() accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock() defer accountingPeer.lock.Unlock()
accountingPeer.paymentThreshold.Set(paymentThreshold) accountingPeer.paymentOngoing = false
return nil
if receivedError != nil {
a.logger.Warningf("accouting: payment failure %v", receivedError)
return
}
currentBalance, err := a.Balance(peer)
if err != nil {
if !errors.Is(err, ErrPeerNoBalance) {
a.logger.Warningf("accounting: notifypaymentsent failed to load balance: %v", err)
return
}
}
// Get nextBalance by safely increasing current balance with price
nextBalance := new(big.Int).Add(currentBalance, amount)
a.logger.Tracef("registering payment sent to peer %v with amount %d, new balance is %d", peer, amount, nextBalance)
err = a.store.Put(peerBalanceKey(peer), nextBalance)
if err != nil {
a.logger.Warningf("accounting: notifypaymentsent failed to persist balance: %v", err)
return
}
}
func (a *Accounting) SetPayFunc(f PayFunc) {
a.payFunction = f
} }
...@@ -10,11 +10,11 @@ import ( ...@@ -10,11 +10,11 @@ import (
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"testing" "testing"
"time"
"github.com/ethersphere/bee/pkg/accounting" "github.com/ethersphere/bee/pkg/accounting"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
mockSettlement "github.com/ethersphere/bee/pkg/settlement/swap/mock"
"github.com/ethersphere/bee/pkg/statestore/mock" "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -29,6 +29,11 @@ var ( ...@@ -29,6 +29,11 @@ var (
testPaymentThreshold = big.NewInt(10000) testPaymentThreshold = big.NewInt(10000)
) )
type paymentCall struct {
peer swarm.Address
amount *big.Int
}
// booking represents an accounting action and the expected result afterwards // booking represents an accounting action and the expected result afterwards
type booking struct { type booking struct {
peer swarm.Address peer swarm.Address
...@@ -43,7 +48,7 @@ func TestAccountingAddBalance(t *testing.T) { ...@@ -43,7 +48,7 @@ func TestAccountingAddBalance(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -104,7 +109,7 @@ func TestAccountingAdd_persistentBalances(t *testing.T) { ...@@ -104,7 +109,7 @@ func TestAccountingAdd_persistentBalances(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -131,7 +136,7 @@ func TestAccountingAdd_persistentBalances(t *testing.T) { ...@@ -131,7 +136,7 @@ func TestAccountingAdd_persistentBalances(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
acc, err = accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, nil) acc, err = accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -162,7 +167,7 @@ func TestAccountingReserve(t *testing.T) { ...@@ -162,7 +167,7 @@ func TestAccountingReserve(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -190,7 +195,7 @@ func TestAccountingDisconnect(t *testing.T) { ...@@ -190,7 +195,7 @@ func TestAccountingDisconnect(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -225,30 +230,38 @@ func TestAccountingCallSettlement(t *testing.T) { ...@@ -225,30 +230,38 @@ func TestAccountingCallSettlement(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
settlement := mockSettlement.New() acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil)
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, settlement, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
paychan := make(chan paymentCall, 1)
f := func(ctx context.Context, peer swarm.Address, amount *big.Int) {
paychan <- paymentCall{peer: peer, amount: amount}
}
acc.SetPayFunc(f)
peer1Addr, err := swarm.ParseHexAddress("00112233") peer1Addr, err := swarm.ParseHexAddress("00112233")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = acc.Reserve(context.Background(), peer1Addr, testPaymentThreshold.Uint64()) requestPrice := testPaymentThreshold.Uint64() - 1000
err = acc.Reserve(context.Background(), peer1Addr, requestPrice)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Credit until payment treshold // Credit until payment treshold
err = acc.Credit(peer1Addr, testPaymentThreshold.Uint64()) err = acc.Credit(peer1Addr, requestPrice)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
acc.Release(peer1Addr, testPaymentThreshold.Uint64()) acc.Release(peer1Addr, requestPrice)
// try another request // try another request
err = acc.Reserve(context.Background(), peer1Addr, 1) err = acc.Reserve(context.Background(), peer1Addr, 1)
...@@ -256,17 +269,21 @@ func TestAccountingCallSettlement(t *testing.T) { ...@@ -256,17 +269,21 @@ func TestAccountingCallSettlement(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
acc.Release(peer1Addr, 1) select {
case call := <-paychan:
totalSent, err := settlement.TotalSent(peer1Addr) if call.amount.Cmp(big.NewInt(int64(requestPrice))) != 0 {
if err != nil { t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, requestPrice)
t.Fatal(err)
} }
if !call.peer.Equal(peer1Addr) {
if totalSent.Cmp(testPaymentThreshold) != 0 { t.Fatalf("wrong peer address got %v wanted %v", call.peer, peer1Addr)
t.Fatalf("paid wrong amount. got %d wanted %d", totalSent, testPaymentThreshold) }
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for payment")
} }
acc.Release(peer1Addr, 1)
acc.NotifyPaymentSent(peer1Addr, big.NewInt(int64(requestPrice)), nil)
balance, err := acc.Balance(peer1Addr) balance, err := acc.Balance(peer1Addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -282,7 +299,7 @@ func TestAccountingCallSettlement(t *testing.T) { ...@@ -282,7 +299,7 @@ func TestAccountingCallSettlement(t *testing.T) {
} }
// Credit until the expected debt exceeeds payment threshold // Credit until the expected debt exceeeds payment threshold
expectedAmount := testPaymentThreshold.Uint64() - 100 expectedAmount := testPaymentThreshold.Uint64() - 101
err = acc.Reserve(context.Background(), peer1Addr, expectedAmount) err = acc.Reserve(context.Background(), peer1Addr, expectedAmount)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -295,7 +312,7 @@ func TestAccountingCallSettlement(t *testing.T) { ...@@ -295,7 +312,7 @@ func TestAccountingCallSettlement(t *testing.T) {
acc.Release(peer1Addr, expectedAmount) acc.Release(peer1Addr, expectedAmount)
// try another request // try another request to trigger settlement
err = acc.Reserve(context.Background(), peer1Addr, 1) err = acc.Reserve(context.Background(), peer1Addr, 1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -303,13 +320,16 @@ func TestAccountingCallSettlement(t *testing.T) { ...@@ -303,13 +320,16 @@ func TestAccountingCallSettlement(t *testing.T) {
acc.Release(peer1Addr, 1) acc.Release(peer1Addr, 1)
totalSent, err = settlement.TotalSent(peer1Addr) select {
if err != nil { case call := <-paychan:
t.Fatal(err) if call.amount.Cmp(big.NewInt(int64(expectedAmount))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, expectedAmount)
} }
if !call.peer.Equal(peer1Addr) {
if totalSent.Cmp(new(big.Int).Add(new(big.Int).SetUint64(expectedAmount), testPaymentThreshold)) != 0 { t.Fatalf("wrong peer address got %v wanted %v", call.peer, peer1Addr)
t.Fatalf("paid wrong amount. got %d wanted %d", totalSent, new(big.Int).Add(new(big.Int).SetUint64(expectedAmount), testPaymentThreshold)) }
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for payment")
} }
acc.Release(peer1Addr, 100) acc.Release(peer1Addr, 100)
...@@ -322,15 +342,22 @@ func TestAccountingCallSettlementEarly(t *testing.T) { ...@@ -322,15 +342,22 @@ func TestAccountingCallSettlementEarly(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
settlement := mockSettlement.New()
debt := uint64(500) debt := uint64(500)
earlyPayment := big.NewInt(1000) earlyPayment := big.NewInt(1000)
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, earlyPayment, logger, store, settlement, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, earlyPayment, logger, store, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
paychan := make(chan paymentCall, 1)
f := func(ctx context.Context, peer swarm.Address, amount *big.Int) {
paychan <- paymentCall{peer: peer, amount: amount}
}
acc.SetPayFunc(f)
peer1Addr, err := swarm.ParseHexAddress("00112233") peer1Addr, err := swarm.ParseHexAddress("00112233")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -349,14 +376,20 @@ func TestAccountingCallSettlementEarly(t *testing.T) { ...@@ -349,14 +376,20 @@ func TestAccountingCallSettlementEarly(t *testing.T) {
acc.Release(peer1Addr, payment) acc.Release(peer1Addr, payment)
totalSent, err := settlement.TotalSent(peer1Addr) select {
if err != nil { case call := <-paychan:
t.Fatal(err) if call.amount.Cmp(big.NewInt(int64(debt))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, debt)
} }
if !call.peer.Equal(peer1Addr) {
if totalSent.Cmp(new(big.Int).SetUint64(debt)) != 0 { t.Fatalf("wrong peer address got %v wanted %v", call.peer, peer1Addr)
t.Fatalf("paid wrong amount. got %d wanted %d", totalSent, testPaymentThreshold)
} }
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for payment")
}
acc.Release(peer1Addr, 1)
acc.NotifyPaymentSent(peer1Addr, big.NewInt(int64(debt)), nil)
balance, err := acc.Balance(peer1Addr) balance, err := acc.Balance(peer1Addr)
if err != nil { if err != nil {
...@@ -373,9 +406,7 @@ func TestAccountingSurplusBalance(t *testing.T) { ...@@ -373,9 +406,7 @@ func TestAccountingSurplusBalance(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
settlement := mockSettlement.New() acc, err := accounting.NewAccounting(testPaymentThreshold, big.NewInt(0), big.NewInt(0), logger, store, nil)
acc, err := accounting.NewAccounting(testPaymentThreshold, big.NewInt(0), big.NewInt(0), logger, store, settlement, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -389,7 +420,7 @@ func TestAccountingSurplusBalance(t *testing.T) { ...@@ -389,7 +420,7 @@ func TestAccountingSurplusBalance(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// Notify of incoming payment from same peer, so balance goes to 0 with surplusbalance 2 // Notify of incoming payment from same peer, so balance goes to 0 with surplusbalance 2
err = acc.NotifyPayment(peer1Addr, new(big.Int).Add(testPaymentThreshold, big.NewInt(1))) err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).Add(testPaymentThreshold, big.NewInt(1)))
if err != nil { if err != nil {
t.Fatal("Unexpected overflow from doable NotifyPayment") t.Fatal("Unexpected overflow from doable NotifyPayment")
} }
...@@ -410,7 +441,7 @@ func TestAccountingSurplusBalance(t *testing.T) { ...@@ -410,7 +441,7 @@ func TestAccountingSurplusBalance(t *testing.T) {
t.Fatal("Not expected balance") t.Fatal("Not expected balance")
} }
// Notify of incoming payment from same peer, so balance goes to 0 with surplusbalance 10002 (testpaymentthreshold+2) // Notify of incoming payment from same peer, so balance goes to 0 with surplusbalance 10002 (testpaymentthreshold+2)
err = acc.NotifyPayment(peer1Addr, testPaymentThreshold) err = acc.NotifyPaymentReceived(peer1Addr, testPaymentThreshold)
if err != nil { if err != nil {
t.Fatal("Unexpected error from NotifyPayment") t.Fatal("Unexpected error from NotifyPayment")
} }
...@@ -454,7 +485,7 @@ func TestAccountingSurplusBalance(t *testing.T) { ...@@ -454,7 +485,7 @@ func TestAccountingSurplusBalance(t *testing.T) {
// Debit for same peer, so balance goes to 9998 (testpaymentthreshold - 2) with surplusbalance decreasing to 0 // Debit for same peer, so balance goes to 9998 (testpaymentthreshold - 2) with surplusbalance decreasing to 0
err = acc.Debit(peer1Addr, testPaymentThreshold.Uint64()) err = acc.Debit(peer1Addr, testPaymentThreshold.Uint64())
if err != nil { if err != nil {
t.Fatal("Unexpected error from NotifyPayment") t.Fatal("Unexpected error from Debit")
} }
// samity check surplus balance // samity check surplus balance
val, err = acc.SurplusBalance(peer1Addr) val, err = acc.SurplusBalance(peer1Addr)
...@@ -475,13 +506,13 @@ func TestAccountingSurplusBalance(t *testing.T) { ...@@ -475,13 +506,13 @@ func TestAccountingSurplusBalance(t *testing.T) {
} }
// TestAccountingNotifyPayment tests that payments adjust the balance and payment which put us into debt are rejected // TestAccountingNotifyPayment tests that payments adjust the balance and payment which put us into debt are rejected
func TestAccountingNotifyPayment(t *testing.T) { func TestAccountingNotifyPaymentReceived(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -497,7 +528,7 @@ func TestAccountingNotifyPayment(t *testing.T) { ...@@ -497,7 +528,7 @@ func TestAccountingNotifyPayment(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = acc.NotifyPayment(peer1Addr, new(big.Int).SetUint64(debtAmount+testPaymentTolerance.Uint64())) err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).SetUint64(debtAmount+testPaymentTolerance.Uint64()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -507,7 +538,7 @@ func TestAccountingNotifyPayment(t *testing.T) { ...@@ -507,7 +538,7 @@ func TestAccountingNotifyPayment(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = acc.NotifyPayment(peer1Addr, new(big.Int).SetUint64(debtAmount+testPaymentTolerance.Uint64()+1)) err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).SetUint64(debtAmount+testPaymentTolerance.Uint64()+1))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -541,7 +572,7 @@ func TestAccountingConnected(t *testing.T) { ...@@ -541,7 +572,7 @@ func TestAccountingConnected(t *testing.T) {
pricing := &pricingMock{} pricing := &pricingMock{}
_, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, pricing) _, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, pricing)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -576,13 +607,20 @@ func TestAccountingNotifyPaymentThreshold(t *testing.T) { ...@@ -576,13 +607,20 @@ func TestAccountingNotifyPaymentThreshold(t *testing.T) {
defer store.Close() defer store.Close()
pricing := &pricingMock{} pricing := &pricingMock{}
settlement := mockSettlement.New()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, big.NewInt(0), logger, store, settlement, pricing) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, big.NewInt(0), logger, store, pricing)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
paychan := make(chan paymentCall, 1)
f := func(ctx context.Context, peer swarm.Address, amount *big.Int) {
paychan <- paymentCall{peer: peer, amount: amount}
}
acc.SetPayFunc(f)
peer1Addr, err := swarm.ParseHexAddress("00112233") peer1Addr, err := swarm.ParseHexAddress("00112233")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -602,16 +640,75 @@ func TestAccountingNotifyPaymentThreshold(t *testing.T) { ...@@ -602,16 +640,75 @@ func TestAccountingNotifyPaymentThreshold(t *testing.T) {
} }
err = acc.Reserve(context.Background(), peer1Addr, lowerThreshold) err = acc.Reserve(context.Background(), peer1Addr, lowerThreshold)
if err == nil {
t.Fatal(err)
}
if !errors.Is(err, accounting.ErrOverdraft) {
t.Fatal(err)
}
select {
case call := <-paychan:
if call.amount.Cmp(big.NewInt(int64(debt))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, debt)
}
if !call.peer.Equal(peer1Addr) {
t.Fatalf("wrong peer address got %v wanted %v", call.peer, peer1Addr)
}
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for payment")
}
}
func TestAccountingPeerDebt(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
defer store.Close()
pricing := &pricingMock{}
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, big.NewInt(0), logger, store, pricing)
if err != nil {
t.Fatal(err)
}
peer1Addr := swarm.MustParseHexAddress("00112233")
debt := uint64(1000)
err = acc.Debit(peer1Addr, debt)
if err != nil {
t.Fatal(err)
}
actualDebt, err := acc.PeerDebt(peer1Addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if actualDebt.Cmp(new(big.Int).SetUint64(debt)) != 0 {
t.Fatalf("wrong actual debt. got %d wanted %d", actualDebt, debt)
}
totalSent, err := settlement.TotalSent(peer1Addr) peer2Addr := swarm.MustParseHexAddress("11112233")
err = acc.Credit(peer2Addr, 500)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
actualDebt, err = acc.PeerDebt(peer2Addr)
if err != nil {
t.Fatal(err)
}
if actualDebt.Cmp(big.NewInt(0)) != 0 {
t.Fatalf("wrong actual debt. got %d wanted 0", actualDebt)
}
if totalSent.Cmp(new(big.Int).SetUint64(debt)) != 0 { peer3Addr := swarm.MustParseHexAddress("22112233")
t.Fatalf("paid wrong amount. got %d wanted %d", totalSent, debt) actualDebt, err = acc.PeerDebt(peer3Addr)
if err != nil {
t.Fatal(err)
} }
if actualDebt.Cmp(big.NewInt(0)) != 0 {
t.Fatalf("wrong actual debt. got %d wanted 0", actualDebt)
}
} }
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"github.com/ethersphere/bee/pkg/crypto" "github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/libp2p" "github.com/ethersphere/bee/pkg/p2p/libp2p"
"github.com/ethersphere/bee/pkg/settlement"
"github.com/ethersphere/bee/pkg/settlement/swap" "github.com/ethersphere/bee/pkg/settlement/swap"
"github.com/ethersphere/bee/pkg/settlement/swap/chequebook" "github.com/ethersphere/bee/pkg/settlement/swap/chequebook"
"github.com/ethersphere/bee/pkg/settlement/swap/swapprotocol" "github.com/ethersphere/bee/pkg/settlement/swap/swapprotocol"
...@@ -185,6 +186,7 @@ func InitSwap( ...@@ -185,6 +186,7 @@ func InitSwap(
chequebookService chequebook.Service, chequebookService chequebook.Service,
chequeStore chequebook.ChequeStore, chequeStore chequebook.ChequeStore,
cashoutService chequebook.CashoutService, cashoutService chequebook.CashoutService,
accountingAPI settlement.AccountingAPI,
) (*swap.Service, error) { ) (*swap.Service, error) {
swapProtocol := swapprotocol.New(p2ps, logger, overlayEthAddress) swapProtocol := swapprotocol.New(p2ps, logger, overlayEthAddress)
swapAddressBook := swap.NewAddressbook(stateStore) swapAddressBook := swap.NewAddressbook(stateStore)
...@@ -199,6 +201,7 @@ func InitSwap( ...@@ -199,6 +201,7 @@ func InitSwap(
networkID, networkID,
cashoutService, cashoutService,
p2ps, p2ps,
accountingAPI,
) )
swapProtocol.SetSwap(swapService) swapProtocol.SetSwap(swapService)
......
...@@ -461,6 +461,26 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -461,6 +461,26 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
logger.Debugf("p2p address: %s", addr) logger.Debugf("p2p address: %s", addr)
} }
paymentTolerance, ok := new(big.Int).SetString(o.PaymentTolerance, 10)
if !ok {
return nil, fmt.Errorf("invalid payment tolerance: %s", paymentTolerance)
}
paymentEarly, ok := new(big.Int).SetString(o.PaymentEarly, 10)
if !ok {
return nil, fmt.Errorf("invalid payment early: %s", paymentEarly)
}
acc, err := accounting.NewAccounting(
paymentThreshold,
paymentTolerance,
paymentEarly,
logger,
stateStore,
pricing,
)
if err != nil {
return nil, fmt.Errorf("accounting: %w", err)
}
if o.SwapEnable { if o.SwapEnable {
swapService, err = InitSwap( swapService, err = InitSwap(
p2ps, p2ps,
...@@ -471,43 +491,23 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -471,43 +491,23 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
chequebookService, chequebookService,
chequeStore, chequeStore,
cashoutService, cashoutService,
acc,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
settlement = swapService settlement = swapService
} else { } else {
pseudosettleService := pseudosettle.New(p2ps, logger, stateStore) pseudosettleService := pseudosettle.New(p2ps, logger, stateStore, acc)
if err = p2ps.AddProtocol(pseudosettleService.Protocol()); err != nil { if err = p2ps.AddProtocol(pseudosettleService.Protocol()); err != nil {
return nil, fmt.Errorf("pseudosettle service: %w", err) return nil, fmt.Errorf("pseudosettle service: %w", err)
} }
settlement = pseudosettleService settlement = pseudosettleService
} }
paymentTolerance, ok := new(big.Int).SetString(o.PaymentTolerance, 10) acc.SetPayFunc(settlement.Pay)
if !ok {
return nil, fmt.Errorf("invalid payment tolerance: %s", paymentTolerance)
}
paymentEarly, ok := new(big.Int).SetString(o.PaymentEarly, 10)
if !ok {
return nil, fmt.Errorf("invalid payment early: %s", paymentEarly)
}
acc, err := accounting.NewAccounting(
paymentThreshold,
paymentTolerance,
paymentEarly,
logger,
stateStore,
settlement,
pricing,
)
if err != nil {
return nil, fmt.Errorf("accounting: %w", err)
}
pricing.SetPaymentThresholdObserver(acc) pricing.SetPaymentThresholdObserver(acc)
settlement.SetNotifyPaymentFunc(acc.AsyncNotifyPayment)
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer) retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer)
tagService := tags.NewTags(stateStore, logger) tagService := tags.NewTags(stateStore, logger)
......
...@@ -20,7 +20,7 @@ var ( ...@@ -20,7 +20,7 @@ var (
type Interface interface { type Interface interface {
// Pay initiates a payment to the given peer // Pay initiates a payment to the given peer
// It should return without error it is likely that the payment worked // It should return without error it is likely that the payment worked
Pay(ctx context.Context, peer swarm.Address, amount *big.Int) error Pay(ctx context.Context, peer swarm.Address, amount *big.Int)
// TotalSent returns the total amount sent to a peer // TotalSent returns the total amount sent to a peer
TotalSent(peer swarm.Address) (totalSent *big.Int, err error) TotalSent(peer swarm.Address) (totalSent *big.Int, err error)
// TotalReceived returns the total amount received from a peer // TotalReceived returns the total amount received from a peer
...@@ -29,9 +29,10 @@ type Interface interface { ...@@ -29,9 +29,10 @@ type Interface interface {
SettlementsSent() (map[string]*big.Int, error) SettlementsSent() (map[string]*big.Int, error)
// SettlementsReceived returns received settlements for each individual known peer // SettlementsReceived returns received settlements for each individual known peer
SettlementsReceived() (map[string]*big.Int, error) SettlementsReceived() (map[string]*big.Int, error)
// SetNotifyPaymentFunc sets the NotifyPaymentFunc to notify
SetNotifyPaymentFunc(notifyPaymentFunc NotifyPaymentFunc)
} }
// NotifyPaymentFunc is called when a payment from peer was successfully received type AccountingAPI interface {
type NotifyPaymentFunc func(peer swarm.Address, amount *big.Int) error PeerDebt(peer swarm.Address) (*big.Int, error)
NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error
NotifyPaymentSent(peer swarm.Address, amount *big.Int, receivedError error)
}
...@@ -36,16 +36,17 @@ type Service struct { ...@@ -36,16 +36,17 @@ type Service struct {
streamer p2p.Streamer streamer p2p.Streamer
logger logging.Logger logger logging.Logger
store storage.StateStorer store storage.StateStorer
notifyPaymentFunc settlement.NotifyPaymentFunc accountingAPI settlement.AccountingAPI
metrics metrics metrics metrics
} }
func New(streamer p2p.Streamer, logger logging.Logger, store storage.StateStorer) *Service { func New(streamer p2p.Streamer, logger logging.Logger, store storage.StateStorer, accountingAPI settlement.AccountingAPI) *Service {
return &Service{ return &Service{
streamer: streamer, streamer: streamer,
logger: logger, logger: logger,
metrics: newMetrics(), metrics: newMetrics(),
store: store, store: store,
accountingAPI: accountingAPI,
} }
} }
...@@ -106,17 +107,22 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -106,17 +107,22 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
return err return err
} }
return s.notifyPaymentFunc(p.Address, new(big.Int).SetUint64(req.Amount)) return s.accountingAPI.NotifyPaymentReceived(p.Address, new(big.Int).SetUint64(req.Amount))
} }
// Pay initiates a payment to the given peer // Pay initiates a payment to the given peer
func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) error { func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second) ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
var err error
defer func() {
if err != nil {
s.accountingAPI.NotifyPaymentSent(peer, nil, err)
}
}()
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName) stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil { if err != nil {
return err return
} }
defer func() { defer func() {
if err != nil { if err != nil {
...@@ -132,28 +138,27 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) ...@@ -132,28 +138,27 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int)
Amount: amount.Uint64(), Amount: amount.Uint64(),
}) })
if err != nil { if err != nil {
return err return
} }
totalSent, err := s.TotalSent(peer) totalSent, err := s.TotalSent(peer)
if err != nil { if err != nil {
if !errors.Is(err, settlement.ErrPeerNoSettlements) { if !errors.Is(err, settlement.ErrPeerNoSettlements) {
return err return
} }
totalSent = big.NewInt(0) totalSent = big.NewInt(0)
} }
err = s.store.Put(totalKey(peer, SettlementSentPrefix), totalSent.Add(totalSent, amount)) err = s.store.Put(totalKey(peer, SettlementSentPrefix), totalSent.Add(totalSent, amount))
if err != nil { if err != nil {
return err return
} }
s.accountingAPI.NotifyPaymentSent(peer, amount, nil)
amountFloat, _ := new(big.Float).SetInt(amount).Float64() amountFloat, _ := new(big.Float).SetInt(amount).Float64()
s.metrics.TotalSentPseudoSettlements.Add(amountFloat) s.metrics.TotalSentPseudoSettlements.Add(amountFloat)
return nil
} }
// SetNotifyPaymentFunc sets the NotifyPaymentFunc to notify func (s *Service) SetAccountingAPI(accountingAPI settlement.AccountingAPI) {
func (s *Service) SetNotifyPaymentFunc(notifyPaymentFunc settlement.NotifyPaymentFunc) { s.accountingAPI = accountingAPI
s.notifyPaymentFunc = notifyPaymentFunc
} }
// TotalSent returns the total amount sent to a peer // TotalSent returns the total amount sent to a peer
......
...@@ -22,24 +22,47 @@ import ( ...@@ -22,24 +22,47 @@ import (
) )
type testObserver struct { type testObserver struct {
called chan struct{} receivedCalled chan notifyPaymentReceivedCall
sentCalled chan notifyPaymentSentCall
}
type notifyPaymentReceivedCall struct {
peer swarm.Address
amount *big.Int
}
type notifyPaymentSentCall struct {
peer swarm.Address peer swarm.Address
amount *big.Int amount *big.Int
err error
} }
func newTestObserver() *testObserver { func newTestObserver() *testObserver {
return &testObserver{ return &testObserver{
called: make(chan struct{}), receivedCalled: make(chan notifyPaymentReceivedCall, 1),
sentCalled: make(chan notifyPaymentSentCall, 1),
} }
} }
func (t *testObserver) NotifyPayment(peer swarm.Address, amount *big.Int) error { func (t *testObserver) PeerDebt(peer swarm.Address) (*big.Int, error) {
close(t.called) return nil, nil
t.peer = peer }
t.amount = amount
func (t *testObserver) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error {
t.receivedCalled <- notifyPaymentReceivedCall{
peer: peer,
amount: amount,
}
return nil return nil
} }
func (t *testObserver) NotifyPaymentSent(peer swarm.Address, amount *big.Int, err error) {
t.sentCalled <- notifyPaymentSentCall{
peer: peer,
amount: amount,
err: err,
}
}
func TestPayment(t *testing.T) { func TestPayment(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
...@@ -47,8 +70,7 @@ func TestPayment(t *testing.T) { ...@@ -47,8 +70,7 @@ func TestPayment(t *testing.T) {
defer storeRecipient.Close() defer storeRecipient.Close()
observer := newTestObserver() observer := newTestObserver()
recipient := pseudosettle.New(nil, logger, storeRecipient) recipient := pseudosettle.New(nil, logger, storeRecipient, observer)
recipient.SetNotifyPaymentFunc(observer.NotifyPayment)
peerID := swarm.MustParseHexAddress("9ee7add7") peerID := swarm.MustParseHexAddress("9ee7add7")
...@@ -60,14 +82,13 @@ func TestPayment(t *testing.T) { ...@@ -60,14 +82,13 @@ func TestPayment(t *testing.T) {
storePayer := mock.NewStateStore() storePayer := mock.NewStateStore()
defer storePayer.Close() defer storePayer.Close()
payer := pseudosettle.New(recorder, logger, storePayer) observer2 := newTestObserver()
payer := pseudosettle.New(recorder, logger, storePayer, observer2)
payer.SetAccountingAPI(observer2)
amount := big.NewInt(10000) amount := big.NewInt(10000)
err := payer.Pay(context.Background(), peerID, amount) payer.Pay(context.Background(), peerID, amount)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle") records, err := recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle")
if err != nil { if err != nil {
...@@ -102,17 +123,34 @@ func TestPayment(t *testing.T) { ...@@ -102,17 +123,34 @@ func TestPayment(t *testing.T) {
} }
select { select {
case <-observer.called: case call := <-observer.receivedCalled:
if call.amount.Cmp(amount) != 0 {
t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, amount)
}
if !call.peer.Equal(peerID) {
t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID)
}
case <-time.After(time.Second): case <-time.After(time.Second):
t.Fatal("expected observer to be called") t.Fatal("expected observer to be called")
} }
if observer.amount.Cmp(amount) != 0 { select {
t.Fatalf("observer called with wrong amount. got %d, want %d", observer.amount, amount) case call := <-observer2.sentCalled:
if call.amount.Cmp(amount) != 0 {
t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, amount)
}
if !call.peer.Equal(peerID) {
t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID)
}
if call.err != nil {
t.Fatalf("observer called with error. got %v want nil", call.err)
} }
if !observer.peer.Equal(peerID) { case <-time.After(time.Second):
t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID) t.Fatal("expected observer to be called")
} }
totalSent, err := payer.TotalSent(peerID) totalSent, err := payer.TotalSent(peerID)
......
...@@ -27,8 +27,7 @@ type Service struct { ...@@ -27,8 +27,7 @@ type Service struct {
settlementsRecvFunc func() (map[string]*big.Int, error) settlementsRecvFunc func() (map[string]*big.Int, error)
receiveChequeFunc func(context.Context, swarm.Address, *chequebook.SignedCheque) error receiveChequeFunc func(context.Context, swarm.Address, *chequebook.SignedCheque) error
payFunc func(context.Context, swarm.Address, *big.Int) error payFunc func(context.Context, swarm.Address, *big.Int)
setNotifyPaymentFunc settlement.NotifyPaymentFunc
handshakeFunc func(swarm.Address, common.Address) error handshakeFunc func(swarm.Address, common.Address) error
lastSentChequeFunc func(swarm.Address) (*chequebook.SignedCheque, error) lastSentChequeFunc func(swarm.Address) (*chequebook.SignedCheque, error)
lastSentChequesFunc func() (map[string]*chequebook.SignedCheque, error) lastSentChequesFunc func() (map[string]*chequebook.SignedCheque, error)
...@@ -72,19 +71,12 @@ func WithReceiveChequeFunc(f func(context.Context, swarm.Address, *chequebook.Si ...@@ -72,19 +71,12 @@ func WithReceiveChequeFunc(f func(context.Context, swarm.Address, *chequebook.Si
}) })
} }
func WithPayFunc(f func(context.Context, swarm.Address, *big.Int) error) Option { func WithPayFunc(f func(context.Context, swarm.Address, *big.Int)) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.payFunc = f s.payFunc = f
}) })
} }
// WithsettlementsFunc sets the mock settlements function
func WithSetNotifyPaymentFunc(f settlement.NotifyPaymentFunc) Option {
return optionFunc(func(s *Service) {
s.setNotifyPaymentFunc = f
})
}
func WithHandshakeFunc(f func(swarm.Address, common.Address) error) Option { func WithHandshakeFunc(f func(swarm.Address, common.Address) error) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.handshakeFunc = f s.handshakeFunc = f
...@@ -155,22 +147,16 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque ...@@ -155,22 +147,16 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque
} }
// Pay is the mock Pay function of swap. // Pay is the mock Pay function of swap.
func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) error { func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) {
if s.payFunc != nil { if s.payFunc != nil {
return s.payFunc(ctx, peer, amount) s.payFunc(ctx, peer, amount)
return
} }
if settlement, ok := s.settlementsSent[peer.String()]; ok { if settlement, ok := s.settlementsSent[peer.String()]; ok {
s.settlementsSent[peer.String()] = big.NewInt(0).Add(settlement, amount) s.settlementsSent[peer.String()] = big.NewInt(0).Add(settlement, amount)
} else { } else {
s.settlementsSent[peer.String()] = amount s.settlementsSent[peer.String()] = amount
} }
return nil
}
func (s *Service) SetNotifyPaymentFunc(f settlement.NotifyPaymentFunc) {
if s.setNotifyPaymentFunc != nil {
s.SetNotifyPaymentFunc(f)
}
} }
// TotalSent is the mock TotalSent function of swap. // TotalSent is the mock TotalSent function of swap.
......
...@@ -50,7 +50,7 @@ type Service struct { ...@@ -50,7 +50,7 @@ type Service struct {
proto swapprotocol.Interface proto swapprotocol.Interface
logger logging.Logger logger logging.Logger
store storage.StateStorer store storage.StateStorer
notifyPaymentFunc settlement.NotifyPaymentFunc accountingAPI settlement.AccountingAPI
metrics metrics metrics metrics
chequebook chequebook.Service chequebook chequebook.Service
chequeStore chequebook.ChequeStore chequeStore chequebook.ChequeStore
...@@ -61,7 +61,7 @@ type Service struct { ...@@ -61,7 +61,7 @@ type Service struct {
} }
// New creates a new swap Service. // New creates a new swap Service.
func New(proto swapprotocol.Interface, logger logging.Logger, store storage.StateStorer, chequebook chequebook.Service, chequeStore chequebook.ChequeStore, addressbook Addressbook, networkID uint64, cashout chequebook.CashoutService, p2pService p2p.Service) *Service { func New(proto swapprotocol.Interface, logger logging.Logger, store storage.StateStorer, chequebook chequebook.Service, chequeStore chequebook.ChequeStore, addressbook Addressbook, networkID uint64, cashout chequebook.CashoutService, p2pService p2p.Service, accountingAPI settlement.AccountingAPI) *Service {
return &Service{ return &Service{
proto: proto, proto: proto,
logger: logger, logger: logger,
...@@ -73,6 +73,7 @@ func New(proto swapprotocol.Interface, logger logging.Logger, store storage.Stat ...@@ -73,6 +73,7 @@ func New(proto swapprotocol.Interface, logger logging.Logger, store storage.Stat
networkID: networkID, networkID: networkID,
cashout: cashout, cashout: cashout,
p2pService: p2pService, p2pService: p2pService,
accountingAPI: accountingAPI,
} }
} }
...@@ -103,40 +104,46 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque ...@@ -103,40 +104,46 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque
s.metrics.TotalReceived.Add(float64(amount.Uint64())) s.metrics.TotalReceived.Add(float64(amount.Uint64()))
s.metrics.ChequesReceived.Inc() s.metrics.ChequesReceived.Inc()
return s.notifyPaymentFunc(peer, amount) return s.accountingAPI.NotifyPaymentReceived(peer, amount)
} }
// Pay initiates a payment to the given peer // Pay initiates a payment to the given peer
func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) error { func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) {
var err error
defer func() {
if err != nil {
s.accountingAPI.NotifyPaymentSent(peer, nil, err)
}
}()
beneficiary, known, err := s.addressbook.Beneficiary(peer) beneficiary, known, err := s.addressbook.Beneficiary(peer)
if err != nil { if err != nil {
return err return
} }
if !known { if !known {
s.logger.Warningf("disconnecting non-swap peer %v", peer) s.logger.Warningf("disconnecting non-swap peer %v", peer)
err = s.p2pService.Disconnect(peer) err = s.p2pService.Disconnect(peer)
if err != nil { if err != nil {
return err return
} }
return ErrUnknownBeneficary err = ErrUnknownBeneficary
return
} }
balance, err := s.chequebook.Issue(ctx, beneficiary, amount, func(signedCheque *chequebook.SignedCheque) error { balance, err := s.chequebook.Issue(ctx, beneficiary, amount, func(signedCheque *chequebook.SignedCheque) error {
return s.proto.EmitCheque(ctx, peer, signedCheque) return s.proto.EmitCheque(ctx, peer, signedCheque)
}) })
if err != nil { if err != nil {
return err return
} }
bal, _ := big.NewFloat(0).SetInt(balance).Float64() bal, _ := big.NewFloat(0).SetInt(balance).Float64()
s.metrics.AvailableBalance.Set(bal) s.metrics.AvailableBalance.Set(bal)
s.accountingAPI.NotifyPaymentSent(peer, amount, nil)
amountFloat, _ := big.NewFloat(0).SetInt(amount).Float64() amountFloat, _ := big.NewFloat(0).SetInt(amount).Float64()
s.metrics.TotalSent.Add(amountFloat) s.metrics.TotalSent.Add(amountFloat)
s.metrics.ChequesSent.Inc() s.metrics.ChequesSent.Inc()
return nil
} }
// SetNotifyPaymentFunc sets the NotifyPaymentFunc to notify func (s *Service) SetAccountingAPI(accountingAPI settlement.AccountingAPI) {
func (s *Service) SetNotifyPaymentFunc(notifyPaymentFunc settlement.NotifyPaymentFunc) { s.accountingAPI = accountingAPI
s.notifyPaymentFunc = notifyPaymentFunc
} }
// TotalSent returns the total amount sent to a peer // TotalSent returns the total amount sent to a peer
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethersphere/bee/pkg/crypto" "github.com/ethersphere/bee/pkg/crypto"
...@@ -35,18 +36,48 @@ func (m *swapProtocolMock) EmitCheque(ctx context.Context, peer swarm.Address, c ...@@ -35,18 +36,48 @@ func (m *swapProtocolMock) EmitCheque(ctx context.Context, peer swarm.Address, c
} }
type testObserver struct { type testObserver struct {
called bool receivedCalled chan notifyPaymentReceivedCall
sentCalled chan notifyPaymentSentCall
}
type notifyPaymentReceivedCall struct {
peer swarm.Address
amount *big.Int
}
type notifyPaymentSentCall struct {
peer swarm.Address peer swarm.Address
amount *big.Int amount *big.Int
err error
} }
func (t *testObserver) NotifyPayment(peer swarm.Address, amount *big.Int) error { func newTestObserver() *testObserver {
t.called = true return &testObserver{
t.peer = peer receivedCalled: make(chan notifyPaymentReceivedCall, 1),
t.amount = amount sentCalled: make(chan notifyPaymentSentCall, 1),
}
}
func (t *testObserver) PeerDebt(peer swarm.Address) (*big.Int, error) {
return nil, nil
}
func (t *testObserver) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error {
t.receivedCalled <- notifyPaymentReceivedCall{
peer: peer,
amount: amount,
}
return nil return nil
} }
func (t *testObserver) NotifyPaymentSent(peer swarm.Address, amount *big.Int, err error) {
t.sentCalled <- notifyPaymentSentCall{
peer: peer,
amount: amount,
err: err,
}
}
type addressbookMock struct { type addressbookMock struct {
beneficiary func(peer swarm.Address) (beneficiary common.Address, known bool, err error) beneficiary func(peer swarm.Address) (beneficiary common.Address, known bool, err error)
chequebook func(peer swarm.Address) (chequebookAddress common.Address, known bool, err error) chequebook func(peer swarm.Address) (chequebookAddress common.Address, known bool, err error)
...@@ -131,6 +162,8 @@ func TestReceiveCheque(t *testing.T) { ...@@ -131,6 +162,8 @@ func TestReceiveCheque(t *testing.T) {
}, },
} }
observer := newTestObserver()
swap := swap.New( swap := swap.New(
&swapProtocolMock{}, &swapProtocolMock{},
logger, logger,
...@@ -141,27 +174,28 @@ func TestReceiveCheque(t *testing.T) { ...@@ -141,27 +174,28 @@ func TestReceiveCheque(t *testing.T) {
networkID, networkID,
&cashoutMock{}, &cashoutMock{},
mockp2p.New(), mockp2p.New(),
observer,
) )
observer := &testObserver{}
swap.SetNotifyPaymentFunc(observer.NotifyPayment)
err := swap.ReceiveCheque(context.Background(), peer, cheque) err := swap.ReceiveCheque(context.Background(), peer, cheque)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !observer.called { select {
t.Fatal("expected observer to be called") case call := <-observer.receivedCalled:
if call.amount.Cmp(amount) != 0 {
t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, amount)
} }
if observer.amount.Cmp(amount) != 0 { if !call.peer.Equal(peer) {
t.Fatalf("observer called with wrong amount. got %d, want %d", observer.amount, amount) t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peer)
} }
if !observer.peer.Equal(peer) { case <-time.After(time.Second):
t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peer) t.Fatal("expected observer to be called")
} }
} }
func TestReceiveChequeReject(t *testing.T) { func TestReceiveChequeReject(t *testing.T) {
...@@ -194,6 +228,8 @@ func TestReceiveChequeReject(t *testing.T) { ...@@ -194,6 +228,8 @@ func TestReceiveChequeReject(t *testing.T) {
}, },
} }
observer := newTestObserver()
swap := swap.New( swap := swap.New(
&swapProtocolMock{}, &swapProtocolMock{},
logger, logger,
...@@ -204,11 +240,9 @@ func TestReceiveChequeReject(t *testing.T) { ...@@ -204,11 +240,9 @@ func TestReceiveChequeReject(t *testing.T) {
networkID, networkID,
&cashoutMock{}, &cashoutMock{},
mockp2p.New(), mockp2p.New(),
observer,
) )
observer := &testObserver{}
swap.SetNotifyPaymentFunc(observer.NotifyPayment)
err := swap.ReceiveCheque(context.Background(), peer, cheque) err := swap.ReceiveCheque(context.Background(), peer, cheque)
if err == nil { if err == nil {
t.Fatal("accepted invalid cheque") t.Fatal("accepted invalid cheque")
...@@ -217,9 +251,12 @@ func TestReceiveChequeReject(t *testing.T) { ...@@ -217,9 +251,12 @@ func TestReceiveChequeReject(t *testing.T) {
t.Fatalf("wrong error. wanted %v, got %v", errReject, err) t.Fatalf("wrong error. wanted %v, got %v", errReject, err)
} }
if observer.called { select {
t.Fatal("observer was be called for rejected payment") case <-observer.receivedCalled:
t.Fatalf("observer called by error.")
default:
} }
} }
func TestReceiveChequeWrongChequebook(t *testing.T) { func TestReceiveChequeWrongChequebook(t *testing.T) {
...@@ -246,6 +283,7 @@ func TestReceiveChequeWrongChequebook(t *testing.T) { ...@@ -246,6 +283,7 @@ func TestReceiveChequeWrongChequebook(t *testing.T) {
}, },
} }
observer := newTestObserver()
swapService := swap.New( swapService := swap.New(
&swapProtocolMock{}, &swapProtocolMock{},
logger, logger,
...@@ -256,11 +294,9 @@ func TestReceiveChequeWrongChequebook(t *testing.T) { ...@@ -256,11 +294,9 @@ func TestReceiveChequeWrongChequebook(t *testing.T) {
networkID, networkID,
&cashoutMock{}, &cashoutMock{},
mockp2p.New(), mockp2p.New(),
observer,
) )
observer := &testObserver{}
swapService.SetNotifyPaymentFunc(observer.NotifyPayment)
err := swapService.ReceiveCheque(context.Background(), peer, cheque) err := swapService.ReceiveCheque(context.Background(), peer, cheque)
if err == nil { if err == nil {
t.Fatal("accepted invalid cheque") t.Fatal("accepted invalid cheque")
...@@ -269,9 +305,12 @@ func TestReceiveChequeWrongChequebook(t *testing.T) { ...@@ -269,9 +305,12 @@ func TestReceiveChequeWrongChequebook(t *testing.T) {
t.Fatalf("wrong error. wanted %v, got %v", swap.ErrWrongChequebook, err) t.Fatalf("wrong error. wanted %v, got %v", swap.ErrWrongChequebook, err)
} }
if observer.called { select {
t.Fatal("observer was be called for rejected payment") case <-observer.receivedCalled:
t.Fatalf("observer called by error.")
default:
} }
} }
func TestPay(t *testing.T) { func TestPay(t *testing.T) {
...@@ -307,6 +346,8 @@ func TestPay(t *testing.T) { ...@@ -307,6 +346,8 @@ func TestPay(t *testing.T) {
}, },
} }
observer := newTestObserver()
var emitCalled bool var emitCalled bool
swap := swap.New( swap := swap.New(
&swapProtocolMock{ &swapProtocolMock{
...@@ -329,12 +370,10 @@ func TestPay(t *testing.T) { ...@@ -329,12 +370,10 @@ func TestPay(t *testing.T) {
networkID, networkID,
&cashoutMock{}, &cashoutMock{},
mockp2p.New(), mockp2p.New(),
observer,
) )
err := swap.Pay(context.Background(), peer, amount) swap.Pay(context.Background(), peer, amount)
if err != nil {
t.Fatal(err)
}
if !chequebookCalled { if !chequebookCalled {
t.Fatal("chequebook was not called") t.Fatal("chequebook was not called")
...@@ -380,12 +419,27 @@ func TestPayIssueError(t *testing.T) { ...@@ -380,12 +419,27 @@ func TestPayIssueError(t *testing.T) {
networkID, networkID,
&cashoutMock{}, &cashoutMock{},
mockp2p.New(), mockp2p.New(),
nil,
) )
err := swap.Pay(context.Background(), peer, amount) observer := newTestObserver()
if !errors.Is(err, errReject) { swap.SetAccountingAPI(observer)
t.Fatalf("wrong error. wanted %v, got %v", errReject, err)
swap.Pay(context.Background(), peer, amount)
select {
case call := <-observer.sentCalled:
if !call.peer.Equal(peer) {
t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peer)
}
if !errors.Is(call.err, errReject) {
t.Fatalf("wrong error. wanted %v, got %v", errReject, call.err)
} }
case <-time.After(time.Second):
t.Fatal("expected observer to be called")
}
} }
func TestPayUnknownBeneficiary(t *testing.T) { func TestPayUnknownBeneficiary(t *testing.T) {
...@@ -404,6 +458,8 @@ func TestPayUnknownBeneficiary(t *testing.T) { ...@@ -404,6 +458,8 @@ func TestPayUnknownBeneficiary(t *testing.T) {
}, },
} }
observer := newTestObserver()
var disconnectCalled bool var disconnectCalled bool
swapService := swap.New( swapService := swap.New(
&swapProtocolMock{}, &swapProtocolMock{},
...@@ -423,11 +479,23 @@ func TestPayUnknownBeneficiary(t *testing.T) { ...@@ -423,11 +479,23 @@ func TestPayUnknownBeneficiary(t *testing.T) {
return nil return nil
}), }),
), ),
observer,
) )
err := swapService.Pay(context.Background(), peer, amount) swapService.Pay(context.Background(), peer, amount)
if !errors.Is(err, swap.ErrUnknownBeneficary) {
t.Fatalf("wrong error. wanted %v, got %v", swap.ErrUnknownBeneficary, err) select {
case call := <-observer.sentCalled:
if !call.peer.Equal(peer) {
t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peer)
}
if !errors.Is(call.err, swap.ErrUnknownBeneficary) {
t.Fatalf("wrong error. wanted %v, got %v", swap.ErrUnknownBeneficary, call.err)
}
case <-time.After(time.Second):
t.Fatal("expected observer to be called")
} }
if !disconnectCalled { if !disconnectCalled {
...@@ -462,6 +530,7 @@ func TestHandshake(t *testing.T) { ...@@ -462,6 +530,7 @@ func TestHandshake(t *testing.T) {
networkID, networkID,
&cashoutMock{}, &cashoutMock{},
mockp2p.New(), mockp2p.New(),
nil,
) )
err := swapService.Handshake(peer, beneficiary) err := swapService.Handshake(peer, beneficiary)
...@@ -501,6 +570,7 @@ func TestHandshakeNewPeer(t *testing.T) { ...@@ -501,6 +570,7 @@ func TestHandshakeNewPeer(t *testing.T) {
networkID, networkID,
&cashoutMock{}, &cashoutMock{},
mockp2p.New(), mockp2p.New(),
nil,
) )
err := swapService.Handshake(peer, beneficiary) err := swapService.Handshake(peer, beneficiary)
...@@ -531,6 +601,7 @@ func TestHandshakeWrongBeneficiary(t *testing.T) { ...@@ -531,6 +601,7 @@ func TestHandshakeWrongBeneficiary(t *testing.T) {
networkID, networkID,
&cashoutMock{}, &cashoutMock{},
mockp2p.New(), mockp2p.New(),
nil,
) )
err := swapService.Handshake(peer, beneficiary) err := swapService.Handshake(peer, beneficiary)
...@@ -580,6 +651,7 @@ func TestCashout(t *testing.T) { ...@@ -580,6 +651,7 @@ func TestCashout(t *testing.T) {
}, },
}, },
mockp2p.New(), mockp2p.New(),
nil,
) )
returnedHash, err := swapService.CashCheque(context.Background(), peer) returnedHash, err := swapService.CashCheque(context.Background(), peer)
...@@ -626,6 +698,7 @@ func TestCashoutStatus(t *testing.T) { ...@@ -626,6 +698,7 @@ func TestCashoutStatus(t *testing.T) {
}, },
}, },
mockp2p.New(), mockp2p.New(),
nil,
) )
returnedStatus, err := swapService.CashoutStatus(context.Background(), peer) returnedStatus, err := swapService.CashoutStatus(context.Background(), peer)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment