Commit ae312818 authored by Ralph Pichler's avatar Ralph Pichler Committed by GitHub

call notifypayment in go routine in swap and restore full close (#953)

parent 9e237f07
...@@ -601,8 +601,7 @@ func surplusBalanceKeyPeer(key []byte) (swarm.Address, error) { ...@@ -601,8 +601,7 @@ func surplusBalanceKeyPeer(key []byte) (swarm.Address, error) {
return addr, nil return addr, nil
} }
// NotifyPayment implements the PaymentObserver interface. It is called by // NotifyPayment is called by Settlement when we receive a payment.
// Settlement when we receive a payment.
func (a *Accounting) NotifyPayment(peer swarm.Address, amount uint64) error { func (a *Accounting) NotifyPayment(peer swarm.Address, amount uint64) error {
accountingPeer, err := a.getAccountingPeer(peer) accountingPeer, err := a.getAccountingPeer(peer)
if err != nil { if err != nil {
...@@ -690,6 +689,18 @@ func (a *Accounting) NotifyPayment(peer swarm.Address, amount uint64) error { ...@@ -690,6 +689,18 @@ func (a *Accounting) NotifyPayment(peer swarm.Address, amount uint64) error {
return nil return nil
} }
// AsyncNotifyPayment calls notify payment in a go routine.
// This is needed when accounting needs to be notified but the accounting lock is already held.
func (a *Accounting) AsyncNotifyPayment(peer swarm.Address, amount uint64) error {
go func() {
err := a.NotifyPayment(peer, amount)
if err != nil {
a.logger.Errorf("failed to notify accounting of payment: %v", err)
}
}()
return nil
}
// subtractI64mU64 is a helper function for safe subtraction of Int64 - Uint64 // subtractI64mU64 is a helper function for safe subtraction of Int64 - Uint64
// It checks for // It checks for
// - overflow safety in conversion of uint64 to int64 // - overflow safety in conversion of uint64 to int64
......
...@@ -299,7 +299,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -299,7 +299,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
return nil, fmt.Errorf("accounting: %w", err) return nil, fmt.Errorf("accounting: %w", err)
} }
settlement.SetPaymentObserver(acc) settlement.SetNotifyPaymentFunc(acc.AsyncNotifyPayment)
pricing.SetPaymentThresholdObserver(acc) pricing.SetPaymentThresholdObserver(acc)
kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, Standalone: o.Standalone}) kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, Standalone: o.Standalone})
......
...@@ -28,12 +28,9 @@ type Interface interface { ...@@ -28,12 +28,9 @@ type Interface interface {
SettlementsSent() (map[string]uint64, error) SettlementsSent() (map[string]uint64, error)
// SettlementsReceived returns received settlements for each individual known peer // SettlementsReceived returns received settlements for each individual known peer
SettlementsReceived() (map[string]uint64, error) SettlementsReceived() (map[string]uint64, error)
// SetPaymentObserver sets the PaymentObserver to notify // SetNotifyPaymentFunc sets the NotifyPaymentFunc to notify
SetPaymentObserver(observer PaymentObserver) SetNotifyPaymentFunc(notifyPaymentFunc NotifyPaymentFunc)
} }
// PaymentObserver is the interface Settlement uses to notify other components of an incoming payment // NotifyPaymentFunc is called when a payment from peer was successfully received
type PaymentObserver interface { type NotifyPaymentFunc func(peer swarm.Address, amount uint64) error
// NotifyPayment is called when a payment from peer was successfully received
NotifyPayment(peer swarm.Address, amount uint64) error
}
...@@ -115,7 +115,7 @@ func (s *Service) SettlementsReceived() (map[string]uint64, error) { ...@@ -115,7 +115,7 @@ func (s *Service) SettlementsReceived() (map[string]uint64, error) {
return s.settlementsRecv, nil return s.settlementsRecv, nil
} }
func (s *Service) SetPaymentObserver(settlement.PaymentObserver) { func (s *Service) SetNotifyPaymentFunc(settlement.NotifyPaymentFunc) {
} }
// Option is the option passed to the mock settlement service // Option is the option passed to the mock settlement service
......
...@@ -35,7 +35,7 @@ type Service struct { ...@@ -35,7 +35,7 @@ type Service struct {
streamer p2p.Streamer streamer p2p.Streamer
logger logging.Logger logger logging.Logger
store storage.StateStorer store storage.StateStorer
observer settlement.PaymentObserver notifyPaymentFunc settlement.NotifyPaymentFunc
metrics metrics metrics metrics
} }
...@@ -105,7 +105,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -105,7 +105,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
return err return err
} }
return s.observer.NotifyPayment(p.Address, req.Amount) return s.notifyPaymentFunc(p.Address, req.Amount)
} }
// Pay initiates a payment to the given peer // Pay initiates a payment to the given peer
...@@ -148,9 +148,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er ...@@ -148,9 +148,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er
return nil return nil
} }
// SetPaymentObserver sets the payment observer which will be notified of incoming payments // SetNotifyPaymentFunc sets the NotifyPaymentFunc to notify
func (s *Service) SetPaymentObserver(observer settlement.PaymentObserver) { func (s *Service) SetNotifyPaymentFunc(notifyPaymentFunc settlement.NotifyPaymentFunc) {
s.observer = observer s.notifyPaymentFunc = notifyPaymentFunc
} }
// TotalSent returns the total amount sent to a peer // TotalSent returns the total amount sent to a peer
......
...@@ -47,7 +47,7 @@ func TestPayment(t *testing.T) { ...@@ -47,7 +47,7 @@ func TestPayment(t *testing.T) {
observer := newTestObserver() observer := newTestObserver()
recipient := pseudosettle.New(nil, logger, storeRecipient) recipient := pseudosettle.New(nil, logger, storeRecipient)
recipient.SetPaymentObserver(observer) recipient.SetNotifyPaymentFunc(observer.NotifyPayment)
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(recipient.Protocol()), streamtest.WithProtocols(recipient.Protocol()),
......
...@@ -24,7 +24,7 @@ type Service struct { ...@@ -24,7 +24,7 @@ type Service struct {
receiveChequeFunc func(context.Context, swarm.Address, *chequebook.SignedCheque) error receiveChequeFunc func(context.Context, swarm.Address, *chequebook.SignedCheque) error
payFunc func(context.Context, swarm.Address, uint64) error payFunc func(context.Context, swarm.Address, uint64) error
setPaymentObserverFunc func(observer settlement.PaymentObserver) setNotifyPaymentFunc settlement.NotifyPaymentFunc
handshakeFunc func(swarm.Address, common.Address) error handshakeFunc func(swarm.Address, common.Address) error
lastSentChequeFunc func(swarm.Address) (*chequebook.SignedCheque, error) lastSentChequeFunc func(swarm.Address) (*chequebook.SignedCheque, error)
lastSentChequesFunc func() (map[string]*chequebook.SignedCheque, error) lastSentChequesFunc func() (map[string]*chequebook.SignedCheque, error)
...@@ -75,9 +75,9 @@ func WithPayFunc(f func(context.Context, swarm.Address, uint64) error) Option { ...@@ -75,9 +75,9 @@ func WithPayFunc(f func(context.Context, swarm.Address, uint64) error) Option {
} }
// WithsettlementsFunc sets the mock settlements function // WithsettlementsFunc sets the mock settlements function
func WithSetPaymentObserverFunc(f func(observer settlement.PaymentObserver)) Option { func WithSetNotifyPaymentFunc(f settlement.NotifyPaymentFunc) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.setPaymentObserverFunc = f s.setNotifyPaymentFunc = f
}) })
} }
...@@ -156,10 +156,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er ...@@ -156,10 +156,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er
return nil return nil
} }
// SetPaymentObserver is the mock SetPaymentObserver function of swap. func (s *Service) SetNotifyPaymentFunc(f settlement.NotifyPaymentFunc) {
func (s *Service) SetPaymentObserver(observer settlement.PaymentObserver) { if s.setNotifyPaymentFunc != nil {
if s.setPaymentObserverFunc != nil { s.SetNotifyPaymentFunc(f)
s.setPaymentObserverFunc(observer)
} }
} }
......
...@@ -50,7 +50,7 @@ type Service struct { ...@@ -50,7 +50,7 @@ type Service struct {
proto swapprotocol.Interface proto swapprotocol.Interface
logger logging.Logger logger logging.Logger
store storage.StateStorer store storage.StateStorer
observer settlement.PaymentObserver notifyPaymentFunc settlement.NotifyPaymentFunc
metrics metrics metrics metrics
chequebook chequebook.Service chequebook chequebook.Service
chequeStore chequebook.ChequeStore chequeStore chequebook.ChequeStore
...@@ -102,7 +102,7 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque ...@@ -102,7 +102,7 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque
s.metrics.TotalReceived.Add(float64(amount.Uint64())) s.metrics.TotalReceived.Add(float64(amount.Uint64()))
return s.observer.NotifyPayment(peer, amount.Uint64()) return s.notifyPaymentFunc(peer, amount.Uint64())
} }
// Pay initiates a payment to the given peer // Pay initiates a payment to the given peer
...@@ -129,9 +129,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er ...@@ -129,9 +129,9 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount uint64) er
return nil return nil
} }
// SetPaymentObserver sets the payment observer which will be notified of incoming payments // SetNotifyPaymentFunc sets the NotifyPaymentFunc to notify
func (s *Service) SetPaymentObserver(observer settlement.PaymentObserver) { func (s *Service) SetNotifyPaymentFunc(notifyPaymentFunc settlement.NotifyPaymentFunc) {
s.observer = observer s.notifyPaymentFunc = notifyPaymentFunc
} }
// TotalSent returns the total amount sent to a peer // TotalSent returns the total amount sent to a peer
......
...@@ -144,7 +144,7 @@ func TestReceiveCheque(t *testing.T) { ...@@ -144,7 +144,7 @@ func TestReceiveCheque(t *testing.T) {
) )
observer := &testObserver{} observer := &testObserver{}
swap.SetPaymentObserver(observer) swap.SetNotifyPaymentFunc(observer.NotifyPayment)
err := swap.ReceiveCheque(context.Background(), peer, cheque) err := swap.ReceiveCheque(context.Background(), peer, cheque)
if err != nil { if err != nil {
...@@ -207,7 +207,7 @@ func TestReceiveChequeReject(t *testing.T) { ...@@ -207,7 +207,7 @@ func TestReceiveChequeReject(t *testing.T) {
) )
observer := &testObserver{} observer := &testObserver{}
swap.SetPaymentObserver(observer) swap.SetNotifyPaymentFunc(observer.NotifyPayment)
err := swap.ReceiveCheque(context.Background(), peer, cheque) err := swap.ReceiveCheque(context.Background(), peer, cheque)
if err == nil { if err == nil {
...@@ -259,7 +259,7 @@ func TestReceiveChequeWrongChequebook(t *testing.T) { ...@@ -259,7 +259,7 @@ func TestReceiveChequeWrongChequebook(t *testing.T) {
) )
observer := &testObserver{} observer := &testObserver{}
swapService.SetPaymentObserver(observer) swapService.SetNotifyPaymentFunc(observer.NotifyPayment)
err := swapService.ReceiveCheque(context.Background(), peer, cheque) err := swapService.ReceiveCheque(context.Background(), peer, cheque)
if err == nil { if err == nil {
......
...@@ -186,8 +186,7 @@ func (s *Service) EmitCheque(ctx context.Context, peer swarm.Address, cheque *ch ...@@ -186,8 +186,7 @@ func (s *Service) EmitCheque(ctx context.Context, peer swarm.Address, cheque *ch
if err != nil { if err != nil {
_ = stream.Reset() _ = stream.Reset()
} else { } else {
// don't wait for full close to avoid deadlocks if cheques are sent simultaneously in both directions _ = stream.FullClose()
go stream.FullClose()
} }
}() }()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment