Commit c1e87cef authored by Ralph Pichler's avatar Ralph Pichler Committed by GitHub

feat: time based settlements (#1711)

parent 4b730a6f
...@@ -64,7 +64,7 @@ jobs: ...@@ -64,7 +64,7 @@ jobs:
run: | run: |
echo -e "127.0.0.10\tregistry.localhost" | sudo tee -a /etc/hosts echo -e "127.0.0.10\tregistry.localhost" | sudo tee -a /etc/hosts
for ((i=0; i<REPLICA; i++)); do echo -e "127.0.1.$((i+1))\tbee-${i}.localhost bee-${i}-debug.localhost"; done | sudo tee -a /etc/hosts for ((i=0; i<REPLICA; i++)); do echo -e "127.0.1.$((i+1))\tbee-${i}.localhost bee-${i}-debug.localhost"; done | sudo tee -a /etc/hosts
timeout 30m ./beeinfra.sh install --local -r "${REPLICA}" --bootnode /dnsaddr/localhost --geth --k3s --pay-threshold 2000000000000 --postage timeout 30m ./beeinfra.sh install --local -r "${REPLICA}" --bootnode /dnsaddr/localhost --geth --k3s --pay-threshold 1500000000000 --postage
- name: Test pingpong - name: Test pingpong
id: pingpong-1 id: pingpong-1
run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done
...@@ -73,7 +73,15 @@ jobs: ...@@ -73,7 +73,15 @@ jobs:
run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Test settlements - name: Test settlements
id: settlements-1 id: settlements-1
run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 2000000000000 run: |
./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 1500000000000
sleep 2
- name: Test pss
id: pss
run: ./beekeeper check pss --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --timeout 5m
- name: Test soc
id: soc
run: ./beekeeper check soc --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Test pushsync (chunks) - name: Test pushsync (chunks)
id: pushsync-chunks-1 id: pushsync-chunks-1
run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3 --upload-chunks --retry-delay 15s run: ./beekeeper check pushsync --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" --chunks-per-node 3 --upload-chunks --retry-delay 15s
...@@ -83,12 +91,6 @@ jobs: ...@@ -83,12 +91,6 @@ jobs:
- name: Test manifest - name: Test manifest
id: manifest-1 id: manifest-1
run: ./beekeeper check manifest --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check manifest --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Test pss
id: pss
run: ./beekeeper check pss --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --timeout 5m
- name: Test soc
id: soc
run: ./beekeeper check soc --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Destroy the cluster - name: Destroy the cluster
run: | run: |
./beeinfra.sh uninstall ./beeinfra.sh uninstall
...@@ -101,7 +103,7 @@ jobs: ...@@ -101,7 +103,7 @@ jobs:
cp /etc/rancher/k3s/k3s.yaml ~/.kube/config cp /etc/rancher/k3s/k3s.yaml ~/.kube/config
- name: Set testing cluster (Node connection and clef enabled) - name: Set testing cluster (Node connection and clef enabled)
run: | run: |
timeout 30m ./beeinfra.sh install --local -r "${REPLICA}" --geth --clef --k3s --pay-threshold 2000000000000 --postage timeout 30m ./beeinfra.sh install --local -r "${REPLICA}" --geth --clef --k3s --pay-threshold 1500000000000 --postage
- name: Test pingpong - name: Test pingpong
id: pingpong-2 id: pingpong-2
run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done run: until ./beekeeper check pingpong --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"; do echo "waiting for pingpong..."; sleep .3; done
...@@ -110,7 +112,7 @@ jobs: ...@@ -110,7 +112,7 @@ jobs:
run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check fullconnectivity --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
- name: Test settlements - name: Test settlements
id: settlements-2 id: settlements-2
run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 2000000000000 run: ./beekeeper check settlements --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" --upload-node-count "${REPLICA}" -t 1500000000000
- name: Destroy the cluster - name: Destroy the cluster
run: | run: |
./beeinfra.sh uninstall ./beeinfra.sh uninstall
...@@ -126,7 +128,7 @@ jobs: ...@@ -126,7 +128,7 @@ jobs:
cp /etc/rancher/k3s/k3s.yaml ~/.kube/config cp /etc/rancher/k3s/k3s.yaml ~/.kube/config
- name: Set testing cluster (storage incentives setup) - name: Set testing cluster (storage incentives setup)
run: | run: |
timeout 10m ./beeinfra.sh install --local -r "${REPLICA}" --geth --k3s --pay-threshold 2000000000000 --postage --db-capacity 100 timeout 10m ./beeinfra.sh install --local -r "${REPLICA}" --geth --k3s --pay-threshold 1500000000000 --postage --db-capacity 100
- name: Test gc - name: Test gc
id: gc-chunk-1 id: gc-chunk-1
run: ./beekeeper check gc --cache-capacity 100 --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}" run: ./beekeeper check gc --cache-capacity 100 --api-scheme http --debug-api-scheme http --disable-namespace --debug-api-domain localhost --api-domain localhost --node-count "${REPLICA}"
......
...@@ -218,7 +218,7 @@ func (c *command) setAllFlags(cmd *cobra.Command) { ...@@ -218,7 +218,7 @@ func (c *command) setAllFlags(cmd *cobra.Command) {
cmd.Flags().String(optionWelcomeMessage, "", "send a welcome message string during handshakes") cmd.Flags().String(optionWelcomeMessage, "", "send a welcome message string during handshakes")
cmd.Flags().Bool(optionNameGlobalPinningEnabled, false, "enable global pinning") cmd.Flags().Bool(optionNameGlobalPinningEnabled, false, "enable global pinning")
cmd.Flags().String(optionNamePaymentThreshold, "10000000000000", "threshold in BZZ where you expect to get paid from your peers") cmd.Flags().String(optionNamePaymentThreshold, "10000000000000", "threshold in BZZ where you expect to get paid from your peers")
cmd.Flags().String(optionNamePaymentTolerance, "50000000000000", "excess debt above payment threshold in BZZ where you disconnect from your peer") cmd.Flags().String(optionNamePaymentTolerance, "10000000000000", "excess debt above payment threshold in BZZ where you disconnect from your peer")
cmd.Flags().String(optionNamePaymentEarly, "1000000000000", "amount in BZZ below the peers payment threshold when we initiate settlement") cmd.Flags().String(optionNamePaymentEarly, "1000000000000", "amount in BZZ below the peers payment threshold when we initiate settlement")
cmd.Flags().StringSlice(optionNameResolverEndpoints, []string{}, "ENS compatible API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url") cmd.Flags().StringSlice(optionNameResolverEndpoints, []string{}, "ENS compatible API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url")
cmd.Flags().Bool(optionNameGatewayMode, false, "disable a set of sensitive features in the api") cmd.Flags().Bool(optionNameGatewayMode, false, "disable a set of sensitive features in the api")
......
...@@ -52,8 +52,8 @@ password-file: /var/lib/bee/password ...@@ -52,8 +52,8 @@ password-file: /var/lib/bee/password
# payment-early: 1000000000000 # payment-early: 1000000000000
## threshold in BZZ where you expect to get paid from your peers (default 10000000000000) ## threshold in BZZ where you expect to get paid from your peers (default 10000000000000)
# payment-threshold: 10000000000000 # payment-threshold: 10000000000000
## excess debt above payment threshold in BZZ where you disconnect from your peer (default 50000000000000) ## excess debt above payment threshold in BZZ where you disconnect from your peer (default 10000000000000)
# payment-tolerance: 50000000000000 # payment-tolerance: 10000000000000
## ENS compatible API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url ## ENS compatible API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url
# resolver-options: [] # resolver-options: []
## whether we want the node to start with no listen addresses for p2p ## whether we want the node to start with no listen addresses for p2p
......
...@@ -55,12 +55,12 @@ BEE_CLEF_SIGNER_ENABLE=true ...@@ -55,12 +55,12 @@ BEE_CLEF_SIGNER_ENABLE=true
# BEE_PASSWORD= # BEE_PASSWORD=
## path to a file that contains password for decrypting keys ## path to a file that contains password for decrypting keys
# BEE_PASSWORD_FILE= # BEE_PASSWORD_FILE=
## amount in BZZ below the peers payment threshold when we initiate settlement (default 10000) ## amount in BZZ below the peers payment threshold when we initiate settlement (default 1000000000000)
# BEE_PAYMENT_EARLY=10000 # BEE_PAYMENT_EARLY=1000000000000
## threshold in BZZ where you expect to get paid from your peers (default 100000) ## threshold in BZZ where you expect to get paid from your peers (default 10000000000000)
# BEE_PAYMENT_THRESHOLD=100000 # BEE_PAYMENT_THRESHOLD=10000000000000
## excess debt above payment threshold in BZZ where you disconnect from your peer (default 10000) ## excess debt above payment threshold in BZZ where you disconnect from your peer (default 10000000000000)
# BEE_PAYMENT_TOLERANCE=10000 # BEE_PAYMENT_TOLERANCE=10000000000000
## ENS compatible API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url ## ENS compatible API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url
# BEE_RESOLVER_OPTIONS=[] # BEE_RESOLVER_OPTIONS=[]
## whether we want the node to start with no listen addresses for p2p ## whether we want the node to start with no listen addresses for p2p
......
...@@ -52,8 +52,8 @@ password-file: /usr/local/var/lib/swarm-bee/password ...@@ -52,8 +52,8 @@ password-file: /usr/local/var/lib/swarm-bee/password
# payment-early: 1000000000000 # payment-early: 1000000000000
## threshold in BZZ where you expect to get paid from your peers (default 10000000000000) ## threshold in BZZ where you expect to get paid from your peers (default 10000000000000)
# payment-threshold: 10000000000000 # payment-threshold: 10000000000000
## excess debt above payment threshold in BZZ where you disconnect from your peer (default 50000000000000) ## excess debt above payment threshold in BZZ where you disconnect from your peer (default 10000000000000)
# payment-tolerance: 50000000000000 # payment-tolerance: 10000000000000
## ENS compatible API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url ## ENS compatible API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url
# resolver-options: [] # resolver-options: []
## whether we want the node to start with no listen addresses for p2p ## whether we want the node to start with no listen addresses for p2p
......
...@@ -38,12 +38,12 @@ data-dir: ./data ...@@ -38,12 +38,12 @@ data-dir: ./data
# password: "" # password: ""
## path to a file that contains password for decrypting keys ## path to a file that contains password for decrypting keys
password-file: ./password password-file: ./password
## amount in BZZ below the peers payment threshold when we initiate settlement (default 10000) ## amount in BZZ below the peers payment threshold when we initiate settlement (default 1000000000000)
# payment-early: 10000 # payment-early: 1000000000000
## threshold in BZZ where you expect to get paid from your peers (default 100000) ## threshold in BZZ where you expect to get paid from your peers (default 10000000000000)
# payment-threshold: 100000 # payment-threshold: 10000000000000
## excess debt above payment threshold in BZZ where you disconnect from your peer (default 10000) ## excess debt above payment threshold in BZZ where you disconnect from your peer (default 10000000000000)
# payment-tolerance: 10000 # payment-tolerance: 10000000000000
## ENS compatible API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url ## ENS compatible API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url
# resolver-options: [] # resolver-options: []
## whether we want the node to start with no listen addresses for p2p ## whether we want the node to start with no listen addresses for p2p
......
...@@ -26,6 +26,9 @@ var ( ...@@ -26,6 +26,9 @@ var (
_ Interface = (*Accounting)(nil) _ Interface = (*Accounting)(nil)
balancesPrefix string = "accounting_balance_" balancesPrefix string = "accounting_balance_"
balancesSurplusPrefix string = "accounting_surplusbalance_" balancesSurplusPrefix string = "accounting_surplusbalance_"
// fraction of the refresh rate that is the minimum for monetary settlement
// this value is chosen so that tiny payments are prevented while still allowing small payments in environments with lower payment thresholds
minimumPaymentDivisor = int64(5)
) )
// Interface is the Accounting interface. // Interface is the Accounting interface.
...@@ -40,8 +43,8 @@ type Interface interface { ...@@ -40,8 +43,8 @@ type Interface interface {
Release(peer swarm.Address, price uint64) Release(peer swarm.Address, price uint64)
// Credit increases the balance the peer has with us (we "pay" the peer). // Credit increases the balance the peer has with us (we "pay" the peer).
Credit(peer swarm.Address, price uint64) error Credit(peer swarm.Address, price uint64) error
// Debit increases the balance we have with the peer (we get "paid" back). // PrepareDebit returns an accounting Action for the later debit to be executed on and to implement shadowing a possibly credited part of reserve on the other side.
Debit(peer swarm.Address, price uint64) error PrepareDebit(peer swarm.Address, price uint64) Action
// Balance returns the current balance for the given peer. // Balance returns the current balance for the given peer.
Balance(peer swarm.Address) (*big.Int, error) Balance(peer swarm.Address) (*big.Int, error)
// SurplusBalance returns the current surplus balance for the given peer. // SurplusBalance returns the current surplus balance for the given peer.
...@@ -54,13 +57,36 @@ type Interface interface { ...@@ -54,13 +57,36 @@ type Interface interface {
CompensatedBalances() (map[string]*big.Int, error) CompensatedBalances() (map[string]*big.Int, error)
} }
// Action represents an accounting action that can be applied
type Action interface {
// Cleanup cleans up an action. Must be called wether it was applied or not.
Cleanup()
// Apply applies an action
Apply() error
}
// debitAction represents a future debit
type debitAction struct {
accounting *Accounting
price *big.Int
peer swarm.Address
accountingPeer *accountingPeer
applied bool
}
// PayFunc is the function used for async monetary settlement
type PayFunc func(context.Context, swarm.Address, *big.Int) type PayFunc func(context.Context, swarm.Address, *big.Int)
// RefreshFunc is the function used for sync time-based settlement
type RefreshFunc func(context.Context, swarm.Address, *big.Int, *big.Int) (*big.Int, int64, error)
// accountingPeer holds all in-memory accounting information for one peer. // accountingPeer holds all in-memory accounting information for one peer.
type accountingPeer struct { type accountingPeer struct {
lock sync.Mutex // lock to be held during any accounting action for this peer lock sync.Mutex // lock to be held during any accounting action for this peer
reservedBalance *big.Int // amount currently reserved for active peer interaction reservedBalance *big.Int // amount currently reserved for active peer interaction
shadowReservedBalance *big.Int // amount potentially to be debited for active peer interaction
paymentThreshold *big.Int // the threshold at which the peer expects us to pay paymentThreshold *big.Int // the threshold at which the peer expects us to pay
refreshTimestamp int64 // last time we attempted time-based settlement
paymentOngoing bool // indicate if we are currently settling with the peer paymentOngoing bool // indicate if we are currently settling with the peer
} }
...@@ -76,10 +102,21 @@ type Accounting struct { ...@@ -76,10 +102,21 @@ type Accounting struct {
// The amount in BZZ we let peers exceed the payment threshold before we // The amount in BZZ we let peers exceed the payment threshold before we
// disconnect them. // disconnect them.
paymentTolerance *big.Int paymentTolerance *big.Int
// Start settling when reserve plus debt reaches this close to threshold.
earlyPayment *big.Int earlyPayment *big.Int
// Limit to disconnect peer after going in debt over
disconnectLimit *big.Int
// function used for monetory settlement
payFunction PayFunc payFunction PayFunc
// function used for time settlement
refreshFunction RefreshFunc
// allowance based on time used in pseudosettle
refreshRate *big.Int
// lower bound for the value of issued cheques
minimumPayment *big.Int
pricing pricing.Interface pricing pricing.Interface
metrics metrics metrics metrics
timeNow func() time.Time
} }
var ( var (
...@@ -89,8 +126,6 @@ var ( ...@@ -89,8 +126,6 @@ var (
ErrDisconnectThresholdExceeded = errors.New("disconnect threshold exceeded") ErrDisconnectThresholdExceeded = errors.New("disconnect threshold exceeded")
// ErrPeerNoBalance is the error returned if no balance in store exists for a peer // ErrPeerNoBalance is the error returned if no balance in store exists for a peer
ErrPeerNoBalance = errors.New("no balance for peer") ErrPeerNoBalance = errors.New("no balance for peer")
// ErrOverflow denotes an arithmetic operation overflowed.
ErrOverflow = errors.New("overflow error")
// ErrInvalidValue denotes an invalid value read from store // ErrInvalidValue denotes an invalid value read from store
ErrInvalidValue = errors.New("invalid value") ErrInvalidValue = errors.New("invalid value")
) )
...@@ -103,16 +138,21 @@ func NewAccounting( ...@@ -103,16 +138,21 @@ func NewAccounting(
Logger logging.Logger, Logger logging.Logger,
Store storage.StateStorer, Store storage.StateStorer,
Pricing pricing.Interface, Pricing pricing.Interface,
refreshRate *big.Int,
) (*Accounting, error) { ) (*Accounting, error) {
return &Accounting{ return &Accounting{
accountingPeers: make(map[string]*accountingPeer), accountingPeers: make(map[string]*accountingPeer),
paymentThreshold: new(big.Int).Set(PaymentThreshold), paymentThreshold: new(big.Int).Set(PaymentThreshold),
paymentTolerance: new(big.Int).Set(PaymentTolerance), paymentTolerance: new(big.Int).Set(PaymentTolerance),
earlyPayment: new(big.Int).Set(EarlyPayment), earlyPayment: new(big.Int).Set(EarlyPayment),
disconnectLimit: new(big.Int).Add(PaymentThreshold, PaymentTolerance),
logger: Logger, logger: Logger,
store: Store, store: Store,
pricing: Pricing, pricing: Pricing,
metrics: newMetrics(), metrics: newMetrics(),
refreshRate: refreshRate,
timeNow: time.Now,
minimumPayment: new(big.Int).Div(refreshRate, big.NewInt(minimumPaymentDivisor)),
}, nil }, nil
} }
...@@ -129,17 +169,16 @@ func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint ...@@ -129,17 +169,16 @@ func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint
return fmt.Errorf("failed to load balance: %w", err) return fmt.Errorf("failed to load balance: %w", err)
} }
} }
currentDebt := new(big.Int).Neg(currentBalance)
if currentDebt.Cmp(big.NewInt(0)) < 0 {
currentDebt.SetInt64(0)
}
bigPrice := new(big.Int).SetUint64(price) bigPrice := new(big.Int).SetUint64(price)
nextReserved := new(big.Int).Add(accountingPeer.reservedBalance, bigPrice) nextReserved := new(big.Int).Add(accountingPeer.reservedBalance, bigPrice)
expectedBalance := new(big.Int).Sub(currentBalance, nextReserved) // debt if all reserved operations are successfully credited excluding debt created by surplus balance
expectedDebt := new(big.Int).Add(currentDebt, nextReserved)
// Determine if we will owe anything to the peer, if we owe less than 0, we conclude we owe nothing
expectedDebt := new(big.Int).Neg(expectedBalance)
if expectedDebt.Cmp(big.NewInt(0)) < 0 {
expectedDebt.SetInt64(0)
}
threshold := new(big.Int).Set(accountingPeer.paymentThreshold) threshold := new(big.Int).Set(accountingPeer.paymentThreshold)
if threshold.Cmp(a.earlyPayment) > 0 { if threshold.Cmp(a.earlyPayment) > 0 {
...@@ -148,23 +187,23 @@ func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint ...@@ -148,23 +187,23 @@ func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint
threshold.SetInt64(0) threshold.SetInt64(0)
} }
// additionalDebt is debt created by incoming payments which we don't consider debt for monetary settlement purposes
additionalDebt, err := a.SurplusBalance(peer) additionalDebt, err := a.SurplusBalance(peer)
if err != nil { if err != nil {
return fmt.Errorf("failed to load surplus balance: %w", err) return fmt.Errorf("failed to load surplus balance: %w", err)
} }
// uint64 conversion of surplusbalance is safe because surplusbalance is always positive // debt if all reserved operations are successfully credited including debt created by surplus balance
if additionalDebt.Cmp(big.NewInt(0)) < 0 {
return ErrInvalidValue
}
increasedExpectedDebt := new(big.Int).Add(expectedDebt, additionalDebt) increasedExpectedDebt := new(big.Int).Add(expectedDebt, additionalDebt)
// debt if all reserved operations are successfully credited and all shadow reserved operations are debited including debt created by surplus balance
// in other words this the debt the other node sees if everything pending is successful
increasedExpectedDebtReduced := new(big.Int).Sub(increasedExpectedDebt, accountingPeer.shadowReservedBalance)
// If our expected debt is less than earlyPayment away from our payment threshold // If our expected debt reduced by what could have been credited on the other side already is less than earlyPayment away from our payment threshold
// and we are actually in debt, trigger settlement. // and we are actually in debt, trigger settlement.
// we pay early to avoid needlessly blocking request later when concurrent requests occur and we are already close to the payment threshold. // we pay early to avoid needlessly blocking request later when concurrent requests occur and we are already close to the payment threshold.
if increasedExpectedDebt.Cmp(threshold) >= 0 && currentBalance.Cmp(big.NewInt(0)) < 0 { if increasedExpectedDebtReduced.Cmp(threshold) >= 0 && currentBalance.Cmp(big.NewInt(0)) < 0 {
err = a.settle(context.Background(), peer, accountingPeer) err = a.settle(peer, accountingPeer)
if err != nil { if err != nil {
return fmt.Errorf("failed to settle with peer %v: %v", peer, err) return fmt.Errorf("failed to settle with peer %v: %v", peer, err)
} }
...@@ -214,7 +253,7 @@ func (a *Accounting) Credit(peer swarm.Address, price uint64) error { ...@@ -214,7 +253,7 @@ func (a *Accounting) Credit(peer swarm.Address, price uint64) error {
} }
} }
// Calculate next balance by safely decreasing current balance with the price we credit // Calculate next balance by decreasing current balance with the price we credit
nextBalance := new(big.Int).Sub(currentBalance, new(big.Int).SetUint64(price)) nextBalance := new(big.Int).Sub(currentBalance, new(big.Int).SetUint64(price))
a.logger.Tracef("crediting peer %v with price %d, new balance is %d", peer, price, nextBalance) a.logger.Tracef("crediting peer %v with price %d, new balance is %d", peer, price, nextBalance)
...@@ -231,10 +270,9 @@ func (a *Accounting) Credit(peer swarm.Address, price uint64) error { ...@@ -231,10 +270,9 @@ func (a *Accounting) Credit(peer swarm.Address, price uint64) error {
// Settle all debt with a peer. The lock on the accountingPeer must be held when // Settle all debt with a peer. The lock on the accountingPeer must be held when
// called. // called.
func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *accountingPeer) error { func (a *Accounting) settle(peer swarm.Address, balance *accountingPeer) error {
if balance.paymentOngoing { now := a.timeNow().Unix()
return nil timeElapsed := now - balance.refreshTimestamp
}
oldBalance, err := a.Balance(peer) oldBalance, err := a.Balance(peer)
if err != nil { if err != nil {
...@@ -243,101 +281,50 @@ func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *ac ...@@ -243,101 +281,50 @@ func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *ac
} }
} }
// Don't do anything if there is no actual debt. // compute the debt including debt created by incoming payments
// This might be the case if the peer owes us and the total reserve for a compensatedBalance, err := a.CompensatedBalance(peer)
// peer exceeds the payment treshold.
if oldBalance.Cmp(big.NewInt(0)) >= 0 {
return nil
}
// This is safe because of the earlier check for oldbalance < 0 and the check for != MinInt64
paymentAmount := new(big.Int).Neg(oldBalance)
balance.paymentOngoing = true
go a.payFunction(ctx, peer, paymentAmount)
return nil
}
// Debit increases the amount of debt we have with the given peer (and decreases
// existing credit).
func (a *Accounting) Debit(peer swarm.Address, price uint64) error {
accountingPeer := a.getAccountingPeer(peer)
accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock()
cost := new(big.Int).SetUint64(price)
// see if peer has surplus balance to deduct this transaction of
surplusBalance, err := a.SurplusBalance(peer)
if err != nil { if err != nil {
return fmt.Errorf("failed to get surplus balance: %w", err) return err
} }
if surplusBalance.Cmp(big.NewInt(0)) > 0 {
// get new surplus balance after deduct paymentAmount := new(big.Int).Neg(compensatedBalance)
newSurplusBalance := new(big.Int).Sub(surplusBalance, cost)
// if nothing left for debiting, store new surplus balance and return from debit // Don't do anything if there is no actual debt or no time passed since last refreshment attempt
if newSurplusBalance.Cmp(big.NewInt(0)) >= 0 { // This might be the case if the peer owes us and the total reserve for a peer exceeds the payment treshold.
a.logger.Tracef("surplus debiting peer %v with value %d, new surplus balance is %d", peer, price, newSurplusBalance) if paymentAmount.Cmp(big.NewInt(0)) > 0 && timeElapsed > 0 {
shadowBalance, err := a.shadowBalance(peer)
err = a.store.Put(peerSurplusBalanceKey(peer), newSurplusBalance)
if err != nil { if err != nil {
return fmt.Errorf("failed to persist surplus balance: %w", err) return err
}
// count debit operations, terminate early
a.metrics.TotalDebitedAmount.Add(float64(price))
a.metrics.DebitEventsCount.Inc()
return nil
} }
// if surplus balance didn't cover full transaction, let's continue with leftover part as cost acceptedAmount, timestamp, err := a.refreshFunction(context.Background(), peer, paymentAmount, shadowBalance)
debitIncrease := new(big.Int).Sub(new(big.Int).SetUint64(price), surplusBalance) if err != nil {
return fmt.Errorf("refresh failure: %w", err)
// conversion to uint64 is safe because we know the relationship between the values by now, but let's make a sanity check
if debitIncrease.Cmp(big.NewInt(0)) <= 0 {
return fmt.Errorf("sanity check failed for partial debit after surplus balance drawn")
} }
cost.Set(debitIncrease)
// if we still have something to debit, than have run out of surplus balance, balance.refreshTimestamp = timestamp
// let's store 0 as surplus balance
a.logger.Tracef("surplus debiting peer %v with value %d, new surplus balance is 0", peer, debitIncrease)
err = a.store.Put(peerSurplusBalanceKey(peer), big.NewInt(0)) oldBalance = new(big.Int).Add(oldBalance, acceptedAmount)
if err != nil {
return fmt.Errorf("failed to persist surplus balance: %w", err)
}
} a.logger.Tracef("registering refreshment sent to peer %v with amount %d, new balance is %d", peer, acceptedAmount, oldBalance)
currentBalance, err := a.Balance(peer) err = a.store.Put(peerBalanceKey(peer), oldBalance)
if err != nil { if err != nil {
if !errors.Is(err, ErrPeerNoBalance) { return fmt.Errorf("settle: failed to persist balance: %w", err)
return fmt.Errorf("failed to load balance: %w", err)
} }
} }
// Get nextBalance by safely increasing current balance with price if a.payFunction != nil && !balance.paymentOngoing {
nextBalance := new(big.Int).Add(currentBalance, cost) // if there is no monetary settlement happening, check if there is something to settle
// compute debt excluding debt created by incoming payments
a.logger.Tracef("debiting peer %v with price %d, new balance is %d", peer, price, nextBalance) paymentAmount := new(big.Int).Neg(oldBalance)
// if the remaining debt is still larger than some minimum amount, trigger monetary settlement
err = a.store.Put(peerBalanceKey(peer), nextBalance) if paymentAmount.Cmp(a.minimumPayment) >= 0 {
if err != nil { balance.paymentOngoing = true
return fmt.Errorf("failed to persist balance: %w", err) // add settled amount to shadow reserve before sending it
balance.shadowReservedBalance.Add(balance.shadowReservedBalance, paymentAmount)
go a.payFunction(context.Background(), peer, paymentAmount)
} }
a.metrics.TotalDebitedAmount.Add(float64(price))
a.metrics.DebitEventsCount.Inc()
if nextBalance.Cmp(new(big.Int).Add(a.paymentThreshold, a.paymentTolerance)) >= 0 {
// peer too much in debt
a.metrics.AccountingDisconnectsCount.Inc()
return p2p.NewBlockPeerError(10000*time.Hour, ErrDisconnectThresholdExceeded)
} }
return nil return nil
...@@ -368,21 +355,20 @@ func (a *Accounting) SurplusBalance(peer swarm.Address) (balance *big.Int, err e ...@@ -368,21 +355,20 @@ func (a *Accounting) SurplusBalance(peer swarm.Address) (balance *big.Int, err e
return nil, err return nil, err
} }
if balance.Cmp(big.NewInt(0)) < 0 {
return nil, ErrInvalidValue
}
return balance, nil return balance, nil
} }
// CompensatedBalance returns balance decreased by surplus balance // CompensatedBalance returns balance decreased by surplus balance
func (a *Accounting) CompensatedBalance(peer swarm.Address) (compensated *big.Int, err error) { func (a *Accounting) CompensatedBalance(peer swarm.Address) (compensated *big.Int, err error) {
surplus, err := a.SurplusBalance(peer) surplus, err := a.SurplusBalance(peer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if surplus.Cmp(big.NewInt(0)) < 0 {
return nil, ErrInvalidValue
}
balance, err := a.Balance(peer) balance, err := a.Balance(peer)
if err != nil { if err != nil {
if !errors.Is(err, ErrPeerNoBalance) { if !errors.Is(err, ErrPeerNoBalance) {
...@@ -420,6 +406,7 @@ func (a *Accounting) getAccountingPeer(peer swarm.Address) *accountingPeer { ...@@ -420,6 +406,7 @@ func (a *Accounting) getAccountingPeer(peer swarm.Address) *accountingPeer {
if !ok { if !ok {
peerData = &accountingPeer{ peerData = &accountingPeer{
reservedBalance: big.NewInt(0), reservedBalance: big.NewInt(0),
shadowReservedBalance: big.NewInt(0),
// initially assume the peer has the same threshold as us // initially assume the peer has the same threshold as us
paymentThreshold: new(big.Int).Set(a.paymentThreshold), paymentThreshold: new(big.Int).Set(a.paymentThreshold),
} }
...@@ -541,6 +528,117 @@ func surplusBalanceKeyPeer(key []byte) (swarm.Address, error) { ...@@ -541,6 +528,117 @@ func surplusBalanceKeyPeer(key []byte) (swarm.Address, error) {
return addr, nil return addr, nil
} }
// PeerDebt returns the positive part of the sum of the outstanding balance and the shadow reserve
func (a *Accounting) PeerDebt(peer swarm.Address) (*big.Int, error) {
accountingPeer := a.getAccountingPeer(peer)
accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock()
balance := new(big.Int)
zero := big.NewInt(0)
err := a.store.Get(peerBalanceKey(peer), &balance)
if err != nil {
if !errors.Is(err, storage.ErrNotFound) {
return nil, err
}
balance = big.NewInt(0)
}
peerDebt := new(big.Int).Add(balance, accountingPeer.shadowReservedBalance)
if peerDebt.Cmp(zero) < 0 {
return zero, nil
}
return peerDebt, nil
}
// shadowBalance returns the current debt reduced by any potentially debitable amount stored in shadowReservedBalance
// this represents how much less our debt could potentially be seen by the other party if it's ahead with processing credits corresponding to our shadow reserve
func (a *Accounting) shadowBalance(peer swarm.Address) (shadowBalance *big.Int, err error) {
accountingPeer := a.getAccountingPeer(peer)
balance := new(big.Int)
zero := big.NewInt(0)
err = a.store.Get(peerBalanceKey(peer), &balance)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return zero, nil
}
return nil, err
}
if balance.Cmp(zero) >= 0 {
return zero, nil
}
negativeBalance := new(big.Int).Neg(balance)
surplusBalance, err := a.SurplusBalance(peer)
if err != nil {
return nil, err
}
debt := new(big.Int).Add(negativeBalance, surplusBalance)
if debt.Cmp(accountingPeer.shadowReservedBalance) < 0 {
return zero, nil
}
shadowBalance = new(big.Int).Sub(negativeBalance, accountingPeer.shadowReservedBalance)
return shadowBalance, nil
}
// NotifyPaymentSent is triggered by async monetary settlement to update our balance and remove it's price from the shadow reserve
func (a *Accounting) NotifyPaymentSent(peer swarm.Address, amount *big.Int, receivedError error) {
accountingPeer := a.getAccountingPeer(peer)
accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock()
accountingPeer.paymentOngoing = false
// decrease shadow reserve by payment value
accountingPeer.shadowReservedBalance.Sub(accountingPeer.shadowReservedBalance, amount)
if receivedError != nil {
a.logger.Warningf("accounting: payment failure %v", receivedError)
return
}
currentBalance, err := a.Balance(peer)
if err != nil {
if !errors.Is(err, ErrPeerNoBalance) {
a.logger.Errorf("accounting: notifypaymentsent failed to load balance: %v", err)
return
}
}
// Get nextBalance by increasing current balance with price
nextBalance := new(big.Int).Add(currentBalance, amount)
a.logger.Tracef("registering payment sent to peer %v with amount %d, new balance is %d", peer, amount, nextBalance)
err = a.store.Put(peerBalanceKey(peer), nextBalance)
if err != nil {
a.logger.Errorf("accounting: notifypaymentsent failed to persist balance: %v", err)
return
}
}
// NotifyPaymentThreshold should be called to notify accounting of changes in the payment threshold
func (a *Accounting) NotifyPaymentThreshold(peer swarm.Address, paymentThreshold *big.Int) error {
accountingPeer := a.getAccountingPeer(peer)
accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock()
accountingPeer.paymentThreshold.Set(paymentThreshold)
return nil
}
// NotifyPayment is called by Settlement when we receive a payment. // NotifyPayment is called by Settlement when we receive a payment.
func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error { func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error {
accountingPeer := a.getAccountingPeer(peer) accountingPeer := a.getAccountingPeer(peer)
...@@ -553,8 +651,8 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) ...@@ -553,8 +651,8 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int)
if !errors.Is(err, ErrPeerNoBalance) { if !errors.Is(err, ErrPeerNoBalance) {
return err return err
} }
} }
// if balance is already negative or zero, we credit full amount received to surplus balance and terminate early // if balance is already negative or zero, we credit full amount received to surplus balance and terminate early
if currentBalance.Cmp(big.NewInt(0)) <= 0 { if currentBalance.Cmp(big.NewInt(0)) <= 0 {
surplus, err := a.SurplusBalance(peer) surplus, err := a.SurplusBalance(peer)
...@@ -603,7 +701,7 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) ...@@ -603,7 +701,7 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int)
} }
increasedSurplus := new(big.Int).Add(surplus, surplusGrowth) increasedSurplus := new(big.Int).Add(surplus, surplusGrowth)
a.logger.Tracef("surplus crediting peer %v with amount %d due to payment, new surplus balance is %d", peer, surplusGrowth, increasedSurplus) a.logger.Tracef("surplus crediting peer %v with amount %d due to refreshment, new surplus balance is %d", peer, surplusGrowth, increasedSurplus)
err = a.store.Put(peerSurplusBalanceKey(peer), increasedSurplus) err = a.store.Put(peerSurplusBalanceKey(peer), increasedSurplus)
if err != nil { if err != nil {
...@@ -614,65 +712,160 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) ...@@ -614,65 +712,160 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int)
return nil return nil
} }
// NotifyPaymentThreshold should be called to notify accounting of changes in the payment threshold // NotifyRefreshmentReceived is called by pseudosettle when we receive a time based settlement.
func (a *Accounting) NotifyPaymentThreshold(peer swarm.Address, paymentThreshold *big.Int) error { func (a *Accounting) NotifyRefreshmentReceived(peer swarm.Address, amount *big.Int) error {
accountingPeer := a.getAccountingPeer(peer) accountingPeer := a.getAccountingPeer(peer)
accountingPeer.lock.Lock() accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock() defer accountingPeer.lock.Unlock()
accountingPeer.paymentThreshold.Set(paymentThreshold) currentBalance, err := a.Balance(peer)
return nil
}
func (a *Accounting) PeerDebt(peer swarm.Address) (*big.Int, error) {
zero := big.NewInt(0)
balance, err := a.Balance(peer)
if err != nil { if err != nil {
if errors.Is(err, ErrPeerNoBalance) { if !errors.Is(err, ErrPeerNoBalance) {
return zero, nil return err
} }
return nil, err
} }
if balance.Cmp(zero) <= 0 { // Get nextBalance by increasing current balance with amount
return zero, nil nextBalance := new(big.Int).Sub(currentBalance, amount)
// We allow a refreshment to potentially put us into debt as it was previously negotiated and be limited to the peer's outstanding debt plus shadow reserve
a.logger.Tracef("crediting peer %v with amount %d due to payment, new balance is %d", peer, amount, nextBalance)
err = a.store.Put(peerBalanceKey(peer), nextBalance)
if err != nil {
return fmt.Errorf("failed to persist balance: %w", err)
} }
return balance, nil return nil
} }
func (a *Accounting) NotifyPaymentSent(peer swarm.Address, amount *big.Int, receivedError error) { // PrepareDebit prepares a debit operation by increasing the shadowReservedBalance
func (a *Accounting) PrepareDebit(peer swarm.Address, price uint64) Action {
accountingPeer := a.getAccountingPeer(peer) accountingPeer := a.getAccountingPeer(peer)
accountingPeer.lock.Lock() accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock() defer accountingPeer.lock.Unlock()
accountingPeer.paymentOngoing = false bigPrice := new(big.Int).SetUint64(price)
if receivedError != nil { accountingPeer.shadowReservedBalance = new(big.Int).Add(accountingPeer.shadowReservedBalance, bigPrice)
a.logger.Warningf("accouting: payment failure %v", receivedError)
return return &debitAction{
accounting: a,
price: bigPrice,
peer: peer,
accountingPeer: accountingPeer,
applied: false,
}
}
func (a *Accounting) increaseBalance(peer swarm.Address, accountingPeer *accountingPeer, price *big.Int) (*big.Int, error) {
cost := new(big.Int).Set(price)
// see if peer has surplus balance to deduct this transaction of
surplusBalance, err := a.SurplusBalance(peer)
if err != nil {
return nil, fmt.Errorf("failed to get surplus balance: %w", err)
}
if surplusBalance.Cmp(big.NewInt(0)) > 0 {
// get new surplus balance after deduct
newSurplusBalance := new(big.Int).Sub(surplusBalance, cost)
// if nothing left for debiting, store new surplus balance and return from debit
if newSurplusBalance.Cmp(big.NewInt(0)) >= 0 {
a.logger.Tracef("surplus debiting peer %v with value %d, new surplus balance is %d", peer, price, newSurplusBalance)
err = a.store.Put(peerSurplusBalanceKey(peer), newSurplusBalance)
if err != nil {
return nil, fmt.Errorf("failed to persist surplus balance: %w", err)
}
return a.Balance(peer)
}
// if surplus balance didn't cover full transaction, let's continue with leftover part as cost
debitIncrease := new(big.Int).Sub(price, surplusBalance)
// a sanity check
if debitIncrease.Cmp(big.NewInt(0)) <= 0 {
return nil, fmt.Errorf("sanity check failed for partial debit after surplus balance drawn")
}
cost.Set(debitIncrease)
// if we still have something to debit, than have run out of surplus balance,
// let's store 0 as surplus balance
a.logger.Tracef("surplus debiting peer %v with value %d, new surplus balance is 0", peer, debitIncrease)
err = a.store.Put(peerSurplusBalanceKey(peer), big.NewInt(0))
if err != nil {
return nil, fmt.Errorf("failed to persist surplus balance: %w", err)
}
} }
currentBalance, err := a.Balance(peer) currentBalance, err := a.Balance(peer)
if err != nil { if err != nil {
if !errors.Is(err, ErrPeerNoBalance) { if !errors.Is(err, ErrPeerNoBalance) {
a.logger.Warningf("accounting: notifypaymentsent failed to load balance: %v", err) return nil, fmt.Errorf("failed to load balance: %w", err)
return
} }
} }
// Get nextBalance by safely increasing current balance with price // Get nextBalance by increasing current balance with price
nextBalance := new(big.Int).Add(currentBalance, amount) nextBalance := new(big.Int).Add(currentBalance, cost)
a.logger.Tracef("registering payment sent to peer %v with amount %d, new balance is %d", peer, amount, nextBalance) a.logger.Tracef("debiting peer %v with price %d, new balance is %d", peer, price, nextBalance)
err = a.store.Put(peerBalanceKey(peer), nextBalance) err = a.store.Put(peerBalanceKey(peer), nextBalance)
if err != nil { if err != nil {
a.logger.Warningf("accounting: notifypaymentsent failed to persist balance: %v", err) return nil, fmt.Errorf("failed to persist balance: %w", err)
return }
return nextBalance, nil
}
// Apply applies the debit operation and decreases the shadowReservedBalance
func (d *debitAction) Apply() error {
d.accountingPeer.lock.Lock()
defer d.accountingPeer.lock.Unlock()
a := d.accounting
cost := new(big.Int).Set(d.price)
nextBalance, err := d.accounting.increaseBalance(d.peer, d.accountingPeer, cost)
if err != nil {
return err
} }
d.applied = true
d.accountingPeer.shadowReservedBalance = new(big.Int).Sub(d.accountingPeer.shadowReservedBalance, d.price)
tot, _ := big.NewFloat(0).SetInt(d.price).Float64()
a.metrics.TotalDebitedAmount.Add(tot)
a.metrics.DebitEventsCount.Inc()
if nextBalance.Cmp(a.disconnectLimit) >= 0 {
// peer too much in debt
a.metrics.AccountingDisconnectsCount.Inc()
return p2p.NewBlockPeerError(10000*time.Hour, ErrDisconnectThresholdExceeded)
}
return nil
}
// Cleanup reduces shadow reserve if and only if debitaction have not been applied
func (d *debitAction) Cleanup() {
if !d.applied {
d.accountingPeer.lock.Lock()
defer d.accountingPeer.lock.Unlock()
d.accountingPeer.shadowReservedBalance = new(big.Int).Sub(d.accountingPeer.shadowReservedBalance, d.price)
}
}
func (a *Accounting) SetRefreshFunc(f RefreshFunc) {
a.refreshFunction = f
} }
func (a *Accounting) SetPayFunc(f PayFunc) { func (a *Accounting) SetPayFunc(f PayFunc) {
......
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
const ( const (
testPrice = uint64(10) testPrice = uint64(10)
testRefreshRate = int64(1000)
) )
var ( var (
...@@ -48,7 +49,7 @@ func TestAccountingAddBalance(t *testing.T) { ...@@ -48,7 +49,7 @@ func TestAccountingAddBalance(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -83,10 +84,12 @@ func TestAccountingAddBalance(t *testing.T) { ...@@ -83,10 +84,12 @@ func TestAccountingAddBalance(t *testing.T) {
} }
acc.Release(booking.peer, uint64(-booking.price)) acc.Release(booking.peer, uint64(-booking.price))
} else { } else {
err = acc.Debit(booking.peer, uint64(booking.price)) debitAction := acc.PrepareDebit(booking.peer, uint64(booking.price))
err = debitAction.Apply()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
debitAction.Cleanup()
} }
balance, err := acc.Balance(booking.peer) balance, err := acc.Balance(booking.peer)
...@@ -109,7 +112,7 @@ func TestAccountingAdd_persistentBalances(t *testing.T) { ...@@ -109,7 +112,7 @@ func TestAccountingAdd_persistentBalances(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -125,10 +128,12 @@ func TestAccountingAdd_persistentBalances(t *testing.T) { ...@@ -125,10 +128,12 @@ func TestAccountingAdd_persistentBalances(t *testing.T) {
} }
peer1DebitAmount := testPrice peer1DebitAmount := testPrice
err = acc.Debit(peer1Addr, peer1DebitAmount) debitAction := acc.PrepareDebit(peer1Addr, peer1DebitAmount)
err = debitAction.Apply()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
debitAction.Cleanup()
peer2CreditAmount := 2 * testPrice peer2CreditAmount := 2 * testPrice
err = acc.Credit(peer2Addr, peer2CreditAmount) err = acc.Credit(peer2Addr, peer2CreditAmount)
...@@ -136,7 +141,7 @@ func TestAccountingAdd_persistentBalances(t *testing.T) { ...@@ -136,7 +141,7 @@ func TestAccountingAdd_persistentBalances(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
acc, err = accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil) acc, err = accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -167,7 +172,7 @@ func TestAccountingReserve(t *testing.T) { ...@@ -167,7 +172,7 @@ func TestAccountingReserve(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -195,7 +200,7 @@ func TestAccountingDisconnect(t *testing.T) { ...@@ -195,7 +200,7 @@ func TestAccountingDisconnect(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -206,16 +211,20 @@ func TestAccountingDisconnect(t *testing.T) { ...@@ -206,16 +211,20 @@ func TestAccountingDisconnect(t *testing.T) {
} }
// put the peer 1 unit away from disconnect // put the peer 1 unit away from disconnect
err = acc.Debit(peer1Addr, testPaymentThreshold.Uint64()+testPaymentTolerance.Uint64()-1) debitAction := acc.PrepareDebit(peer1Addr, testPaymentThreshold.Uint64()+testPaymentTolerance.Uint64()-1)
err = debitAction.Apply()
if err != nil { if err != nil {
t.Fatal("expected no error while still within tolerance") t.Fatal("expected no error while still within tolerance")
} }
debitAction.Cleanup()
// put the peer over thee threshold // put the peer over thee threshold
err = acc.Debit(peer1Addr, 1) debitAction = acc.PrepareDebit(peer1Addr, 1)
err = debitAction.Apply()
if err == nil { if err == nil {
t.Fatal("expected Add to return error") t.Fatal("expected Add to return error")
} }
debitAction.Cleanup()
var e *p2p.BlockPeerError var e *p2p.BlockPeerError
if !errors.As(err, &e) { if !errors.As(err, &e) {
...@@ -230,18 +239,23 @@ func TestAccountingCallSettlement(t *testing.T) { ...@@ -230,18 +239,23 @@ func TestAccountingCallSettlement(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
paychan := make(chan paymentCall, 1) refreshchan := make(chan paymentCall, 1)
f := func(ctx context.Context, peer swarm.Address, amount *big.Int) { f := func(ctx context.Context, peer swarm.Address, amount *big.Int, shadowBalance *big.Int) (*big.Int, int64, error) {
paychan <- paymentCall{peer: peer, amount: amount} refreshchan <- paymentCall{peer: peer, amount: amount}
return amount, 0, nil
}
pay := func(ctx context.Context, peer swarm.Address, amount *big.Int) {
} }
acc.SetPayFunc(f) acc.SetRefreshFunc(f)
acc.SetPayFunc(pay)
peer1Addr, err := swarm.ParseHexAddress("00112233") peer1Addr, err := swarm.ParseHexAddress("00112233")
if err != nil { if err != nil {
...@@ -270,7 +284,7 @@ func TestAccountingCallSettlement(t *testing.T) { ...@@ -270,7 +284,7 @@ func TestAccountingCallSettlement(t *testing.T) {
} }
select { select {
case call := <-paychan: case call := <-refreshchan:
if call.amount.Cmp(big.NewInt(int64(requestPrice))) != 0 { if call.amount.Cmp(big.NewInt(int64(requestPrice))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, requestPrice) t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, requestPrice)
} }
...@@ -281,8 +295,11 @@ func TestAccountingCallSettlement(t *testing.T) { ...@@ -281,8 +295,11 @@ func TestAccountingCallSettlement(t *testing.T) {
t.Fatal("timeout waiting for payment") t.Fatal("timeout waiting for payment")
} }
if acc.IsPaymentOngoing(peer1Addr) {
t.Fatal("triggered monetary settlement")
}
acc.Release(peer1Addr, 1) acc.Release(peer1Addr, 1)
acc.NotifyPaymentSent(peer1Addr, big.NewInt(int64(requestPrice)), nil)
balance, err := acc.Balance(peer1Addr) balance, err := acc.Balance(peer1Addr)
if err != nil { if err != nil {
...@@ -321,7 +338,7 @@ func TestAccountingCallSettlement(t *testing.T) { ...@@ -321,7 +338,7 @@ func TestAccountingCallSettlement(t *testing.T) {
acc.Release(peer1Addr, 1) acc.Release(peer1Addr, 1)
select { select {
case call := <-paychan: case call := <-refreshchan:
if call.amount.Cmp(big.NewInt(int64(expectedAmount))) != 0 { if call.amount.Cmp(big.NewInt(int64(expectedAmount))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, expectedAmount) t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, expectedAmount)
} }
...@@ -332,9 +349,266 @@ func TestAccountingCallSettlement(t *testing.T) { ...@@ -332,9 +349,266 @@ func TestAccountingCallSettlement(t *testing.T) {
t.Fatal("timeout waiting for payment") t.Fatal("timeout waiting for payment")
} }
if acc.IsPaymentOngoing(peer1Addr) {
t.Fatal("triggered monetary settlement")
}
acc.Release(peer1Addr, 100) acc.Release(peer1Addr, 100)
} }
func TestAccountingCallSettlementMonetary(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, big.NewInt(testRefreshRate))
if err != nil {
t.Fatal(err)
}
refreshchan := make(chan paymentCall, 1)
paychan := make(chan paymentCall, 1)
notTimeSettledAmount := big.NewInt(testRefreshRate * 2)
acc.SetRefreshFunc(func(ctx context.Context, peer swarm.Address, amount *big.Int, shadowBalance *big.Int) (*big.Int, int64, error) {
refreshchan <- paymentCall{peer: peer, amount: amount}
return new(big.Int).Sub(amount, notTimeSettledAmount), 0, nil
})
acc.SetPayFunc(func(ctx context.Context, peer swarm.Address, amount *big.Int) {
paychan <- paymentCall{peer: peer, amount: amount}
})
peer1Addr, err := swarm.ParseHexAddress("00112233")
if err != nil {
t.Fatal(err)
}
requestPrice := testPaymentThreshold.Uint64() - 1000
err = acc.Reserve(context.Background(), peer1Addr, requestPrice)
if err != nil {
t.Fatal(err)
}
// Credit until payment treshold
err = acc.Credit(peer1Addr, requestPrice)
if err != nil {
t.Fatal(err)
}
acc.Release(peer1Addr, requestPrice)
// try another request
err = acc.Reserve(context.Background(), peer1Addr, 1)
if err != nil {
t.Fatal(err)
}
select {
case call := <-refreshchan:
if call.amount.Cmp(big.NewInt(int64(requestPrice))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, requestPrice)
}
if !call.peer.Equal(peer1Addr) {
t.Fatalf("wrong peer address got %v wanted %v", call.peer, peer1Addr)
}
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for payment")
}
select {
case call := <-paychan:
if call.amount.Cmp(notTimeSettledAmount) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, notTimeSettledAmount)
}
if !call.peer.Equal(peer1Addr) {
t.Fatalf("wrong peer address got %v wanted %v", call.peer, peer1Addr)
}
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for payment")
}
acc.Release(peer1Addr, 1)
balance, err := acc.Balance(peer1Addr)
if err != nil {
t.Fatal(err)
}
if balance.Cmp(new(big.Int).Neg(notTimeSettledAmount)) != 0 {
t.Fatalf("expected balance to be adjusted. got %d", balance)
}
acc.SetRefreshFunc(func(ctx context.Context, peer swarm.Address, amount *big.Int, shadowBalance *big.Int) (*big.Int, int64, error) {
refreshchan <- paymentCall{peer: peer, amount: amount}
return big.NewInt(0), 0, nil
})
// Credit until the expected debt exceeeds payment threshold
expectedAmount := testPaymentThreshold.Uint64()
err = acc.Reserve(context.Background(), peer1Addr, expectedAmount)
if !errors.Is(err, accounting.ErrOverdraft) {
t.Fatalf("expected overdraft, got %v", err)
}
select {
case call := <-refreshchan:
if call.amount.Cmp(notTimeSettledAmount) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, notTimeSettledAmount)
}
if !call.peer.Equal(peer1Addr) {
t.Fatalf("wrong peer address got %v wanted %v", call.peer, peer1Addr)
}
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for refreshment")
}
select {
case <-paychan:
t.Fatal("pay called twice")
case <-time.After(1 * time.Second):
}
}
func TestAccountingCallSettlementTooSoon(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
store := mock.NewStateStore()
defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, big.NewInt(testRefreshRate))
if err != nil {
t.Fatal(err)
}
refreshchan := make(chan paymentCall, 1)
paychan := make(chan paymentCall, 1)
ts := int64(1000)
acc.SetRefreshFunc(func(ctx context.Context, peer swarm.Address, amount *big.Int, shadowBalance *big.Int) (*big.Int, int64, error) {
refreshchan <- paymentCall{peer: peer, amount: amount}
return amount, ts, nil
})
acc.SetPayFunc(func(ctx context.Context, peer swarm.Address, amount *big.Int) {
paychan <- paymentCall{peer: peer, amount: amount}
})
peer1Addr, err := swarm.ParseHexAddress("00112233")
if err != nil {
t.Fatal(err)
}
requestPrice := testPaymentThreshold.Uint64() - 1000
err = acc.Reserve(context.Background(), peer1Addr, requestPrice)
if err != nil {
t.Fatal(err)
}
// Credit until payment treshold
err = acc.Credit(peer1Addr, requestPrice)
if err != nil {
t.Fatal(err)
}
acc.Release(peer1Addr, requestPrice)
// try another request
err = acc.Reserve(context.Background(), peer1Addr, 1)
if err != nil {
t.Fatal(err)
}
select {
case call := <-refreshchan:
if call.amount.Cmp(big.NewInt(int64(requestPrice))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, requestPrice)
}
if !call.peer.Equal(peer1Addr) {
t.Fatalf("wrong peer address got %v wanted %v", call.peer, peer1Addr)
}
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for payment")
}
acc.Release(peer1Addr, 1)
balance, err := acc.Balance(peer1Addr)
if err != nil {
t.Fatal(err)
}
if balance.Cmp(big.NewInt(0)) != 0 {
t.Fatalf("expected balance to be adjusted. got %d", balance)
}
acc.SetTime(ts)
err = acc.Reserve(context.Background(), peer1Addr, requestPrice)
if err != nil {
t.Fatal(err)
}
// Credit until payment treshold
err = acc.Credit(peer1Addr, requestPrice)
if err != nil {
t.Fatal(err)
}
acc.Release(peer1Addr, requestPrice)
// try another request
err = acc.Reserve(context.Background(), peer1Addr, 1)
if err != nil {
t.Fatal(err)
}
select {
case <-refreshchan:
t.Fatal("sent refreshment")
default:
}
select {
case call := <-paychan:
if call.amount.Cmp(big.NewInt(int64(requestPrice))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, requestPrice)
}
if !call.peer.Equal(peer1Addr) {
t.Fatalf("wrong peer address got %v wanted %v", call.peer, peer1Addr)
}
case <-time.After(1 * time.Second):
t.Fatal("payment not sent")
}
acc.Release(peer1Addr, 1)
acc.NotifyPaymentSent(peer1Addr, big.NewInt(int64(requestPrice)), errors.New("error"))
acc.SetTime(ts + 1)
// try another request
err = acc.Reserve(context.Background(), peer1Addr, 1)
if err != nil {
t.Fatal(err)
}
select {
case call := <-refreshchan:
if call.amount.Cmp(big.NewInt(int64(requestPrice))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, requestPrice)
}
if !call.peer.Equal(peer1Addr) {
t.Fatalf("wrong peer address got %v wanted %v", call.peer, peer1Addr)
}
default:
t.Fatal("no refreshment")
}
}
// TestAccountingCallSettlementEarly tests that settlement is called correctly if the payment threshold minus early payment is hit // TestAccountingCallSettlementEarly tests that settlement is called correctly if the payment threshold minus early payment is hit
func TestAccountingCallSettlementEarly(t *testing.T) { func TestAccountingCallSettlementEarly(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
...@@ -345,18 +619,19 @@ func TestAccountingCallSettlementEarly(t *testing.T) { ...@@ -345,18 +619,19 @@ func TestAccountingCallSettlementEarly(t *testing.T) {
debt := uint64(500) debt := uint64(500)
earlyPayment := big.NewInt(1000) earlyPayment := big.NewInt(1000)
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, earlyPayment, logger, store, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, earlyPayment, logger, store, nil, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
paychan := make(chan paymentCall, 1) refreshchan := make(chan paymentCall, 1)
f := func(ctx context.Context, peer swarm.Address, amount *big.Int) { f := func(ctx context.Context, peer swarm.Address, amount *big.Int, shadowBalance *big.Int) (*big.Int, int64, error) {
paychan <- paymentCall{peer: peer, amount: amount} refreshchan <- paymentCall{peer: peer, amount: amount}
return amount, 0, nil
} }
acc.SetPayFunc(f) acc.SetRefreshFunc(f)
peer1Addr, err := swarm.ParseHexAddress("00112233") peer1Addr, err := swarm.ParseHexAddress("00112233")
if err != nil { if err != nil {
...@@ -377,7 +652,7 @@ func TestAccountingCallSettlementEarly(t *testing.T) { ...@@ -377,7 +652,7 @@ func TestAccountingCallSettlementEarly(t *testing.T) {
acc.Release(peer1Addr, payment) acc.Release(peer1Addr, payment)
select { select {
case call := <-paychan: case call := <-refreshchan:
if call.amount.Cmp(big.NewInt(int64(debt))) != 0 { if call.amount.Cmp(big.NewInt(int64(debt))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, debt) t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, debt)
} }
...@@ -388,9 +663,6 @@ func TestAccountingCallSettlementEarly(t *testing.T) { ...@@ -388,9 +663,6 @@ func TestAccountingCallSettlementEarly(t *testing.T) {
t.Fatal("timeout waiting for payment") t.Fatal("timeout waiting for payment")
} }
acc.Release(peer1Addr, 1)
acc.NotifyPaymentSent(peer1Addr, big.NewInt(int64(debt)), nil)
balance, err := acc.Balance(peer1Addr) balance, err := acc.Balance(peer1Addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -406,7 +678,7 @@ func TestAccountingSurplusBalance(t *testing.T) { ...@@ -406,7 +678,7 @@ func TestAccountingSurplusBalance(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, big.NewInt(0), big.NewInt(0), logger, store, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, big.NewInt(0), big.NewInt(0), logger, store, nil, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -415,10 +687,12 @@ func TestAccountingSurplusBalance(t *testing.T) { ...@@ -415,10 +687,12 @@ func TestAccountingSurplusBalance(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// Try Debiting a large amount to peer so balance is large positive // Try Debiting a large amount to peer so balance is large positive
err = acc.Debit(peer1Addr, testPaymentThreshold.Uint64()-1) debitAction := acc.PrepareDebit(peer1Addr, testPaymentThreshold.Uint64()-1)
err = debitAction.Apply()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
debitAction.Cleanup()
// Notify of incoming payment from same peer, so balance goes to 0 with surplusbalance 2 // Notify of incoming payment from same peer, so balance goes to 0 with surplusbalance 2
err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).Add(testPaymentThreshold, big.NewInt(1))) err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).Add(testPaymentThreshold, big.NewInt(1)))
if err != nil { if err != nil {
...@@ -462,10 +736,12 @@ func TestAccountingSurplusBalance(t *testing.T) { ...@@ -462,10 +736,12 @@ func TestAccountingSurplusBalance(t *testing.T) {
t.Fatal("Not expected balance, expected 0") t.Fatal("Not expected balance, expected 0")
} }
// Debit for same peer, so balance stays 0 with surplusbalance decreasing to 2 // Debit for same peer, so balance stays 0 with surplusbalance decreasing to 2
err = acc.Debit(peer1Addr, testPaymentThreshold.Uint64()) debitAction = acc.PrepareDebit(peer1Addr, testPaymentThreshold.Uint64())
err = debitAction.Apply()
if err != nil { if err != nil {
t.Fatal("Unexpected error from Credit") t.Fatal("Unexpected error from Credit")
} }
debitAction.Cleanup()
// samity check surplus balance // samity check surplus balance
val, err = acc.SurplusBalance(peer1Addr) val, err = acc.SurplusBalance(peer1Addr)
if err != nil { if err != nil {
...@@ -483,10 +759,12 @@ func TestAccountingSurplusBalance(t *testing.T) { ...@@ -483,10 +759,12 @@ func TestAccountingSurplusBalance(t *testing.T) {
t.Fatal("Not expected balance, expected 0") t.Fatal("Not expected balance, expected 0")
} }
// Debit for same peer, so balance goes to 9998 (testpaymentthreshold - 2) with surplusbalance decreasing to 0 // Debit for same peer, so balance goes to 9998 (testpaymentthreshold - 2) with surplusbalance decreasing to 0
err = acc.Debit(peer1Addr, testPaymentThreshold.Uint64()) debitAction = acc.PrepareDebit(peer1Addr, testPaymentThreshold.Uint64())
err = debitAction.Apply()
if err != nil { if err != nil {
t.Fatal("Unexpected error from Debit") t.Fatal("Unexpected error from Debit")
} }
debitAction.Cleanup()
// samity check surplus balance // samity check surplus balance
val, err = acc.SurplusBalance(peer1Addr) val, err = acc.SurplusBalance(peer1Addr)
if err != nil { if err != nil {
...@@ -512,7 +790,7 @@ func TestAccountingNotifyPaymentReceived(t *testing.T) { ...@@ -512,7 +790,7 @@ func TestAccountingNotifyPaymentReceived(t *testing.T) {
store := mock.NewStateStore() store := mock.NewStateStore()
defer store.Close() defer store.Close()
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, nil, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -523,20 +801,24 @@ func TestAccountingNotifyPaymentReceived(t *testing.T) { ...@@ -523,20 +801,24 @@ func TestAccountingNotifyPaymentReceived(t *testing.T) {
} }
debtAmount := uint64(100) debtAmount := uint64(100)
err = acc.Debit(peer1Addr, debtAmount+testPaymentTolerance.Uint64()) debitAction := acc.PrepareDebit(peer1Addr, debtAmount+testPaymentTolerance.Uint64())
err = debitAction.Apply()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
debitAction.Cleanup()
err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).SetUint64(debtAmount+testPaymentTolerance.Uint64())) err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).SetUint64(debtAmount+testPaymentTolerance.Uint64()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = acc.Debit(peer1Addr, debtAmount) debitAction = acc.PrepareDebit(peer1Addr, debtAmount)
err = debitAction.Apply()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
debitAction.Cleanup()
err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).SetUint64(debtAmount+testPaymentTolerance.Uint64()+1)) err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).SetUint64(debtAmount+testPaymentTolerance.Uint64()+1))
if err != nil { if err != nil {
...@@ -572,7 +854,7 @@ func TestAccountingConnected(t *testing.T) { ...@@ -572,7 +854,7 @@ func TestAccountingConnected(t *testing.T) {
pricing := &pricingMock{} pricing := &pricingMock{}
_, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, pricing) _, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, testPaymentEarly, logger, store, pricing, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -608,18 +890,19 @@ func TestAccountingNotifyPaymentThreshold(t *testing.T) { ...@@ -608,18 +890,19 @@ func TestAccountingNotifyPaymentThreshold(t *testing.T) {
pricing := &pricingMock{} pricing := &pricingMock{}
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, big.NewInt(0), logger, store, pricing) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, big.NewInt(0), logger, store, pricing, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
paychan := make(chan paymentCall, 1) refreshchan := make(chan paymentCall, 1)
f := func(ctx context.Context, peer swarm.Address, amount *big.Int) { f := func(ctx context.Context, peer swarm.Address, amount *big.Int, shadowBalance *big.Int) (*big.Int, int64, error) {
paychan <- paymentCall{peer: peer, amount: amount} refreshchan <- paymentCall{peer: peer, amount: amount}
return amount, 0, nil
} }
acc.SetPayFunc(f) acc.SetRefreshFunc(f)
peer1Addr, err := swarm.ParseHexAddress("00112233") peer1Addr, err := swarm.ParseHexAddress("00112233")
if err != nil { if err != nil {
...@@ -649,7 +932,7 @@ func TestAccountingNotifyPaymentThreshold(t *testing.T) { ...@@ -649,7 +932,7 @@ func TestAccountingNotifyPaymentThreshold(t *testing.T) {
} }
select { select {
case call := <-paychan: case call := <-refreshchan:
if call.amount.Cmp(big.NewInt(int64(debt))) != 0 { if call.amount.Cmp(big.NewInt(int64(debt))) != 0 {
t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, debt) t.Fatalf("paid wrong amount. got %d wanted %d", call.amount, debt)
} }
...@@ -659,7 +942,6 @@ func TestAccountingNotifyPaymentThreshold(t *testing.T) { ...@@ -659,7 +942,6 @@ func TestAccountingNotifyPaymentThreshold(t *testing.T) {
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for payment") t.Fatal("timeout waiting for payment")
} }
} }
func TestAccountingPeerDebt(t *testing.T) { func TestAccountingPeerDebt(t *testing.T) {
...@@ -670,17 +952,19 @@ func TestAccountingPeerDebt(t *testing.T) { ...@@ -670,17 +952,19 @@ func TestAccountingPeerDebt(t *testing.T) {
pricing := &pricingMock{} pricing := &pricingMock{}
acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, big.NewInt(0), logger, store, pricing) acc, err := accounting.NewAccounting(testPaymentThreshold, testPaymentTolerance, big.NewInt(0), logger, store, pricing, big.NewInt(testRefreshRate))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peer1Addr := swarm.MustParseHexAddress("00112233") peer1Addr := swarm.MustParseHexAddress("00112233")
debt := uint64(1000) debt := uint64(1000)
err = acc.Debit(peer1Addr, debt) debitAction := acc.PrepareDebit(peer1Addr, debt)
err = debitAction.Apply()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
debitAction.Cleanup()
actualDebt, err := acc.PeerDebt(peer1Addr) actualDebt, err := acc.PeerDebt(peer1Addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
......
package accounting
import (
"time"
"github.com/ethersphere/bee/pkg/swarm"
)
func (s *Accounting) SetTimeNow(f func() time.Time) {
s.timeNow = f
}
func (s *Accounting) SetTime(k int64) {
s.SetTimeNow(func() time.Time {
return time.Unix(k, 0)
})
}
func (a *Accounting) IsPaymentOngoing(peer swarm.Address) bool {
return a.getAccountingPeer(peer).paymentOngoing
}
...@@ -22,8 +22,9 @@ type Service struct { ...@@ -22,8 +22,9 @@ type Service struct {
reserveFunc func(ctx context.Context, peer swarm.Address, price uint64) error reserveFunc func(ctx context.Context, peer swarm.Address, price uint64) error
releaseFunc func(peer swarm.Address, price uint64) releaseFunc func(peer swarm.Address, price uint64)
creditFunc func(peer swarm.Address, price uint64) error creditFunc func(peer swarm.Address, price uint64) error
debitFunc func(peer swarm.Address, price uint64) error prepareDebitFunc func(peer swarm.Address, price uint64) accounting.Action
balanceFunc func(swarm.Address) (*big.Int, error) balanceFunc func(swarm.Address) (*big.Int, error)
shadowBalanceFunc func(swarm.Address) (*big.Int, error)
balancesFunc func() (map[string]*big.Int, error) balancesFunc func() (map[string]*big.Int, error)
compensatedBalanceFunc func(swarm.Address) (*big.Int, error) compensatedBalanceFunc func(swarm.Address) (*big.Int, error)
compensatedBalancesFunc func() (map[string]*big.Int, error) compensatedBalancesFunc func() (map[string]*big.Int, error)
...@@ -31,6 +32,13 @@ type Service struct { ...@@ -31,6 +32,13 @@ type Service struct {
balanceSurplusFunc func(swarm.Address) (*big.Int, error) balanceSurplusFunc func(swarm.Address) (*big.Int, error)
} }
type debitAction struct {
accounting *Service
price *big.Int
peer swarm.Address
applied bool
}
// WithReserveFunc sets the mock Reserve function // WithReserveFunc sets the mock Reserve function
func WithReserveFunc(f func(ctx context.Context, peer swarm.Address, price uint64) error) Option { func WithReserveFunc(f func(ctx context.Context, peer swarm.Address, price uint64) error) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
...@@ -53,9 +61,9 @@ func WithCreditFunc(f func(peer swarm.Address, price uint64) error) Option { ...@@ -53,9 +61,9 @@ func WithCreditFunc(f func(peer swarm.Address, price uint64) error) Option {
} }
// WithDebitFunc sets the mock Debit function // WithDebitFunc sets the mock Debit function
func WithDebitFunc(f func(peer swarm.Address, price uint64) error) Option { func WithPrepareDebitFunc(f func(peer swarm.Address, price uint64) accounting.Action) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.debitFunc = f s.prepareDebitFunc = f
}) })
} }
...@@ -136,21 +144,36 @@ func (s *Service) Credit(peer swarm.Address, price uint64) error { ...@@ -136,21 +144,36 @@ func (s *Service) Credit(peer swarm.Address, price uint64) error {
} }
// Debit is the mock function wrapper that calls the set implementation // Debit is the mock function wrapper that calls the set implementation
func (s *Service) Debit(peer swarm.Address, price uint64) error { func (s *Service) PrepareDebit(peer swarm.Address, price uint64) accounting.Action {
if s.debitFunc != nil { if s.prepareDebitFunc != nil {
return s.debitFunc(peer, price) return s.prepareDebitFunc(peer, price)
} }
s.lock.Lock()
defer s.lock.Unlock()
if bal, ok := s.balances[peer.String()]; ok { bigPrice := new(big.Int).SetUint64(price)
s.balances[peer.String()] = new(big.Int).Add(bal, new(big.Int).SetUint64(price)) return &debitAction{
accounting: s,
price: bigPrice,
peer: peer,
applied: false,
}
}
func (a *debitAction) Apply() error {
a.accounting.lock.Lock()
defer a.accounting.lock.Unlock()
if bal, ok := a.accounting.balances[a.peer.String()]; ok {
a.accounting.balances[a.peer.String()] = new(big.Int).Add(bal, new(big.Int).Set(a.price))
} else { } else {
s.balances[peer.String()] = new(big.Int).SetUint64(price) a.accounting.balances[a.peer.String()] = new(big.Int).Set(a.price)
} }
return nil return nil
} }
func (a *debitAction) Cleanup() {}
// Balance is the mock function wrapper that calls the set implementation // Balance is the mock function wrapper that calls the set implementation
func (s *Service) Balance(peer swarm.Address) (*big.Int, error) { func (s *Service) Balance(peer swarm.Address) (*big.Int, error) {
if s.balanceFunc != nil { if s.balanceFunc != nil {
...@@ -165,6 +188,19 @@ func (s *Service) Balance(peer swarm.Address) (*big.Int, error) { ...@@ -165,6 +188,19 @@ func (s *Service) Balance(peer swarm.Address) (*big.Int, error) {
} }
} }
func (s *Service) ShadowBalance(peer swarm.Address) (*big.Int, error) {
if s.shadowBalanceFunc != nil {
return s.shadowBalanceFunc(peer)
}
s.lock.Lock()
defer s.lock.Unlock()
if bal, ok := s.balances[peer.String()]; ok {
return new(big.Int).Neg(bal), nil
} else {
return big.NewInt(0), nil
}
}
// Balances is the mock function wrapper that calls the set implementation // Balances is the mock function wrapper that calls the set implementation
func (s *Service) Balances() (map[string]*big.Int, error) { func (s *Service) Balances() (map[string]*big.Int, error) {
if s.balancesFunc != nil { if s.balancesFunc != nil {
......
...@@ -44,10 +44,10 @@ type Service struct { ...@@ -44,10 +44,10 @@ type Service struct {
tracer *tracing.Tracer tracer *tracing.Tracer
tags *tags.Tags tags *tags.Tags
accounting accounting.Interface accounting accounting.Interface
settlement settlement.Interface pseudosettle settlement.Interface
chequebookEnabled bool chequebookEnabled bool
chequebook chequebook.Service chequebook chequebook.Service
swap swap.ApiInterface swap swap.Interface
batchStore postage.Storer batchStore postage.Storer
corsAllowedOrigins []string corsAllowedOrigins []string
metricsRegistry *prometheus.Registry metricsRegistry *prometheus.Registry
...@@ -80,19 +80,19 @@ func New(overlay swarm.Address, publicKey, pssPublicKey ecdsa.PublicKey, ethereu ...@@ -80,19 +80,19 @@ func New(overlay swarm.Address, publicKey, pssPublicKey ecdsa.PublicKey, ethereu
// Configure injects required dependencies and configuration parameters and // Configure injects required dependencies and configuration parameters and
// constructs HTTP routes that depend on them. It is intended and safe to call // constructs HTTP routes that depend on them. It is intended and safe to call
// this method only once. // this method only once.
func (s *Service) Configure(p2p p2p.DebugService, pingpong pingpong.Interface, topologyDriver topology.Driver, lightNodes *lightnode.Container, storer storage.Storer, tags *tags.Tags, accounting accounting.Interface, settlement settlement.Interface, chequebookEnabled bool, swap swap.ApiInterface, chequebook chequebook.Service, batchStore postage.Storer) { func (s *Service) Configure(p2p p2p.DebugService, pingpong pingpong.Interface, topologyDriver topology.Driver, lightNodes *lightnode.Container, storer storage.Storer, tags *tags.Tags, accounting accounting.Interface, pseudosettle settlement.Interface, chequebookEnabled bool, swap swap.Interface, chequebook chequebook.Service, batchStore postage.Storer) {
s.p2p = p2p s.p2p = p2p
s.pingpong = pingpong s.pingpong = pingpong
s.topologyDriver = topologyDriver s.topologyDriver = topologyDriver
s.storer = storer s.storer = storer
s.tags = tags s.tags = tags
s.accounting = accounting s.accounting = accounting
s.settlement = settlement
s.chequebookEnabled = chequebookEnabled s.chequebookEnabled = chequebookEnabled
s.chequebook = chequebook s.chequebook = chequebook
s.swap = swap s.swap = swap
s.lightNodes = lightNodes s.lightNodes = lightNodes
s.batchStore = batchStore s.batchStore = batchStore
s.pseudosettle = pseudosettle
s.setRouter(s.newRouter()) s.setRouter(s.newRouter())
} }
......
...@@ -65,7 +65,7 @@ func newTestServer(t *testing.T, o testServerOptions) *testServer { ...@@ -65,7 +65,7 @@ func newTestServer(t *testing.T, o testServerOptions) *testServer {
acc := accountingmock.NewAccounting(o.AccountingOpts...) acc := accountingmock.NewAccounting(o.AccountingOpts...)
settlement := swapmock.New(o.SettlementOpts...) settlement := swapmock.New(o.SettlementOpts...)
chequebook := chequebookmock.NewChequebook(o.ChequebookOpts...) chequebook := chequebookmock.NewChequebook(o.ChequebookOpts...)
swapserv := swapmock.NewApiInterface(o.SwapOpts...) swapserv := swapmock.New(o.SwapOpts...)
ln := lightnode.NewContainer() ln := lightnode.NewContainer()
s := debugapi.New(o.Overlay, o.PublicKey, o.PSSPublicKey, o.EthereumAddress, logging.New(ioutil.Discard, 0), nil, o.CORSAllowedOrigins) s := debugapi.New(o.Overlay, o.PublicKey, o.PSSPublicKey, o.EthereumAddress, logging.New(ioutil.Discard, 0), nil, o.CORSAllowedOrigins)
s.Configure(o.P2P, o.Pingpong, topologyDriver, ln, o.Storer, o.Tags, acc, settlement, true, swapserv, chequebook, o.BatchStore) s.Configure(o.P2P, o.Pingpong, topologyDriver, ln, o.Storer, o.Tags, acc, settlement, true, swapserv, chequebook, o.BatchStore)
...@@ -132,7 +132,7 @@ func TestServer_Configure(t *testing.T) { ...@@ -132,7 +132,7 @@ func TestServer_Configure(t *testing.T) {
acc := accountingmock.NewAccounting(o.AccountingOpts...) acc := accountingmock.NewAccounting(o.AccountingOpts...)
settlement := swapmock.New(o.SettlementOpts...) settlement := swapmock.New(o.SettlementOpts...)
chequebook := chequebookmock.NewChequebook(o.ChequebookOpts...) chequebook := chequebookmock.NewChequebook(o.ChequebookOpts...)
swapserv := swapmock.NewApiInterface(o.SwapOpts...) swapserv := swapmock.New(o.SwapOpts...)
ln := lightnode.NewContainer() ln := lightnode.NewContainer()
s := debugapi.New(o.Overlay, o.PublicKey, o.PSSPublicKey, o.EthereumAddress, logging.New(ioutil.Discard, 0), nil, nil) s := debugapi.New(o.Overlay, o.PublicKey, o.PSSPublicKey, o.EthereumAddress, logging.New(ioutil.Discard, 0), nil, nil)
ts := httptest.NewServer(s) ts := httptest.NewServer(s)
......
...@@ -125,15 +125,18 @@ func (s *Service) newRouter() *mux.Router { ...@@ -125,15 +125,18 @@ func (s *Service) newRouter() *mux.Router {
"GET": http.HandlerFunc(s.peerBalanceHandler), "GET": http.HandlerFunc(s.peerBalanceHandler),
}) })
router.Handle("/timesettlements", jsonhttp.MethodHandler{
"GET": http.HandlerFunc(s.settlementsHandlerPseudosettle),
})
if s.chequebookEnabled {
router.Handle("/settlements", jsonhttp.MethodHandler{ router.Handle("/settlements", jsonhttp.MethodHandler{
"GET": http.HandlerFunc(s.settlementsHandler), "GET": http.HandlerFunc(s.settlementsHandler),
}) })
router.Handle("/settlements/{peer}", jsonhttp.MethodHandler{ router.Handle("/settlements/{peer}", jsonhttp.MethodHandler{
"GET": http.HandlerFunc(s.peerSettlementsHandler), "GET": http.HandlerFunc(s.peerSettlementsHandler),
}) })
if s.chequebookEnabled {
router.Handle("/chequebook/balance", jsonhttp.MethodHandler{ router.Handle("/chequebook/balance", jsonhttp.MethodHandler{
"GET": http.HandlerFunc(s.chequebookBalanceHandler), "GET": http.HandlerFunc(s.chequebookBalanceHandler),
}) })
......
...@@ -34,14 +34,14 @@ type settlementsResponse struct { ...@@ -34,14 +34,14 @@ type settlementsResponse struct {
func (s *Service) settlementsHandler(w http.ResponseWriter, r *http.Request) { func (s *Service) settlementsHandler(w http.ResponseWriter, r *http.Request) {
settlementsSent, err := s.settlement.SettlementsSent() settlementsSent, err := s.swap.SettlementsSent()
if err != nil { if err != nil {
jsonhttp.InternalServerError(w, errCantSettlements) jsonhttp.InternalServerError(w, errCantSettlements)
s.logger.Debugf("debug api: sent settlements: %v", err) s.logger.Debugf("debug api: sent settlements: %v", err)
s.logger.Error("debug api: can not get sent settlements") s.logger.Error("debug api: can not get sent settlements")
return return
} }
settlementsReceived, err := s.settlement.SettlementsReceived() settlementsReceived, err := s.swap.SettlementsReceived()
if err != nil { if err != nil {
jsonhttp.InternalServerError(w, errCantSettlements) jsonhttp.InternalServerError(w, errCantSettlements)
s.logger.Debugf("debug api: received settlements: %v", err) s.logger.Debugf("debug api: received settlements: %v", err)
...@@ -100,7 +100,7 @@ func (s *Service) peerSettlementsHandler(w http.ResponseWriter, r *http.Request) ...@@ -100,7 +100,7 @@ func (s *Service) peerSettlementsHandler(w http.ResponseWriter, r *http.Request)
peerexists := false peerexists := false
received, err := s.settlement.TotalReceived(peer) received, err := s.swap.TotalReceived(peer)
if err != nil { if err != nil {
if !errors.Is(err, settlement.ErrPeerNoSettlements) { if !errors.Is(err, settlement.ErrPeerNoSettlements) {
s.logger.Debugf("debug api: settlements peer: get peer %s received settlement: %v", peer.String(), err) s.logger.Debugf("debug api: settlements peer: get peer %s received settlement: %v", peer.String(), err)
...@@ -116,7 +116,7 @@ func (s *Service) peerSettlementsHandler(w http.ResponseWriter, r *http.Request) ...@@ -116,7 +116,7 @@ func (s *Service) peerSettlementsHandler(w http.ResponseWriter, r *http.Request)
peerexists = true peerexists = true
} }
sent, err := s.settlement.TotalSent(peer) sent, err := s.swap.TotalSent(peer)
if err != nil { if err != nil {
if !errors.Is(err, settlement.ErrPeerNoSettlements) { if !errors.Is(err, settlement.ErrPeerNoSettlements) {
s.logger.Debugf("debug api: settlements peer: get peer %s sent settlement: %v", peer.String(), err) s.logger.Debugf("debug api: settlements peer: get peer %s sent settlement: %v", peer.String(), err)
...@@ -143,3 +143,59 @@ func (s *Service) peerSettlementsHandler(w http.ResponseWriter, r *http.Request) ...@@ -143,3 +143,59 @@ func (s *Service) peerSettlementsHandler(w http.ResponseWriter, r *http.Request)
SettlementSent: sent, SettlementSent: sent,
}) })
} }
func (s *Service) settlementsHandlerPseudosettle(w http.ResponseWriter, r *http.Request) {
settlementsSent, err := s.pseudosettle.SettlementsSent()
if err != nil {
jsonhttp.InternalServerError(w, errCantSettlements)
s.logger.Debugf("debug api: sent settlements: %v", err)
s.logger.Error("debug api: can not get sent settlements")
return
}
settlementsReceived, err := s.pseudosettle.SettlementsReceived()
if err != nil {
jsonhttp.InternalServerError(w, errCantSettlements)
s.logger.Debugf("debug api: received settlements: %v", err)
s.logger.Error("debug api: can not get received settlements")
return
}
totalReceived := big.NewInt(0)
totalSent := big.NewInt(0)
settlementResponses := make(map[string]settlementResponse)
for a, b := range settlementsSent {
settlementResponses[a] = settlementResponse{
Peer: a,
SettlementSent: b,
SettlementReceived: big.NewInt(0),
}
totalSent.Add(b, totalSent)
}
for a, b := range settlementsReceived {
if _, ok := settlementResponses[a]; ok {
t := settlementResponses[a]
t.SettlementReceived = b
settlementResponses[a] = t
} else {
settlementResponses[a] = settlementResponse{
Peer: a,
SettlementSent: big.NewInt(0),
SettlementReceived: b,
}
}
totalReceived.Add(b, totalReceived)
}
settlementResponsesArray := make([]settlementResponse, len(settlementResponses))
i := 0
for k := range settlementResponses {
settlementResponsesArray[i] = settlementResponses[k]
i++
}
jsonhttp.OK(w, settlementsResponse{TotalSettlementReceived: totalReceived, TotalSettlementSent: totalSent, Settlements: settlementResponsesArray})
}
...@@ -36,7 +36,7 @@ func TestSettlements(t *testing.T) { ...@@ -36,7 +36,7 @@ func TestSettlements(t *testing.T) {
} }
testServer := newTestServer(t, testServerOptions{ testServer := newTestServer(t, testServerOptions{
SettlementOpts: []mock.Option{mock.WithSettlementsSentFunc(settlementsSentFunc), mock.WithSettlementsRecvFunc(settlementsRecvFunc)}, SwapOpts: []mock.Option{mock.WithSettlementsSentFunc(settlementsSentFunc), mock.WithSettlementsRecvFunc(settlementsRecvFunc)},
}) })
expected := &debugapi.SettlementsResponse{ expected := &debugapi.SettlementsResponse{
...@@ -84,7 +84,7 @@ func TestSettlementsError(t *testing.T) { ...@@ -84,7 +84,7 @@ func TestSettlementsError(t *testing.T) {
return nil, wantErr return nil, wantErr
} }
testServer := newTestServer(t, testServerOptions{ testServer := newTestServer(t, testServerOptions{
SettlementOpts: []mock.Option{mock.WithSettlementsSentFunc(settlementsSentFunc)}, SwapOpts: []mock.Option{mock.WithSettlementsSentFunc(settlementsSentFunc)},
}) })
jsonhttptest.Request(t, testServer.Client, http.MethodGet, "/settlements", http.StatusInternalServerError, jsonhttptest.Request(t, testServer.Client, http.MethodGet, "/settlements", http.StatusInternalServerError,
...@@ -101,7 +101,7 @@ func TestSettlementsPeers(t *testing.T) { ...@@ -101,7 +101,7 @@ func TestSettlementsPeers(t *testing.T) {
return big.NewInt(1000000000000000000), nil return big.NewInt(1000000000000000000), nil
} }
testServer := newTestServer(t, testServerOptions{ testServer := newTestServer(t, testServerOptions{
SettlementOpts: []mock.Option{mock.WithSettlementSentFunc(settlementSentFunc)}, SwapOpts: []mock.Option{mock.WithSettlementSentFunc(settlementSentFunc)},
}) })
jsonhttptest.Request(t, testServer.Client, http.MethodGet, "/settlements/"+peer, http.StatusOK, jsonhttptest.Request(t, testServer.Client, http.MethodGet, "/settlements/"+peer, http.StatusOK,
...@@ -124,7 +124,7 @@ func TestSettlementsPeersNoSettlements(t *testing.T) { ...@@ -124,7 +124,7 @@ func TestSettlementsPeersNoSettlements(t *testing.T) {
t.Run("no sent", func(t *testing.T) { t.Run("no sent", func(t *testing.T) {
testServer := newTestServer(t, testServerOptions{ testServer := newTestServer(t, testServerOptions{
SettlementOpts: []mock.Option{ SwapOpts: []mock.Option{
mock.WithSettlementSentFunc(errFunc), mock.WithSettlementSentFunc(errFunc),
mock.WithSettlementRecvFunc(noErrFunc), mock.WithSettlementRecvFunc(noErrFunc),
}, },
...@@ -141,7 +141,7 @@ func TestSettlementsPeersNoSettlements(t *testing.T) { ...@@ -141,7 +141,7 @@ func TestSettlementsPeersNoSettlements(t *testing.T) {
t.Run("no received", func(t *testing.T) { t.Run("no received", func(t *testing.T) {
testServer := newTestServer(t, testServerOptions{ testServer := newTestServer(t, testServerOptions{
SettlementOpts: []mock.Option{ SwapOpts: []mock.Option{
mock.WithSettlementSentFunc(noErrFunc), mock.WithSettlementSentFunc(noErrFunc),
mock.WithSettlementRecvFunc(errFunc), mock.WithSettlementRecvFunc(errFunc),
}, },
...@@ -164,7 +164,7 @@ func TestSettlementsPeersError(t *testing.T) { ...@@ -164,7 +164,7 @@ func TestSettlementsPeersError(t *testing.T) {
return nil, wantErr return nil, wantErr
} }
testServer := newTestServer(t, testServerOptions{ testServer := newTestServer(t, testServerOptions{
SettlementOpts: []mock.Option{mock.WithSettlementSentFunc(settlementSentFunc)}, SwapOpts: []mock.Option{mock.WithSettlementSentFunc(settlementSentFunc)},
}) })
jsonhttptest.Request(t, testServer.Client, http.MethodGet, "/settlements/"+peer, http.StatusInternalServerError, jsonhttptest.Request(t, testServer.Client, http.MethodGet, "/settlements/"+peer, http.StatusInternalServerError,
......
...@@ -203,7 +203,7 @@ func InitSwap( ...@@ -203,7 +203,7 @@ func InitSwap(
chequebookService chequebook.Service, chequebookService chequebook.Service,
chequeStore chequebook.ChequeStore, chequeStore chequebook.ChequeStore,
cashoutService chequebook.CashoutService, cashoutService chequebook.CashoutService,
accountingAPI settlement.AccountingAPI, accounting settlement.Accounting,
) (*swap.Service, error) { ) (*swap.Service, error) {
swapProtocol := swapprotocol.New(p2ps, logger, overlayEthAddress) swapProtocol := swapprotocol.New(p2ps, logger, overlayEthAddress)
swapAddressBook := swap.NewAddressbook(stateStore) swapAddressBook := swap.NewAddressbook(stateStore)
...@@ -218,7 +218,7 @@ func InitSwap( ...@@ -218,7 +218,7 @@ func InitSwap(
networkID, networkID,
cashoutService, cashoutService,
p2ps, p2ps,
accountingAPI, accounting,
) )
swapProtocol.SetSwap(swapService) swapProtocol.SetSwap(swapService)
......
...@@ -54,7 +54,6 @@ import ( ...@@ -54,7 +54,6 @@ import (
"github.com/ethersphere/bee/pkg/recovery" "github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/resolver/multiresolver" "github.com/ethersphere/bee/pkg/resolver/multiresolver"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
settlement "github.com/ethersphere/bee/pkg/settlement"
"github.com/ethersphere/bee/pkg/settlement/pseudosettle" "github.com/ethersphere/bee/pkg/settlement/pseudosettle"
"github.com/ethersphere/bee/pkg/settlement/swap" "github.com/ethersphere/bee/pkg/settlement/swap"
"github.com/ethersphere/bee/pkg/settlement/swap/chequebook" "github.com/ethersphere/bee/pkg/settlement/swap/chequebook"
...@@ -136,6 +135,11 @@ type Options struct { ...@@ -136,6 +135,11 @@ type Options struct {
BlockTime uint64 BlockTime uint64
} }
const (
refreshRate = int64(1000000000000)
basePrice = 1000000000
)
func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, signer crypto.Signer, networkID uint64, logger logging.Logger, libp2pPrivateKey, pssPrivateKey *ecdsa.PrivateKey, o Options) (b *Bee, err error) { func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, signer crypto.Signer, networkID uint64, logger logging.Logger, libp2pPrivateKey, pssPrivateKey *ecdsa.PrivateKey, o Options) (b *Bee, err error) {
tracer, tracerCloser, err := tracing.NewTracer(&tracing.Options{ tracer, tracerCloser, err := tracing.NewTracer(&tracing.Options{
Enabled: o.TracingEnabled, Enabled: o.TracingEnabled,
...@@ -419,7 +423,6 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -419,7 +423,6 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
} }
} }
var settlement settlement.Interface
var swapService *swap.Service var swapService *swap.Service
kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, StandaloneMode: o.Standalone, BootnodeMode: o.BootnodeMode}) kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, StandaloneMode: o.Standalone, BootnodeMode: o.BootnodeMode})
...@@ -445,7 +448,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -445,7 +448,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
return nil, fmt.Errorf("invalid payment threshold: %s", paymentThreshold) return nil, fmt.Errorf("invalid payment threshold: %s", paymentThreshold)
} }
pricer := pricer.NewFixedPricer(swarmAddress, 1000000000) pricer := pricer.NewFixedPricer(swarmAddress, basePrice)
minThreshold := pricer.MostExpensive() minThreshold := pricer.MostExpensive()
...@@ -472,6 +475,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -472,6 +475,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
if !ok { if !ok {
return nil, fmt.Errorf("invalid payment early: %s", paymentEarly) return nil, fmt.Errorf("invalid payment early: %s", paymentEarly)
} }
acc, err := accounting.NewAccounting( acc, err := accounting.NewAccounting(
paymentThreshold, paymentThreshold,
paymentTolerance, paymentTolerance,
...@@ -479,11 +483,19 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -479,11 +483,19 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
logger, logger,
stateStore, stateStore,
pricing, pricing,
big.NewInt(refreshRate),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("accounting: %w", err) return nil, fmt.Errorf("accounting: %w", err)
} }
pseudosettleService := pseudosettle.New(p2ps, logger, stateStore, acc, big.NewInt(refreshRate), p2ps)
if err = p2ps.AddProtocol(pseudosettleService.Protocol()); err != nil {
return nil, fmt.Errorf("pseudosettle service: %w", err)
}
acc.SetRefreshFunc(pseudosettleService.Pay)
if o.SwapEnable { if o.SwapEnable {
swapService, err = InitSwap( swapService, err = InitSwap(
p2ps, p2ps,
...@@ -499,17 +511,9 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -499,17 +511,9 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
if err != nil { if err != nil {
return nil, err return nil, err
} }
settlement = swapService acc.SetPayFunc(swapService.Pay)
} else {
pseudosettleService := pseudosettle.New(p2ps, logger, stateStore, acc)
if err = p2ps.AddProtocol(pseudosettleService.Protocol()); err != nil {
return nil, fmt.Errorf("pseudosettle service: %w", err)
}
settlement = pseudosettleService
} }
acc.SetPayFunc(settlement.Pay)
pricing.SetPaymentThresholdObserver(acc) pricing.SetPaymentThresholdObserver(acc)
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer) retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer)
...@@ -645,12 +649,14 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -645,12 +649,14 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
debugAPIService.MustRegisterMetrics(l.Metrics()...) debugAPIService.MustRegisterMetrics(l.Metrics()...)
} }
if l, ok := settlement.(metrics.Collector); ok { debugAPIService.MustRegisterMetrics(pseudosettleService.Metrics()...)
debugAPIService.MustRegisterMetrics(l.Metrics()...)
if swapService != nil {
debugAPIService.MustRegisterMetrics(swapService.Metrics()...)
} }
// inject dependencies and configure full debug api http path routes // inject dependencies and configure full debug api http path routes
debugAPIService.Configure(p2ps, pingPong, kad, lightNodes, storer, tagService, acc, settlement, o.SwapEnable, swapService, chequebookService, batchStore) debugAPIService.Configure(p2ps, pingPong, kad, lightNodes, storer, tagService, acc, pseudosettleService, o.SwapEnable, swapService, chequebookService, batchStore)
} }
if err := kad.Start(p2pCtx); err != nil { if err := kad.Start(p2pCtx); err != nil {
......
...@@ -27,7 +27,7 @@ import ( ...@@ -27,7 +27,7 @@ import (
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/topology"
"github.com/ethersphere/bee/pkg/tracing" "github.com/ethersphere/bee/pkg/tracing"
"github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
opentracing "github.com/opentracing/opentracing-go" opentracing "github.com/opentracing/opentracing-go"
) )
...@@ -158,12 +158,15 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -158,12 +158,15 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
ps.logger.Errorf("pushsync: chunk store: %v", err) ps.logger.Errorf("pushsync: chunk store: %v", err)
} }
debit := ps.accounting.PrepareDebit(p.Address, price)
defer debit.Cleanup()
// return back receipt // return back receipt
receipt := pb.Receipt{Address: chunk.Address().Bytes()} receipt := pb.Receipt{Address: chunk.Address().Bytes()}
if err := w.WriteMsgWithContext(ctxd, &receipt); err != nil { if err := w.WriteMsgWithContext(ctxd, &receipt); err != nil {
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err) return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
} }
return ps.accounting.Debit(p.Address, price) return debit.Apply()
} }
return ErrOutOfDepthReplication return ErrOutOfDepthReplication
...@@ -283,23 +286,29 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -283,23 +286,29 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
} }
// return back receipt // return back receipt
debit := ps.accounting.PrepareDebit(p.Address, price)
defer debit.Cleanup()
receipt := pb.Receipt{Address: chunk.Address().Bytes(), Signature: signature} receipt := pb.Receipt{Address: chunk.Address().Bytes(), Signature: signature}
if err := w.WriteMsgWithContext(ctx, &receipt); err != nil { if err := w.WriteMsgWithContext(ctx, &receipt); err != nil {
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err) return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
} }
return ps.accounting.Debit(p.Address, price) return debit.Apply()
} }
return fmt.Errorf("handler: push to closest: %w", err) return fmt.Errorf("handler: push to closest: %w", err)
} }
debit := ps.accounting.PrepareDebit(p.Address, price)
defer debit.Cleanup()
// pass back the receipt // pass back the receipt
if err := w.WriteMsgWithContext(ctx, receipt); err != nil { if err := w.WriteMsgWithContext(ctx, receipt); err != nil {
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err) return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
} }
return ps.accounting.Debit(p.Address, price) return debit.Apply()
} }
// PushChunkToClosest sends chunk to the closest peer by opening a stream. It then waits for // PushChunkToClosest sends chunk to the closest peer by opening a stream. It then waits for
......
...@@ -362,6 +362,11 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -362,6 +362,11 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
if err != nil { if err != nil {
return fmt.Errorf("stamp marshal: %w", err) return fmt.Errorf("stamp marshal: %w", err)
} }
chunkPrice := s.pricer.Price(chunk.Address())
debit := s.accounting.PrepareDebit(p.Address, chunkPrice)
defer debit.Cleanup()
if err := w.WriteMsgWithContext(ctx, &pb.Delivery{ if err := w.WriteMsgWithContext(ctx, &pb.Delivery{
Data: chunk.Data(), Data: chunk.Data(),
Stamp: stamp, Stamp: stamp,
...@@ -371,8 +376,6 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -371,8 +376,6 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String()) s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String())
chunkPrice := s.pricer.Price(chunk.Address())
// debit price from p's balance // debit price from p's balance
return s.accounting.Debit(p.Address, chunkPrice) return debit.Apply()
} }
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
package settlement package settlement
import ( import (
"context"
"errors" "errors"
"math/big" "math/big"
...@@ -18,9 +17,6 @@ var ( ...@@ -18,9 +17,6 @@ var (
// Interface is the interface used by Accounting to trigger settlement // Interface is the interface used by Accounting to trigger settlement
type Interface interface { type Interface interface {
// Pay initiates a payment to the given peer
// It should return without error it is likely that the payment worked
Pay(ctx context.Context, peer swarm.Address, amount *big.Int)
// TotalSent returns the total amount sent to a peer // TotalSent returns the total amount sent to a peer
TotalSent(peer swarm.Address) (totalSent *big.Int, err error) TotalSent(peer swarm.Address) (totalSent *big.Int, err error)
// TotalReceived returns the total amount received from a peer // TotalReceived returns the total amount received from a peer
...@@ -31,8 +27,9 @@ type Interface interface { ...@@ -31,8 +27,9 @@ type Interface interface {
SettlementsReceived() (map[string]*big.Int, error) SettlementsReceived() (map[string]*big.Int, error)
} }
type AccountingAPI interface { type Accounting interface {
PeerDebt(peer swarm.Address) (*big.Int, error) PeerDebt(peer swarm.Address) (*big.Int, error)
NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error
NotifyPaymentSent(peer swarm.Address, amount *big.Int, receivedError error) NotifyPaymentSent(peer swarm.Address, amount *big.Int, receivedError error)
NotifyRefreshmentReceived(peer swarm.Address, amount *big.Int) error
} }
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pseudosettle
import (
"context"
"time"
"github.com/ethersphere/bee/pkg/p2p"
)
func (s *Service) SetTimeNow(f func() time.Time) {
s.timeNow = f
}
func (s *Service) SetTime(k int64) {
s.SetTimeNow(func() time.Time {
return time.Unix(k, 0)
})
}
func (s *Service) Init(ctx context.Context, peer p2p.Peer) error {
return s.init(ctx, peer)
}
func (s *Service) Terminate(peer p2p.Peer) error {
return s.terminate(peer)
}
...@@ -23,7 +23,7 @@ var _ = math.Inf ...@@ -23,7 +23,7 @@ var _ = math.Inf
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type Payment struct { type Payment struct {
Amount uint64 `protobuf:"varint,1,opt,name=Amount,proto3" json:"Amount,omitempty"` Amount []byte `protobuf:"bytes,1,opt,name=Amount,proto3" json:"Amount,omitempty"`
} }
func (m *Payment) Reset() { *m = Payment{} } func (m *Payment) Reset() { *m = Payment{} }
...@@ -59,29 +59,84 @@ func (m *Payment) XXX_DiscardUnknown() { ...@@ -59,29 +59,84 @@ func (m *Payment) XXX_DiscardUnknown() {
var xxx_messageInfo_Payment proto.InternalMessageInfo var xxx_messageInfo_Payment proto.InternalMessageInfo
func (m *Payment) GetAmount() uint64 { func (m *Payment) GetAmount() []byte {
if m != nil { if m != nil {
return m.Amount return m.Amount
} }
return nil
}
type PaymentAck struct {
Amount []byte `protobuf:"bytes,1,opt,name=Amount,proto3" json:"Amount,omitempty"`
Timestamp int64 `protobuf:"varint,2,opt,name=Timestamp,proto3" json:"Timestamp,omitempty"`
}
func (m *PaymentAck) Reset() { *m = PaymentAck{} }
func (m *PaymentAck) String() string { return proto.CompactTextString(m) }
func (*PaymentAck) ProtoMessage() {}
func (*PaymentAck) Descriptor() ([]byte, []int) {
return fileDescriptor_3ff21bb6c9cf5e84, []int{1}
}
func (m *PaymentAck) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *PaymentAck) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_PaymentAck.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *PaymentAck) XXX_Merge(src proto.Message) {
xxx_messageInfo_PaymentAck.Merge(m, src)
}
func (m *PaymentAck) XXX_Size() int {
return m.Size()
}
func (m *PaymentAck) XXX_DiscardUnknown() {
xxx_messageInfo_PaymentAck.DiscardUnknown(m)
}
var xxx_messageInfo_PaymentAck proto.InternalMessageInfo
func (m *PaymentAck) GetAmount() []byte {
if m != nil {
return m.Amount
}
return nil
}
func (m *PaymentAck) GetTimestamp() int64 {
if m != nil {
return m.Timestamp
}
return 0 return 0
} }
func init() { func init() {
proto.RegisterType((*Payment)(nil), "pseudosettle.Payment") proto.RegisterType((*Payment)(nil), "pseudosettle.Payment")
proto.RegisterType((*PaymentAck)(nil), "pseudosettle.PaymentAck")
} }
func init() { proto.RegisterFile("pseudosettle.proto", fileDescriptor_3ff21bb6c9cf5e84) } func init() { proto.RegisterFile("pseudosettle.proto", fileDescriptor_3ff21bb6c9cf5e84) }
var fileDescriptor_3ff21bb6c9cf5e84 = []byte{ var fileDescriptor_3ff21bb6c9cf5e84 = []byte{
// 114 bytes of a gzipped FileDescriptorProto // 148 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0x28, 0x4e, 0x2d, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0x28, 0x4e, 0x2d,
0x4d, 0xc9, 0x2f, 0x4e, 0x2d, 0x29, 0xc9, 0x49, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x4d, 0xc9, 0x2f, 0x4e, 0x2d, 0x29, 0xc9, 0x49, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2,
0x41, 0x16, 0x53, 0x52, 0xe4, 0x62, 0x0f, 0x48, 0xac, 0xcc, 0x4d, 0xcd, 0x2b, 0x11, 0x12, 0xe3, 0x41, 0x16, 0x53, 0x52, 0xe4, 0x62, 0x0f, 0x48, 0xac, 0xcc, 0x4d, 0xcd, 0x2b, 0x11, 0x12, 0xe3,
0x62, 0x73, 0xcc, 0xcd, 0x2f, 0xcd, 0x2b, 0x91, 0x60, 0x54, 0x60, 0xd4, 0x60, 0x09, 0x82, 0xf2, 0x62, 0x73, 0xcc, 0xcd, 0x2f, 0xcd, 0x2b, 0x91, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x09, 0x82, 0xf2,
0x9c, 0x64, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x09, 0x94, 0x9c, 0xb8, 0xb8, 0xa0, 0x4a, 0x1c, 0x93, 0xb3, 0x71, 0xa9, 0x12, 0x92, 0xe1, 0xe2, 0x0c,
0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8, 0xf1, 0x58, 0x8e, 0x21, 0x8a, 0xa9, 0x20, 0x29, 0xc9, 0xcc, 0x4d, 0x2d, 0x2e, 0x49, 0xcc, 0x2d, 0x90, 0x60, 0x52, 0x60, 0xd4, 0x60, 0x0e, 0x42,
0x89, 0x0d, 0x6c, 0xaa, 0x31, 0x20, 0x00, 0x00, 0xff, 0xff, 0xfb, 0x97, 0x5c, 0xf8, 0x6b, 0x00, 0x08, 0x38, 0xc9, 0x9c, 0x78, 0x24, 0xc7, 0x78, 0xe1, 0x91, 0x1c, 0xe3, 0x83, 0x47, 0x72, 0x8c,
0x00, 0x00, 0x13, 0x1e, 0xcb, 0x31, 0x5c, 0x78, 0x2c, 0xc7, 0x70, 0xe3, 0xb1, 0x1c, 0x43, 0x14, 0x53, 0x41,
0x52, 0x12, 0x1b, 0xd8, 0x65, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x87, 0xcb, 0xb8, 0x18,
0xaf, 0x00, 0x00, 0x00,
} }
func (m *Payment) Marshal() (dAtA []byte, err error) { func (m *Payment) Marshal() (dAtA []byte, err error) {
...@@ -104,10 +159,47 @@ func (m *Payment) MarshalToSizedBuffer(dAtA []byte) (int, error) { ...@@ -104,10 +159,47 @@ func (m *Payment) MarshalToSizedBuffer(dAtA []byte) (int, error) {
_ = i _ = i
var l int var l int
_ = l _ = l
if m.Amount != 0 { if len(m.Amount) > 0 {
i = encodeVarintPseudosettle(dAtA, i, uint64(m.Amount)) i -= len(m.Amount)
copy(dAtA[i:], m.Amount)
i = encodeVarintPseudosettle(dAtA, i, uint64(len(m.Amount)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func (m *PaymentAck) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *PaymentAck) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *PaymentAck) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if m.Timestamp != 0 {
i = encodeVarintPseudosettle(dAtA, i, uint64(m.Timestamp))
i-- i--
dAtA[i] = 0x8 dAtA[i] = 0x10
}
if len(m.Amount) > 0 {
i -= len(m.Amount)
copy(dAtA[i:], m.Amount)
i = encodeVarintPseudosettle(dAtA, i, uint64(len(m.Amount)))
i--
dAtA[i] = 0xa
} }
return len(dAtA) - i, nil return len(dAtA) - i, nil
} }
...@@ -129,8 +221,25 @@ func (m *Payment) Size() (n int) { ...@@ -129,8 +221,25 @@ func (m *Payment) Size() (n int) {
} }
var l int var l int
_ = l _ = l
if m.Amount != 0 { l = len(m.Amount)
n += 1 + sovPseudosettle(uint64(m.Amount)) if l > 0 {
n += 1 + l + sovPseudosettle(uint64(l))
}
return n
}
func (m *PaymentAck) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.Amount)
if l > 0 {
n += 1 + l + sovPseudosettle(uint64(l))
}
if m.Timestamp != 0 {
n += 1 + sovPseudosettle(uint64(m.Timestamp))
} }
return n return n
} }
...@@ -171,10 +280,131 @@ func (m *Payment) Unmarshal(dAtA []byte) error { ...@@ -171,10 +280,131 @@ func (m *Payment) Unmarshal(dAtA []byte) error {
} }
switch fieldNum { switch fieldNum {
case 1: case 1:
if wireType != 0 { if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Amount", wireType) return fmt.Errorf("proto: wrong wireType = %d for field Amount", wireType)
} }
m.Amount = 0 var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPseudosettle
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthPseudosettle
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthPseudosettle
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Amount = append(m.Amount[:0], dAtA[iNdEx:postIndex]...)
if m.Amount == nil {
m.Amount = []byte{}
}
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipPseudosettle(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthPseudosettle
}
if (iNdEx + skippy) < 0 {
return ErrInvalidLengthPseudosettle
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *PaymentAck) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPseudosettle
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: PaymentAck: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: PaymentAck: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Amount", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPseudosettle
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthPseudosettle
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthPseudosettle
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Amount = append(m.Amount[:0], dAtA[iNdEx:postIndex]...)
if m.Amount == nil {
m.Amount = []byte{}
}
iNdEx = postIndex
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Timestamp", wireType)
}
m.Timestamp = 0
for shift := uint(0); ; shift += 7 { for shift := uint(0); ; shift += 7 {
if shift >= 64 { if shift >= 64 {
return ErrIntOverflowPseudosettle return ErrIntOverflowPseudosettle
...@@ -184,7 +414,7 @@ func (m *Payment) Unmarshal(dAtA []byte) error { ...@@ -184,7 +414,7 @@ func (m *Payment) Unmarshal(dAtA []byte) error {
} }
b := dAtA[iNdEx] b := dAtA[iNdEx]
iNdEx++ iNdEx++
m.Amount |= uint64(b&0x7F) << shift m.Timestamp |= int64(b&0x7F) << shift
if b < 0x80 { if b < 0x80 {
break break
} }
......
...@@ -9,5 +9,10 @@ package pseudosettle; ...@@ -9,5 +9,10 @@ package pseudosettle;
option go_package = "pb"; option go_package = "pb";
message Payment { message Payment {
uint64 Amount = 1; bytes Amount = 1;
}
message PaymentAck {
bytes Amount = 1;
int64 Timestamp = 2;
} }
\ No newline at end of file
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"strings" "strings"
"sync"
"time" "time"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
...@@ -30,23 +31,47 @@ const ( ...@@ -30,23 +31,47 @@ const (
var ( var (
SettlementReceivedPrefix = "pseudosettle_total_received_" SettlementReceivedPrefix = "pseudosettle_total_received_"
SettlementSentPrefix = "pseudosettle_total_sent_" SettlementSentPrefix = "pseudosettle_total_sent_"
ErrSettlementTooSoon = errors.New("settlement too soon")
ErrNoPseudoSettlePeer = errors.New("settlement peer not found")
ErrDisconnectAllowanceCheckFailed = errors.New("settlement allowance below enforced amount")
ErrTimeOutOfSync = errors.New("settlement allowance timestamps differ beyond tolerance")
) )
type Service struct { type Service struct {
streamer p2p.Streamer streamer p2p.Streamer
logger logging.Logger logger logging.Logger
store storage.StateStorer store storage.StateStorer
accountingAPI settlement.AccountingAPI accounting settlement.Accounting
metrics metrics metrics metrics
refreshRate *big.Int
p2pService p2p.Service
timeNow func() time.Time
peersMu sync.Mutex
peers map[string]*pseudoSettlePeer
}
type pseudoSettlePeer struct {
lock sync.Mutex // lock to be held during receiving a payment from this peer
} }
func New(streamer p2p.Streamer, logger logging.Logger, store storage.StateStorer, accountingAPI settlement.AccountingAPI) *Service { type lastPayment struct {
Timestamp int64
CheckTimestamp int64
Total *big.Int
}
func New(streamer p2p.Streamer, logger logging.Logger, store storage.StateStorer, accounting settlement.Accounting, refreshRate *big.Int, p2pService p2p.Service) *Service {
return &Service{ return &Service{
streamer: streamer, streamer: streamer,
logger: logger, logger: logger,
metrics: newMetrics(), metrics: newMetrics(),
store: store, store: store,
accountingAPI: accountingAPI, accounting: accounting,
p2pService: p2pService,
refreshRate: refreshRate,
timeNow: time.Now,
peers: make(map[string]*pseudoSettlePeer),
} }
} }
...@@ -60,9 +85,34 @@ func (s *Service) Protocol() p2p.ProtocolSpec { ...@@ -60,9 +85,34 @@ func (s *Service) Protocol() p2p.ProtocolSpec {
Handler: s.handler, Handler: s.handler,
}, },
}, },
ConnectIn: s.init,
ConnectOut: s.init,
DisconnectIn: s.terminate,
DisconnectOut: s.terminate,
} }
} }
func (s *Service) init(ctx context.Context, p p2p.Peer) error {
s.peersMu.Lock()
defer s.peersMu.Unlock()
_, ok := s.peers[p.Address.String()]
if !ok {
peerData := &pseudoSettlePeer{}
s.peers[p.Address.String()] = peerData
}
return nil
}
func (s *Service) terminate(p p2p.Peer) error {
s.peersMu.Lock()
defer s.peersMu.Unlock()
delete(s.peers, p.Address.String())
return nil
}
func totalKey(peer swarm.Address, prefix string) string { func totalKey(peer swarm.Address, prefix string) string {
return fmt.Sprintf("%v%v", prefix, peer.String()) return fmt.Sprintf("%v%v", prefix, peer.String())
} }
...@@ -77,13 +127,44 @@ func totalKeyPeer(key []byte, prefix string) (peer swarm.Address, err error) { ...@@ -77,13 +127,44 @@ func totalKeyPeer(key []byte, prefix string) (peer swarm.Address, err error) {
return swarm.ParseHexAddress(split[1]) return swarm.ParseHexAddress(split[1])
} }
// peerAllowance computes the maximum incoming payment value we accept
// this is the time based allowance or the peers actual debt, whichever is less
func (s *Service) peerAllowance(peer swarm.Address) (limit *big.Int, stamp int64, err error) {
var lastTime lastPayment
err = s.store.Get(totalKey(peer, SettlementReceivedPrefix), &lastTime)
if err != nil {
if !errors.Is(err, storage.ErrNotFound) {
return nil, 0, err
}
lastTime.Timestamp = int64(0)
}
currentTime := s.timeNow().Unix()
if currentTime == lastTime.Timestamp {
return nil, 0, ErrSettlementTooSoon
}
maxAllowance := new(big.Int).Mul(big.NewInt(currentTime-lastTime.Timestamp), s.refreshRate)
peerDebt, err := s.accounting.PeerDebt(peer)
if err != nil {
return nil, 0, err
}
if peerDebt.Cmp(maxAllowance) >= 0 {
return maxAllowance, currentTime, nil
}
return peerDebt, currentTime, nil
}
func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (err error) { func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (err error) {
r := protobuf.NewReader(stream) w, r := protobuf.NewWriterAndReader(stream)
defer func() { defer func() {
if err != nil { if err != nil {
_ = stream.Reset() _ = stream.Reset()
} else { } else {
_ = stream.FullClose() go stream.FullClose()
} }
}() }()
var req pb.Payment var req pb.Payment
...@@ -91,100 +172,210 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -91,100 +172,210 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
return fmt.Errorf("read request from peer %v: %w", p.Address, err) return fmt.Errorf("read request from peer %v: %w", p.Address, err)
} }
s.metrics.TotalReceivedPseudoSettlements.Add(float64(req.Amount)) attemptedAmount := big.NewInt(0).SetBytes(req.Amount)
s.logger.Tracef("received payment message from peer %v of %d", p.Address, req.Amount)
paymentAmount := new(big.Int).Set(attemptedAmount)
s.peersMu.Lock()
pseudoSettlePeer, ok := s.peers[p.Address.String()]
if !ok {
s.peersMu.Unlock()
return ErrNoPseudoSettlePeer
}
s.peersMu.Unlock()
pseudoSettlePeer.lock.Lock()
defer pseudoSettlePeer.lock.Unlock()
totalReceived, err := s.TotalReceived(p.Address) allowance, timestamp, err := s.peerAllowance(p.Address)
if err != nil { if err != nil {
if !errors.Is(err, settlement.ErrPeerNoSettlements) {
return err return err
} }
totalReceived = big.NewInt(0)
if allowance.Cmp(attemptedAmount) < 0 {
paymentAmount.Set(allowance)
s.logger.Tracef("pseudosettle accepting reduced payment from peer %v of %d", p.Address, paymentAmount)
} else {
s.logger.Tracef("pseudosettle accepting payment message from peer %v of %d", p.Address, paymentAmount)
}
if paymentAmount.Cmp(big.NewInt(0)) < 0 {
paymentAmount.Set(big.NewInt(0))
} }
err = s.store.Put(totalKey(p.Address, SettlementReceivedPrefix), totalReceived.Add(totalReceived, new(big.Int).SetUint64(req.Amount))) err = w.WriteMsgWithContext(ctx, &pb.PaymentAck{
Amount: paymentAmount.Bytes(),
Timestamp: timestamp,
})
if err != nil { if err != nil {
return err return err
} }
return s.accountingAPI.NotifyPaymentReceived(p.Address, new(big.Int).SetUint64(req.Amount)) var lastTime lastPayment
err = s.store.Get(totalKey(p.Address, SettlementReceivedPrefix), &lastTime)
if err != nil {
if !errors.Is(err, storage.ErrNotFound) {
return err
}
lastTime.Total = big.NewInt(0)
}
lastTime.Total = lastTime.Total.Add(lastTime.Total, paymentAmount)
lastTime.Timestamp = timestamp
err = s.store.Put(totalKey(p.Address, SettlementReceivedPrefix), lastTime)
if err != nil {
return err
}
receivedPaymentF64, _ := big.NewFloat(0).SetInt(paymentAmount).Float64()
s.metrics.TotalReceivedPseudoSettlements.Add(receivedPaymentF64)
return s.accounting.NotifyRefreshmentReceived(p.Address, paymentAmount)
} }
// Pay initiates a payment to the given peer // Pay initiates a payment to the given peer
func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) { func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int, checkAllowance *big.Int) (*big.Int, int64, error) {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second) ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
var err error var err error
defer func() {
var lastTime lastPayment
err = s.store.Get(totalKey(peer, SettlementSentPrefix), &lastTime)
if err != nil { if err != nil {
s.accountingAPI.NotifyPaymentSent(peer, nil, err) if !errors.Is(err, storage.ErrNotFound) {
return nil, 0, err
} }
}() lastTime.Total = big.NewInt(0)
lastTime.Timestamp = 0
}
currentTime := s.timeNow().Unix()
if currentTime == lastTime.Timestamp {
return nil, 0, ErrSettlementTooSoon
}
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName) stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil { if err != nil {
return return nil, 0, err
} }
defer func() { defer func() {
if err != nil { if err != nil {
_ = stream.Reset() _ = stream.Reset()
} else { } else {
go stream.FullClose() _ = stream.FullClose()
} }
}() }()
s.logger.Tracef("sending payment message to peer %v of %d", peer, amount) if checkAllowance.Cmp(amount) > 0 {
w := protobuf.NewWriter(stream) checkAllowance.Set(amount)
}
s.logger.Tracef("pseudosettle sending payment message to peer %v of %d", peer, amount)
w, r := protobuf.NewWriterAndReader(stream)
err = w.WriteMsgWithContext(ctx, &pb.Payment{ err = w.WriteMsgWithContext(ctx, &pb.Payment{
Amount: amount.Uint64(), Amount: amount.Bytes(),
}) })
if err != nil { if err != nil {
return return nil, 0, err
} }
totalSent, err := s.TotalSent(peer)
checkTime := s.timeNow().Unix()
var paymentAck pb.PaymentAck
err = r.ReadMsgWithContext(ctx, &paymentAck)
if err != nil { if err != nil {
if !errors.Is(err, settlement.ErrPeerNoSettlements) { return nil, 0, err
return
} }
totalSent = big.NewInt(0)
acceptedAmount := new(big.Int).SetBytes(paymentAck.Amount)
if acceptedAmount.Cmp(amount) > 0 {
err = fmt.Errorf("pseudosettle other peer %v accepted payment larger than expected", peer)
return nil, 0, err
} }
err = s.store.Put(totalKey(peer, SettlementSentPrefix), totalSent.Add(totalSent, amount)) experiencedInterval := checkTime - lastTime.CheckTimestamp
allegedInterval := paymentAck.Timestamp - lastTime.Timestamp
if allegedInterval < 0 {
return nil, 0, ErrTimeOutOfSync
}
experienceDifferenceRecent := paymentAck.Timestamp - checkTime
if experienceDifferenceRecent < -2 || experienceDifferenceRecent > 2 {
return nil, 0, ErrTimeOutOfSync
}
experienceDifferenceInterval := experiencedInterval - allegedInterval
if experienceDifferenceInterval < -3 || experienceDifferenceInterval > 3 {
return nil, 0, ErrTimeOutOfSync
}
// enforce allowance
// check if value is appropriate
expectedAllowance := new(big.Int).Mul(big.NewInt(allegedInterval), s.refreshRate)
if expectedAllowance.Cmp(checkAllowance) > 0 {
expectedAllowance = new(big.Int).Set(checkAllowance)
}
if expectedAllowance.Cmp(acceptedAmount) > 0 {
// disconnect peer
err = s.p2pService.Blocklist(peer, 10000*time.Hour)
if err != nil { if err != nil {
return return nil, 0, err
}
return nil, 0, ErrDisconnectAllowanceCheckFailed
} }
s.accountingAPI.NotifyPaymentSent(peer, amount, nil)
amountFloat, _ := new(big.Float).SetInt(amount).Float64() lastTime.Total = lastTime.Total.Add(lastTime.Total, acceptedAmount)
lastTime.Timestamp = paymentAck.Timestamp
lastTime.CheckTimestamp = checkTime
err = s.store.Put(totalKey(peer, SettlementSentPrefix), lastTime)
if err != nil {
return nil, 0, err
}
amountFloat, _ := new(big.Float).SetInt(acceptedAmount).Float64()
s.metrics.TotalSentPseudoSettlements.Add(amountFloat) s.metrics.TotalSentPseudoSettlements.Add(amountFloat)
return acceptedAmount, lastTime.CheckTimestamp, nil
} }
func (s *Service) SetAccountingAPI(accountingAPI settlement.AccountingAPI) { func (s *Service) SetAccounting(accounting settlement.Accounting) {
s.accountingAPI = accountingAPI s.accounting = accounting
} }
// TotalSent returns the total amount sent to a peer // TotalSent returns the total amount sent to a peer
func (s *Service) TotalSent(peer swarm.Address) (totalSent *big.Int, err error) { func (s *Service) TotalSent(peer swarm.Address) (totalSent *big.Int, err error) {
key := totalKey(peer, SettlementSentPrefix) var lastTime lastPayment
err = s.store.Get(key, &totalSent)
err = s.store.Get(totalKey(peer, SettlementSentPrefix), &lastTime)
if err != nil { if err != nil {
if errors.Is(err, storage.ErrNotFound) { if !errors.Is(err, storage.ErrNotFound) {
return nil, settlement.ErrPeerNoSettlements return nil, settlement.ErrPeerNoSettlements
} }
return nil, err lastTime.Total = big.NewInt(0)
} }
return totalSent, nil
return lastTime.Total, nil
} }
// TotalReceived returns the total amount received from a peer // TotalReceived returns the total amount received from a peer
func (s *Service) TotalReceived(peer swarm.Address) (totalReceived *big.Int, err error) { func (s *Service) TotalReceived(peer swarm.Address) (totalReceived *big.Int, err error) {
key := totalKey(peer, SettlementReceivedPrefix) var lastTime lastPayment
err = s.store.Get(key, &totalReceived)
err = s.store.Get(totalKey(peer, SettlementReceivedPrefix), &lastTime)
if err != nil { if err != nil {
if errors.Is(err, storage.ErrNotFound) { if !errors.Is(err, storage.ErrNotFound) {
return nil, settlement.ErrPeerNoSettlements return nil, settlement.ErrPeerNoSettlements
} }
return nil, err lastTime.Total = big.NewInt(0)
} }
return totalReceived, nil
return lastTime.Total, nil
} }
// SettlementsSent returns all stored sent settlement values for a given type of prefix // SettlementsSent returns all stored sent settlement values for a given type of prefix
...@@ -196,13 +387,13 @@ func (s *Service) SettlementsSent() (map[string]*big.Int, error) { ...@@ -196,13 +387,13 @@ func (s *Service) SettlementsSent() (map[string]*big.Int, error) {
return false, fmt.Errorf("parse address from key: %s: %w", string(key), err) return false, fmt.Errorf("parse address from key: %s: %w", string(key), err)
} }
if _, ok := sent[addr.String()]; !ok { if _, ok := sent[addr.String()]; !ok {
var storevalue *big.Int var storevalue lastPayment
err = s.store.Get(totalKey(addr, SettlementSentPrefix), &storevalue) err = s.store.Get(totalKey(addr, SettlementSentPrefix), &storevalue)
if err != nil { if err != nil {
return false, fmt.Errorf("get peer %s settlement balance: %w", addr.String(), err) return false, fmt.Errorf("get peer %s settlement balance: %w", addr.String(), err)
} }
sent[addr.String()] = storevalue sent[addr.String()] = storevalue.Total
} }
return false, nil return false, nil
}) })
...@@ -221,13 +412,13 @@ func (s *Service) SettlementsReceived() (map[string]*big.Int, error) { ...@@ -221,13 +412,13 @@ func (s *Service) SettlementsReceived() (map[string]*big.Int, error) {
return false, fmt.Errorf("parse address from key: %s: %w", string(key), err) return false, fmt.Errorf("parse address from key: %s: %w", string(key), err)
} }
if _, ok := received[addr.String()]; !ok { if _, ok := received[addr.String()]; !ok {
var storevalue *big.Int var storevalue lastPayment
err = s.store.Get(totalKey(addr, SettlementReceivedPrefix), &storevalue) err = s.store.Get(totalKey(addr, SettlementReceivedPrefix), &storevalue)
if err != nil { if err != nil {
return false, fmt.Errorf("get peer %s settlement balance: %w", addr.String(), err) return false, fmt.Errorf("get peer %s settlement balance: %w", addr.String(), err)
} }
received[addr.String()] = storevalue received[addr.String()] = storevalue.Total
} }
return false, nil return false, nil
}) })
......
...@@ -7,12 +7,15 @@ package pseudosettle_test ...@@ -7,12 +7,15 @@ package pseudosettle_test
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"testing" "testing"
"time" "time"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
mockp2p "github.com/ethersphere/bee/pkg/p2p/mock"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/settlement/pseudosettle" "github.com/ethersphere/bee/pkg/settlement/pseudosettle"
...@@ -24,6 +27,7 @@ import ( ...@@ -24,6 +27,7 @@ import (
type testObserver struct { type testObserver struct {
receivedCalled chan notifyPaymentReceivedCall receivedCalled chan notifyPaymentReceivedCall
sentCalled chan notifyPaymentSentCall sentCalled chan notifyPaymentSentCall
peerDebts map[string]*big.Int
} }
type notifyPaymentReceivedCall struct { type notifyPaymentReceivedCall struct {
...@@ -37,18 +41,27 @@ type notifyPaymentSentCall struct { ...@@ -37,18 +41,27 @@ type notifyPaymentSentCall struct {
err error err error
} }
func newTestObserver() *testObserver { func newTestObserver(debtAmounts map[string]*big.Int, shadowBalanceAmounts map[string]*big.Int) *testObserver {
return &testObserver{ return &testObserver{
receivedCalled: make(chan notifyPaymentReceivedCall, 1), receivedCalled: make(chan notifyPaymentReceivedCall, 1),
sentCalled: make(chan notifyPaymentSentCall, 1), sentCalled: make(chan notifyPaymentSentCall, 1),
peerDebts: debtAmounts,
} }
} }
func (t *testObserver) setPeerDebt(peer swarm.Address, debt *big.Int) {
t.peerDebts[peer.String()] = debt
}
func (t *testObserver) PeerDebt(peer swarm.Address) (*big.Int, error) { func (t *testObserver) PeerDebt(peer swarm.Address) (*big.Int, error) {
return nil, nil if debt, ok := t.peerDebts[peer.String()]; ok {
return debt, nil
}
return nil, errors.New("Peer not listed")
} }
func (t *testObserver) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error { func (t *testObserver) NotifyRefreshmentReceived(peer swarm.Address, amount *big.Int) error {
t.receivedCalled <- notifyPaymentReceivedCall{ t.receivedCalled <- notifyPaymentReceivedCall{
peer: peer, peer: peer,
amount: amount, amount: amount,
...@@ -56,6 +69,10 @@ func (t *testObserver) NotifyPaymentReceived(peer swarm.Address, amount *big.Int ...@@ -56,6 +69,10 @@ func (t *testObserver) NotifyPaymentReceived(peer swarm.Address, amount *big.Int
return nil return nil
} }
func (t *testObserver) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error {
return nil
}
func (t *testObserver) NotifyPaymentSent(peer swarm.Address, amount *big.Int, err error) { func (t *testObserver) NotifyPaymentSent(peer swarm.Address, amount *big.Int, err error) {
t.sentCalled <- notifyPaymentSentCall{ t.sentCalled <- notifyPaymentSentCall{
peer: peer, peer: peer,
...@@ -63,16 +80,34 @@ func (t *testObserver) NotifyPaymentSent(peer swarm.Address, amount *big.Int, er ...@@ -63,16 +80,34 @@ func (t *testObserver) NotifyPaymentSent(peer swarm.Address, amount *big.Int, er
err: err, err: err,
} }
} }
func (t *testObserver) Reserve(ctx context.Context, peer swarm.Address, amount uint64) error {
return nil
}
func (t *testObserver) Release(peer swarm.Address, amount uint64) {
}
var testRefreshRate = int64(10000)
func TestPayment(t *testing.T) { func TestPayment(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
storeRecipient := mock.NewStateStore() storeRecipient := mock.NewStateStore()
defer storeRecipient.Close() defer storeRecipient.Close()
observer := newTestObserver()
recipient := pseudosettle.New(nil, logger, storeRecipient, observer)
peerID := swarm.MustParseHexAddress("9ee7add7") peerID := swarm.MustParseHexAddress("9ee7add7")
peer := p2p.Peer{Address: peerID}
debt := int64(10000)
observer := newTestObserver(map[string]*big.Int{peerID.String(): big.NewInt(debt)}, map[string]*big.Int{})
recipient := pseudosettle.New(nil, logger, storeRecipient, observer, big.NewInt(testRefreshRate), mockp2p.New())
recipient.SetAccounting(observer)
err := recipient.Init(context.Background(), peer)
if err != nil {
t.Fatal(err)
}
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(recipient.Protocol()), streamtest.WithProtocols(recipient.Protocol()),
...@@ -82,13 +117,20 @@ func TestPayment(t *testing.T) { ...@@ -82,13 +117,20 @@ func TestPayment(t *testing.T) {
storePayer := mock.NewStateStore() storePayer := mock.NewStateStore()
defer storePayer.Close() defer storePayer.Close()
observer2 := newTestObserver() observer2 := newTestObserver(map[string]*big.Int{}, map[string]*big.Int{peerID.String(): big.NewInt(debt)})
payer := pseudosettle.New(recorder, logger, storePayer, observer2) payer := pseudosettle.New(recorder, logger, storePayer, observer2, big.NewInt(testRefreshRate), mockp2p.New())
payer.SetAccountingAPI(observer2) payer.SetAccounting(observer2)
amount := big.NewInt(debt)
amount := big.NewInt(10000) acceptedAmount, _, err := payer.Pay(context.Background(), peerID, amount, amount)
if err != nil {
t.Fatal(err)
}
payer.Pay(context.Background(), peerID, amount) if acceptedAmount.Cmp(amount) != 0 {
t.Fatalf("full amount not accepted. wanted %d, got %d", amount, acceptedAmount)
}
records, err := recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle") records, err := recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle")
if err != nil { if err != nil {
...@@ -113,15 +155,28 @@ func TestPayment(t *testing.T) { ...@@ -113,15 +155,28 @@ func TestPayment(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if len(messages) != 1 { receivedMessages, err := protobuf.ReadMessages(
t.Fatalf("got %v messages, want %v", len(messages), 1) bytes.NewReader(record.Out()),
func() protobuf.Message { return new(pb.PaymentAck) },
)
if err != nil {
t.Fatal(err)
} }
sentAmount := messages[0].(*pb.Payment).Amount if len(messages) != 1 || len(receivedMessages) != 1 {
if sentAmount != amount.Uint64() { t.Fatalf("got %v/%v messages, want %v/%v", len(messages), len(receivedMessages), 1, 1)
}
sentAmount := big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount)
receivedAmount := big.NewInt(0).SetBytes(receivedMessages[0].(*pb.PaymentAck).Amount)
if sentAmount.Cmp(amount) != 0 {
t.Fatalf("got message with amount %v, want %v", sentAmount, amount) t.Fatalf("got message with amount %v, want %v", sentAmount, amount)
} }
if sentAmount.Cmp(receivedAmount) != 0 {
t.Fatalf("wrong settlement amount, got %v, want %v", receivedAmount, sentAmount)
}
select { select {
case call := <-observer.receivedCalled: case call := <-observer.receivedCalled:
if call.amount.Cmp(amount) != 0 { if call.amount.Cmp(amount) != 0 {
...@@ -136,8 +191,117 @@ func TestPayment(t *testing.T) { ...@@ -136,8 +191,117 @@ func TestPayment(t *testing.T) {
t.Fatal("expected observer to be called") t.Fatal("expected observer to be called")
} }
totalSent, err := payer.TotalSent(peerID)
if err != nil {
t.Fatal(err)
}
if totalSent.Cmp(sentAmount) != 0 {
t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentAmount)
}
totalReceived, err := recipient.TotalReceived(peerID)
if err != nil {
t.Fatal(err)
}
if totalReceived.Cmp(sentAmount) != 0 {
t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentAmount)
}
}
func TestTimeLimitedPayment(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
storeRecipient := mock.NewStateStore()
defer storeRecipient.Close()
peerID := swarm.MustParseHexAddress("9ee7add7")
peer := p2p.Peer{Address: peerID}
debt := testRefreshRate
observer := newTestObserver(map[string]*big.Int{peerID.String(): big.NewInt(debt)}, map[string]*big.Int{})
recipient := pseudosettle.New(nil, logger, storeRecipient, observer, big.NewInt(testRefreshRate), mockp2p.New())
recipient.SetAccounting(observer)
err := recipient.Init(context.Background(), peer)
if err != nil {
t.Fatal(err)
}
recorder := streamtest.New(
streamtest.WithProtocols(recipient.Protocol()),
streamtest.WithBaseAddr(peerID),
)
storePayer := mock.NewStateStore()
defer storePayer.Close()
observer2 := newTestObserver(map[string]*big.Int{}, map[string]*big.Int{peerID.String(): big.NewInt(debt)})
payer := pseudosettle.New(recorder, logger, storePayer, observer2, big.NewInt(testRefreshRate), mockp2p.New())
payer.SetAccounting(observer2)
payer.SetTime(int64(10000))
recipient.SetTime(int64(10000))
amount := big.NewInt(debt)
acceptedAmount, _, err := payer.Pay(context.Background(), peerID, amount, amount)
if err != nil {
t.Fatal(err)
}
if acceptedAmount.Cmp(amount) != 0 {
t.Fatalf("full amount not accepted. wanted %d, got %d", amount, acceptedAmount)
}
records, err := recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want %v", l, 1)
}
record := records[0]
if err := record.Err(); err != nil {
t.Fatalf("record error: %v", err)
}
messages, err := protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.Payment) },
)
if err != nil {
t.Fatal(err)
}
receivedMessages, err := protobuf.ReadMessages(
bytes.NewReader(record.Out()),
func() protobuf.Message { return new(pb.PaymentAck) },
)
if err != nil {
t.Fatal(err)
}
if len(messages) != 1 || len(receivedMessages) != 1 {
t.Fatalf("got %v/%v messages, want %v/%v", len(messages), len(receivedMessages), 1, 1)
}
sentAmount := big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount)
receivedAmount := big.NewInt(0).SetBytes(receivedMessages[0].(*pb.PaymentAck).Amount)
if sentAmount.Cmp(amount) != 0 {
t.Fatalf("got message with amount %v, want %v", sentAmount, amount)
}
if sentAmount.Cmp(receivedAmount) != 0 {
t.Fatalf("wrong settlement amount, got %v, want %v", receivedAmount, sentAmount)
}
select { select {
case call := <-observer2.sentCalled: case call := <-observer.receivedCalled:
if call.amount.Cmp(amount) != 0 { if call.amount.Cmp(amount) != 0 {
t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, amount) t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, amount)
} }
...@@ -145,9 +309,6 @@ func TestPayment(t *testing.T) { ...@@ -145,9 +309,6 @@ func TestPayment(t *testing.T) {
if !call.peer.Equal(peerID) { if !call.peer.Equal(peerID) {
t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID)
} }
if call.err != nil {
t.Fatalf("observer called with error. got %v want nil", call.err)
}
case <-time.After(time.Second): case <-time.After(time.Second):
t.Fatal("expected observer to be called") t.Fatal("expected observer to be called")
...@@ -158,7 +319,7 @@ func TestPayment(t *testing.T) { ...@@ -158,7 +319,7 @@ func TestPayment(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if totalSent.Cmp(new(big.Int).SetUint64(sentAmount)) != 0 { if totalSent.Cmp(sentAmount) != 0 {
t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentAmount) t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentAmount)
} }
...@@ -167,7 +328,453 @@ func TestPayment(t *testing.T) { ...@@ -167,7 +328,453 @@ func TestPayment(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if totalReceived.Cmp(new(big.Int).SetUint64(sentAmount)) != 0 { if totalReceived.Cmp(sentAmount) != 0 {
t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentAmount) t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentAmount)
} }
sentSum := big.NewInt(testRefreshRate)
// Let 3 seconds pass, attempt settlement below time based refreshment rate
debt = testRefreshRate * 3 / 2
amount = big.NewInt(debt)
payer.SetTime(int64(10003))
recipient.SetTime(int64(10003))
observer.setPeerDebt(peerID, amount)
acceptedAmount, _, err = payer.Pay(context.Background(), peerID, amount, amount)
if err != nil {
t.Fatal(err)
}
if acceptedAmount.Cmp(amount) != 0 {
t.Fatalf("full amount not accepted. wanted %d, got %d", amount, acceptedAmount)
}
sentSum = sentSum.Add(sentSum, amount)
records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 2 {
t.Fatalf("got %v records, want %v", l, 2)
}
record = records[1]
if err := record.Err(); err != nil {
t.Fatalf("record error: %v", err)
}
messages, err = protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.Payment) },
)
if err != nil {
t.Fatal(err)
}
receivedMessages, err = protobuf.ReadMessages(
bytes.NewReader(record.Out()),
func() protobuf.Message { return new(pb.PaymentAck) },
)
if err != nil {
t.Fatal(err)
}
if len(messages) != 1 || len(receivedMessages) != 1 {
t.Fatalf("got %v/%v messages, want %v/%v", len(messages), len(receivedMessages), 1, 1)
}
sentAmount = big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount)
receivedAmount = big.NewInt(0).SetBytes(receivedMessages[0].(*pb.PaymentAck).Amount)
if sentAmount.Cmp(amount) != 0 {
t.Fatalf("got message with amount %v, want %v", sentAmount, amount)
}
if sentAmount.Cmp(receivedAmount) != 0 {
t.Fatalf("wrong settlement amount, got %v, want %v", receivedAmount, sentAmount)
}
select {
case call := <-observer.receivedCalled:
if call.amount.Cmp(receivedAmount) != 0 {
t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, amount)
}
if !call.peer.Equal(peerID) {
t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID)
}
case <-time.After(time.Second):
t.Fatal("expected observer to be called")
}
totalSent, err = payer.TotalSent(peerID)
if err != nil {
t.Fatal(err)
}
if totalSent.Cmp(sentSum) != 0 {
t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentSum)
}
totalReceived, err = recipient.TotalReceived(peerID)
if err != nil {
t.Fatal(err)
}
if totalReceived.Cmp(sentSum) != 0 {
t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentSum)
}
// attempt settlement over the time-based allowed limit 1 seconds later
debt = 3 * testRefreshRate
amount = big.NewInt(debt)
payer.SetTime(int64(10004))
recipient.SetTime(int64(10004))
observer.setPeerDebt(peerID, amount)
acceptedAmount, _, err = payer.Pay(context.Background(), peerID, amount, amount)
if err != nil {
t.Fatal(err)
}
testRefreshRateBigInt := big.NewInt(testRefreshRate)
if acceptedAmount.Cmp(testRefreshRateBigInt) != 0 {
t.Fatalf("full amount not accepted. wanted %d, got %d", amount, testRefreshRateBigInt)
}
sentSum = sentSum.Add(sentSum, testRefreshRateBigInt)
records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 3 {
t.Fatalf("got %v records, want %v", l, 3)
}
record = records[2]
if err := record.Err(); err != nil {
t.Fatalf("record error: %v", err)
}
messages, err = protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.Payment) },
)
if err != nil {
t.Fatal(err)
}
receivedMessages, err = protobuf.ReadMessages(
bytes.NewReader(record.Out()),
func() protobuf.Message { return new(pb.PaymentAck) },
)
if err != nil {
t.Fatal(err)
}
if len(messages) != 1 || len(receivedMessages) != 1 {
t.Fatalf("got %v/%v messages, want %v/%v", len(messages), len(receivedMessages), 1, 1)
}
sentAmount = big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount)
receivedAmount = big.NewInt(0).SetBytes(receivedMessages[0].(*pb.PaymentAck).Amount)
if sentAmount.Cmp(amount) != 0 {
t.Fatalf("got message with amount %v, want %v", sentAmount, amount)
}
if receivedAmount.Cmp(testRefreshRateBigInt) != 0 {
t.Fatalf("wrong settlement amount, got %v, want %v", receivedAmount, testRefreshRateBigInt)
}
select {
case call := <-observer.receivedCalled:
if call.amount.Cmp(testRefreshRateBigInt) != 0 {
t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, testRefreshRate)
}
if !call.peer.Equal(peerID) {
t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID)
}
case <-time.After(time.Second):
t.Fatal("expected observer to be called")
}
totalSent, err = payer.TotalSent(peerID)
if err != nil {
t.Fatal(err)
}
if totalSent.Cmp(sentSum) != 0 {
t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentSum)
}
totalReceived, err = recipient.TotalReceived(peerID)
if err != nil {
t.Fatal(err)
}
if totalReceived.Cmp(sentSum) != 0 {
t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentSum)
}
// attempt settle again in the same second without success
debt = 4 * testRefreshRate
amount = big.NewInt(debt)
observer.setPeerDebt(peerID, amount)
_, _, err = payer.Pay(context.Background(), peerID, amount, amount)
if !errors.Is(err, pseudosettle.ErrSettlementTooSoon) {
t.Fatal("sent settlement too soon")
}
records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 3 {
t.Fatalf("got %v records, want %v", l, 3)
}
select {
case <-observer.receivedCalled:
t.Fatal("unexpected observer to be called")
case <-time.After(time.Second):
}
// attempt again while recipient is still supposed to be blocking based on time
debt = 2 * testRefreshRate
amount = big.NewInt(debt)
payer.SetTime(int64(10005))
recipient.SetTime(int64(10004))
observer.setPeerDebt(peerID, amount)
_, _, err = payer.Pay(context.Background(), peerID, amount, amount)
if err == nil {
t.Fatal("expected error")
}
records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 4 {
t.Fatalf("got %v records, want %v", l, 4)
}
select {
case <-observer.receivedCalled:
t.Fatal("unexpected observer to be called")
case <-time.After(time.Second):
}
// attempt multiple seconds later with debt over time based allowance
debt = 9 * testRefreshRate
amount = big.NewInt(debt)
payer.SetTime(int64(10010))
recipient.SetTime(int64(10010))
observer.setPeerDebt(peerID, amount)
acceptedAmount, _, err = payer.Pay(context.Background(), peerID, amount, amount)
if err != nil {
t.Fatal(err)
}
testAmount := big.NewInt(6 * testRefreshRate)
if acceptedAmount.Cmp(testAmount) != 0 {
t.Fatalf("incorrect amount accepted. wanted %d, got %d", testAmount, acceptedAmount)
}
sentSum = sentSum.Add(sentSum, big.NewInt(6*testRefreshRate))
records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 5 {
t.Fatalf("got %v records, want %v", l, 5)
}
record = records[4]
if err := record.Err(); err != nil {
t.Fatalf("record error: %v", err)
}
messages, err = protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.Payment) },
)
if err != nil {
t.Fatal(err)
}
receivedMessages, err = protobuf.ReadMessages(
bytes.NewReader(record.Out()),
func() protobuf.Message { return new(pb.PaymentAck) },
)
if err != nil {
t.Fatal(err)
}
if len(messages) != 1 || len(receivedMessages) != 1 {
t.Fatalf("got %v/%v messages, want %v/%v", len(messages), len(receivedMessages), 1, 1)
}
sentAmount = big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount)
receivedAmount = big.NewInt(0).SetBytes(receivedMessages[0].(*pb.PaymentAck).Amount)
if sentAmount.Cmp(amount) != 0 {
t.Fatalf("got message with amount %v, want %v", sentAmount, amount)
}
if receivedAmount.Cmp(testAmount) != 0 {
t.Fatalf("wrong settlement amount, got %v, want %v", receivedAmount, testAmount)
}
select {
case call := <-observer.receivedCalled:
if call.amount.Cmp(testAmount) != 0 {
t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, testAmount)
}
if !call.peer.Equal(peerID) {
t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID)
}
case <-time.After(time.Second):
t.Fatal("expected observer to be called")
}
totalSent, err = payer.TotalSent(peerID)
if err != nil {
t.Fatal(err)
}
if totalSent.Cmp(sentSum) != 0 {
t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentSum)
}
totalReceived, err = recipient.TotalReceived(peerID)
if err != nil {
t.Fatal(err)
}
if totalReceived.Cmp(sentSum) != 0 {
t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentSum)
}
// attempt further settlement with less outstanding debt than time allowance would allow
debt = 5 * testRefreshRate
amount = big.NewInt(debt)
payer.SetTime(int64(10020))
recipient.SetTime(int64(10020))
observer.setPeerDebt(peerID, amount)
acceptedAmount, _, err = payer.Pay(context.Background(), peerID, amount, amount)
if err != nil {
t.Fatal(err)
}
testAmount = big.NewInt(5 * testRefreshRate)
if acceptedAmount.Cmp(testAmount) != 0 {
t.Fatalf("incorrect amount accepted. wanted %d, got %d", testAmount, acceptedAmount)
}
sentSum = sentSum.Add(sentSum, big.NewInt(5*testRefreshRate))
records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 6 {
t.Fatalf("got %v records, want %v", l, 5)
}
record = records[5]
if err := record.Err(); err != nil {
t.Fatalf("record error: %v", err)
}
messages, err = protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.Payment) },
)
if err != nil {
t.Fatal(err)
}
if len(messages) != 1 {
t.Fatalf("got %v messages, want %v", len(messages), 1)
}
sentAmount = big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount)
if sentAmount.Cmp(testAmount) != 0 {
t.Fatalf("got message with amount %v, want %v", sentAmount, testAmount)
}
select {
case call := <-observer.receivedCalled:
if call.amount.Cmp(testAmount) != 0 {
t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, testAmount)
}
if !call.peer.Equal(peerID) {
t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID)
}
case <-time.After(time.Second):
t.Fatal("expected observer to be called")
}
totalSent, err = payer.TotalSent(peerID)
if err != nil {
t.Fatal(err)
}
if totalSent.Cmp(sentSum) != 0 {
t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentSum)
}
totalReceived, err = recipient.TotalReceived(peerID)
if err != nil {
t.Fatal(err)
}
if totalReceived.Cmp(sentSum) != 0 {
t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentSum)
}
} }
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethersphere/bee/pkg/settlement"
"github.com/ethersphere/bee/pkg/settlement/swap" "github.com/ethersphere/bee/pkg/settlement/swap"
"github.com/ethersphere/bee/pkg/settlement/swap/chequebook" "github.com/ethersphere/bee/pkg/settlement/swap/chequebook"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -120,7 +119,7 @@ func WithCashoutStatusFunc(f func(ctx context.Context, peer swarm.Address) (*che ...@@ -120,7 +119,7 @@ func WithCashoutStatusFunc(f func(ctx context.Context, peer swarm.Address) (*che
} }
// New creates the mock swap implementation // New creates the mock swap implementation
func New(opts ...Option) settlement.Interface { func New(opts ...Option) swap.Interface {
mock := new(Service) mock := new(Service)
mock.settlementsSent = make(map[string]*big.Int) mock.settlementsSent = make(map[string]*big.Int)
mock.settlementsRecv = make(map[string]*big.Int) mock.settlementsRecv = make(map[string]*big.Int)
...@@ -130,14 +129,6 @@ func New(opts ...Option) settlement.Interface { ...@@ -130,14 +129,6 @@ func New(opts ...Option) settlement.Interface {
return mock return mock
} }
func NewApiInterface(opts ...Option) swap.ApiInterface {
mock := new(Service)
for _, o := range opts {
o.apply(mock)
}
return mock
}
// ReceiveCheque is the mock ReceiveCheque function of swap. // ReceiveCheque is the mock ReceiveCheque function of swap.
func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque *chequebook.SignedCheque) (err error) { func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque *chequebook.SignedCheque) (err error) {
if s.receiveChequeFunc != nil { if s.receiveChequeFunc != nil {
......
...@@ -30,7 +30,8 @@ var ( ...@@ -30,7 +30,8 @@ var (
ErrUnknownBeneficary = errors.New("unknown beneficiary for peer") ErrUnknownBeneficary = errors.New("unknown beneficiary for peer")
) )
type ApiInterface interface { type Interface interface {
settlement.Interface
// LastSentCheque returns the last sent cheque for the peer // LastSentCheque returns the last sent cheque for the peer
LastSentCheque(peer swarm.Address) (*chequebook.SignedCheque, error) LastSentCheque(peer swarm.Address) (*chequebook.SignedCheque, error)
// LastSentCheques returns the list of last sent cheques for all peers // LastSentCheques returns the list of last sent cheques for all peers
...@@ -50,7 +51,7 @@ type Service struct { ...@@ -50,7 +51,7 @@ type Service struct {
proto swapprotocol.Interface proto swapprotocol.Interface
logger logging.Logger logger logging.Logger
store storage.StateStorer store storage.StateStorer
accountingAPI settlement.AccountingAPI accounting settlement.Accounting
metrics metrics metrics metrics
chequebook chequebook.Service chequebook chequebook.Service
chequeStore chequebook.ChequeStore chequeStore chequebook.ChequeStore
...@@ -61,7 +62,7 @@ type Service struct { ...@@ -61,7 +62,7 @@ type Service struct {
} }
// New creates a new swap Service. // New creates a new swap Service.
func New(proto swapprotocol.Interface, logger logging.Logger, store storage.StateStorer, chequebook chequebook.Service, chequeStore chequebook.ChequeStore, addressbook Addressbook, networkID uint64, cashout chequebook.CashoutService, p2pService p2p.Service, accountingAPI settlement.AccountingAPI) *Service { func New(proto swapprotocol.Interface, logger logging.Logger, store storage.StateStorer, chequebook chequebook.Service, chequeStore chequebook.ChequeStore, addressbook Addressbook, networkID uint64, cashout chequebook.CashoutService, p2pService p2p.Service, accounting settlement.Accounting) *Service {
return &Service{ return &Service{
proto: proto, proto: proto,
logger: logger, logger: logger,
...@@ -73,7 +74,7 @@ func New(proto swapprotocol.Interface, logger logging.Logger, store storage.Stat ...@@ -73,7 +74,7 @@ func New(proto swapprotocol.Interface, logger logging.Logger, store storage.Stat
networkID: networkID, networkID: networkID,
cashout: cashout, cashout: cashout,
p2pService: p2pService, p2pService: p2pService,
accountingAPI: accountingAPI, accounting: accounting,
} }
} }
...@@ -101,10 +102,11 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque ...@@ -101,10 +102,11 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque
} }
} }
s.metrics.TotalReceived.Add(float64(amount.Uint64())) tot, _ := big.NewFloat(0).SetInt(amount).Float64()
s.metrics.TotalReceived.Add(tot)
s.metrics.ChequesReceived.Inc() s.metrics.ChequesReceived.Inc()
return s.accountingAPI.NotifyPaymentReceived(peer, amount) return s.accounting.NotifyPaymentReceived(peer, amount)
} }
// Pay initiates a payment to the given peer // Pay initiates a payment to the given peer
...@@ -112,7 +114,7 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) ...@@ -112,7 +114,7 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int)
var err error var err error
defer func() { defer func() {
if err != nil { if err != nil {
s.accountingAPI.NotifyPaymentSent(peer, nil, err) s.accounting.NotifyPaymentSent(peer, amount, err)
} }
}() }()
beneficiary, known, err := s.addressbook.Beneficiary(peer) beneficiary, known, err := s.addressbook.Beneficiary(peer)
...@@ -136,14 +138,14 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) ...@@ -136,14 +138,14 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int)
} }
bal, _ := big.NewFloat(0).SetInt(balance).Float64() bal, _ := big.NewFloat(0).SetInt(balance).Float64()
s.metrics.AvailableBalance.Set(bal) s.metrics.AvailableBalance.Set(bal)
s.accountingAPI.NotifyPaymentSent(peer, amount, nil) s.accounting.NotifyPaymentSent(peer, amount, nil)
amountFloat, _ := big.NewFloat(0).SetInt(amount).Float64() amountFloat, _ := big.NewFloat(0).SetInt(amount).Float64()
s.metrics.TotalSent.Add(amountFloat) s.metrics.TotalSent.Add(amountFloat)
s.metrics.ChequesSent.Inc() s.metrics.ChequesSent.Inc()
} }
func (s *Service) SetAccountingAPI(accountingAPI settlement.AccountingAPI) { func (s *Service) SetAccounting(accounting settlement.Accounting) {
s.accountingAPI = accountingAPI s.accounting = accounting
} }
// TotalSent returns the total amount sent to a peer // TotalSent returns the total amount sent to a peer
......
...@@ -70,6 +70,10 @@ func (t *testObserver) NotifyPaymentReceived(peer swarm.Address, amount *big.Int ...@@ -70,6 +70,10 @@ func (t *testObserver) NotifyPaymentReceived(peer swarm.Address, amount *big.Int
return nil return nil
} }
func (t *testObserver) NotifyRefreshmentReceived(peer swarm.Address, amount *big.Int) error {
return nil
}
func (t *testObserver) NotifyPaymentSent(peer swarm.Address, amount *big.Int, err error) { func (t *testObserver) NotifyPaymentSent(peer swarm.Address, amount *big.Int, err error) {
t.sentCalled <- notifyPaymentSentCall{ t.sentCalled <- notifyPaymentSentCall{
peer: peer, peer: peer,
...@@ -423,7 +427,7 @@ func TestPayIssueError(t *testing.T) { ...@@ -423,7 +427,7 @@ func TestPayIssueError(t *testing.T) {
) )
observer := newTestObserver() observer := newTestObserver()
swap.SetAccountingAPI(observer) swap.SetAccounting(observer)
swap.Pay(context.Background(), peer, amount) swap.Pay(context.Background(), peer, amount)
select { select {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment