Commit ae312818 authored by Ralph Pichler's avatar Ralph Pichler Committed by GitHub

call notifypayment in go routine in swap and restore full close (#953)

parent 9e237f07
...@@ -601,8 +601,7 @@ func surplusBalanceKeyPeer(key []byte) (swarm.Address, error) { ...@@ -601,8 +601,7 @@ func surplusBalanceKeyPeer(key []byte) (swarm.Address, error) {
return addr, nil return addr, nil
} }
// NotifyPayment implements the PaymentObserver interface. It is called by // NotifyPayment is called by Settlement when we receive a payment.
// Settlement when we receive a payment.
func (a *Accounting) NotifyPayment(peer swarm.Address, amount uint64) error { func (a *Accounting) NotifyPayment(peer swarm.Address, amount uint64) error {
accountingPeer, err := a.getAccountingPeer(peer) accountingPeer, err := a.getAccountingPeer(peer)
if err != nil { if err != nil {
...@@ -690,6 +689,18 @@ func (a *Accounting) NotifyPayment(peer swarm.Address, amount uint64) error { ...@@ -690,6 +689,18 @@ func (a *Accounting) NotifyPayment(peer swarm.Address, amount uint64) error {
return nil return nil
} }
// AsyncNotifyPayment calls notify payment in a go routine.
// This is needed when accounting needs to be notified but the accounting lock is already held.
func (a *Accounting) AsyncNotifyPayment(peer swarm.Address, amount uint64) error {
go func() {
err := a.NotifyPayment(peer, amount)
if err != nil {
a.logger.Errorf("failed to notify accounting of payment: %v", err)
}
}()
return nil
}
// subtractI64mU64 is a helper function for safe subtraction of Int64 - Uint64 // subtractI64mU64 is a helper function for safe subtraction of Int64 - Uint64
// It checks for // It checks for
// - overflow safety in conversion of uint64 to int64 // - overflow safety in conversion of uint64 to int64
......
...@@ -299,7 +299,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -299,7 +299,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
return nil, fmt.Errorf("accounting: %w", err) return nil, fmt.Errorf("accounting: %w", err)
} }
settlement.SetPaymentObserver(acc) settlement.SetNotifyPaymentFunc(acc.AsyncNotifyPayment)
pricing.SetPaymentThresholdObserver(acc) pricing.SetPaymentThresholdObserver(acc)
kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, Standalone: o.Standalone}) kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, Standalone: o.Standalone})
......
...@@ -28,12 +28,9 @@ type Interface interface { ...@@ -28,12 +28,9 @@ type Interface interface {
SettlementsSent() (map[string]uint64, error) SettlementsSent() (map[string]uint64, error)
// SettlementsReceived returns received settlements for each individual known peer // SettlementsReceived returns received settlements for each individual known peer
SettlementsReceived() (map[string]uint64, error) SettlementsReceived() (map[string]uint64, error)
// SetPaymentObserver sets the PaymentObserver to notify // SetNotifyPaymentFunc sets the NotifyPaymentFunc to notify
SetPaymentObserver(observer PaymentObserver) SetNotifyPaymentFunc(notifyPaymentFunc NotifyPaymentFunc)
} }
// PaymentObserver is the interface Settlement uses to notify other components of an incoming payment // NotifyPaymentFunc is called when a payment from peer was successfully received
type PaymentObserver interface { type NotifyPaymentFunc func(peer swarm.Address, amount uint64) error
// NotifyPayment is called when a payment from peer was successfully received
NotifyPayment(peer swarm.Address, amount uint64) error
}
...@@ -115,7 +115,7 @@ func (s *Service) SettlementsReceived() (map[string]uint64, error) { ...@@ -115,7 +115,7 @@ func (s *Service) SettlementsReceived() (map[string]uint64, error) {
return s.settlementsRecv, nil return s.settlementsRecv, nil
} }
func (s *Service) SetPaymentObserver(settlement.PaymentObserver) { func (s *Service) SetNotifyPaymentFunc(settlement.NotifyPaymentFunc) {
} }
// Option is the option passed to the mock settlement service // Option is the option passed to the mock settlement service
......
...@@ -32,11 +32,11 @@ var ( ...@@ -32,11 +32,11 @@ var (
) )
type Service struct { type Service struct {
streamer p2p.Streamer streamer p2p.Streamer
logger logging.Logger logger logging.Logger
store storage.StateStorer store storage.StateStorer
observer settlement.PaymentObserver notifyPaymentFunc settlement.NotifyPaymentFunc
metrics metrics metrics metrics
} }
func New(streamer p2p.Streamer, logger logging.Logger, store storage.StateStorer) *Service { func New(streamer p2p.Streamer, logger logging.Logger, store storage.StateStorer) *Service {
...@@ -105,7 +105,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -105,7 +105,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
return err return err
} }
return s.observer.NotifyPayment(p.Address, req.Amount) return s.notifyPaymentFunc(p.Address, req.Amount)
} }
// Pay initiates a payment to the given peer // Pay initiates a payment to the given peer
...@@ -148,9 +148,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er ...@@ -148,9 +148,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er
return nil return nil
} }
// SetPaymentObserver sets the payment observer which will be notified of incoming payments // SetNotifyPaymentFunc sets the NotifyPaymentFunc to notify
func (s *Service) SetPaymentObserver(observer settlement.PaymentObserver) { func (s *Service) SetNotifyPaymentFunc(notifyPaymentFunc settlement.NotifyPaymentFunc) {
s.observer = observer s.notifyPaymentFunc = notifyPaymentFunc
} }
// TotalSent returns the total amount sent to a peer // TotalSent returns the total amount sent to a peer
......
...@@ -47,7 +47,7 @@ func TestPayment(t *testing.T) { ...@@ -47,7 +47,7 @@ func TestPayment(t *testing.T) {
observer := newTestObserver() observer := newTestObserver()
recipient := pseudosettle.New(nil, logger, storeRecipient) recipient := pseudosettle.New(nil, logger, storeRecipient)
recipient.SetPaymentObserver(observer) recipient.SetNotifyPaymentFunc(observer.NotifyPayment)
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(recipient.Protocol()), streamtest.WithProtocols(recipient.Protocol()),
......
...@@ -22,12 +22,12 @@ type Service struct { ...@@ -22,12 +22,12 @@ type Service struct {
settlementsSentFunc func() (map[string]uint64, error) settlementsSentFunc func() (map[string]uint64, error)
settlementsRecvFunc func() (map[string]uint64, error) settlementsRecvFunc func() (map[string]uint64, error)
receiveChequeFunc func(context.Context, swarm.Address, *chequebook.SignedCheque) error receiveChequeFunc func(context.Context, swarm.Address, *chequebook.SignedCheque) error
payFunc func(context.Context, swarm.Address, uint64) error payFunc func(context.Context, swarm.Address, uint64) error
setPaymentObserverFunc func(observer settlement.PaymentObserver) setNotifyPaymentFunc settlement.NotifyPaymentFunc
handshakeFunc func(swarm.Address, common.Address) error handshakeFunc func(swarm.Address, common.Address) error
lastSentChequeFunc func(swarm.Address) (*chequebook.SignedCheque, error) lastSentChequeFunc func(swarm.Address) (*chequebook.SignedCheque, error)
lastSentChequesFunc func() (map[string]*chequebook.SignedCheque, error) lastSentChequesFunc func() (map[string]*chequebook.SignedCheque, error)
lastReceivedChequeFunc func(swarm.Address) (*chequebook.SignedCheque, error) lastReceivedChequeFunc func(swarm.Address) (*chequebook.SignedCheque, error)
lastReceivedChequesFunc func() (map[string]*chequebook.SignedCheque, error) lastReceivedChequesFunc func() (map[string]*chequebook.SignedCheque, error)
...@@ -75,9 +75,9 @@ func WithPayFunc(f func(context.Context, swarm.Address, uint64) error) Option { ...@@ -75,9 +75,9 @@ func WithPayFunc(f func(context.Context, swarm.Address, uint64) error) Option {
} }
// WithsettlementsFunc sets the mock settlements function // WithsettlementsFunc sets the mock settlements function
func WithSetPaymentObserverFunc(f func(observer settlement.PaymentObserver)) Option { func WithSetNotifyPaymentFunc(f settlement.NotifyPaymentFunc) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.setPaymentObserverFunc = f s.setNotifyPaymentFunc = f
}) })
} }
...@@ -156,10 +156,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er ...@@ -156,10 +156,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er
return nil return nil
} }
// SetPaymentObserver is the mock SetPaymentObserver function of swap. func (s *Service) SetNotifyPaymentFunc(f settlement.NotifyPaymentFunc) {
func (s *Service) SetPaymentObserver(observer settlement.PaymentObserver) { if s.setNotifyPaymentFunc != nil {
if s.setPaymentObserverFunc != nil { s.SetNotifyPaymentFunc(f)
s.setPaymentObserverFunc(observer)
} }
} }
......
...@@ -47,17 +47,17 @@ type ApiInterface interface { ...@@ -47,17 +47,17 @@ type ApiInterface interface {
// Service is the implementation of the swap settlement layer. // Service is the implementation of the swap settlement layer.
type Service struct { type Service struct {
proto swapprotocol.Interface proto swapprotocol.Interface
logger logging.Logger logger logging.Logger
store storage.StateStorer store storage.StateStorer
observer settlement.PaymentObserver notifyPaymentFunc settlement.NotifyPaymentFunc
metrics metrics metrics metrics
chequebook chequebook.Service chequebook chequebook.Service
chequeStore chequebook.ChequeStore chequeStore chequebook.ChequeStore
cashout chequebook.CashoutService cashout chequebook.CashoutService
p2pService p2p.Service p2pService p2p.Service
addressbook Addressbook addressbook Addressbook
networkID uint64 networkID uint64
} }
// New creates a new swap Service. // New creates a new swap Service.
...@@ -102,7 +102,7 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque ...@@ -102,7 +102,7 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque
s.metrics.TotalReceived.Add(float64(amount.Uint64())) s.metrics.TotalReceived.Add(float64(amount.Uint64()))
return s.observer.NotifyPayment(peer, amount.Uint64()) return s.notifyPaymentFunc(peer, amount.Uint64())
} }
// Pay initiates a payment to the given peer // Pay initiates a payment to the given peer
...@@ -129,9 +129,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er ...@@ -129,9 +129,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er
return nil return nil
} }
// SetPaymentObserver sets the payment observer which will be notified of incoming payments // SetNotifyPaymentFunc sets the NotifyPaymentFunc to notify
func (s *Service) SetPaymentObserver(observer settlement.PaymentObserver) { func (s *Service) SetNotifyPaymentFunc(notifyPaymentFunc settlement.NotifyPaymentFunc) {
s.observer = observer s.notifyPaymentFunc = notifyPaymentFunc
} }
// TotalSent returns the total amount sent to a peer // TotalSent returns the total amount sent to a peer
......
...@@ -144,7 +144,7 @@ func TestReceiveCheque(t *testing.T) { ...@@ -144,7 +144,7 @@ func TestReceiveCheque(t *testing.T) {
) )
observer := &testObserver{} observer := &testObserver{}
swap.SetPaymentObserver(observer) swap.SetNotifyPaymentFunc(observer.NotifyPayment)
err := swap.ReceiveCheque(context.Background(), peer, cheque) err := swap.ReceiveCheque(context.Background(), peer, cheque)
if err != nil { if err != nil {
...@@ -207,7 +207,7 @@ func TestReceiveChequeReject(t *testing.T) { ...@@ -207,7 +207,7 @@ func TestReceiveChequeReject(t *testing.T) {
) )
observer := &testObserver{} observer := &testObserver{}
swap.SetPaymentObserver(observer) swap.SetNotifyPaymentFunc(observer.NotifyPayment)
err := swap.ReceiveCheque(context.Background(), peer, cheque) err := swap.ReceiveCheque(context.Background(), peer, cheque)
if err == nil { if err == nil {
...@@ -259,7 +259,7 @@ func TestReceiveChequeWrongChequebook(t *testing.T) { ...@@ -259,7 +259,7 @@ func TestReceiveChequeWrongChequebook(t *testing.T) {
) )
observer := &testObserver{} observer := &testObserver{}
swapService.SetPaymentObserver(observer) swapService.SetNotifyPaymentFunc(observer.NotifyPayment)
err := swapService.ReceiveCheque(context.Background(), peer, cheque) err := swapService.ReceiveCheque(context.Background(), peer, cheque)
if err == nil { if err == nil {
......
...@@ -186,8 +186,7 @@ func (s *Service) EmitCheque(ctx context.Context, peer swarm.Address, cheque *ch ...@@ -186,8 +186,7 @@ func (s *Service) EmitCheque(ctx context.Context, peer swarm.Address, cheque *ch
if err != nil { if err != nil {
_ = stream.Reset() _ = stream.Reset()
} else { } else {
// don't wait for full close to avoid deadlocks if cheques are sent simultaneously in both directions _ = stream.FullClose()
go stream.FullClose()
} }
}() }()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment