Commit 7a8dea9a authored by Peter Mrekaj's avatar Peter Mrekaj Committed by GitHub

refactor(node): simplify bee shutdown repetitive io.Closer procedure (#1808)

parent 84158764
...@@ -67,6 +67,7 @@ import ( ...@@ -67,6 +67,7 @@ import (
"github.com/ethersphere/bee/pkg/topology/lightnode" "github.com/ethersphere/bee/pkg/topology/lightnode"
"github.com/ethersphere/bee/pkg/tracing" "github.com/ethersphere/bee/pkg/tracing"
"github.com/ethersphere/bee/pkg/traversal" "github.com/ethersphere/bee/pkg/traversal"
"github.com/hashicorp/go-multierror"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
...@@ -682,13 +683,20 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -682,13 +683,20 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
} }
func (b *Bee) Shutdown(ctx context.Context) error { func (b *Bee) Shutdown(ctx context.Context) error {
errs := new(multiError) var mErr error
if b.apiCloser != nil { // tryClose is a convenient closure which decrease
if err := b.apiCloser.Close(); err != nil { // repetitive io.Closer tryClose procedure.
errs.add(fmt.Errorf("api: %w", err)) tryClose := func(c io.Closer, errMsg string) {
if c == nil {
return
} }
if err := c.Close(); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("%s: %w", errMsg, err))
} }
}
tryClose(b.apiCloser, "api")
var eg errgroup.Group var eg errgroup.Group
if b.apiServer != nil { if b.apiServer != nil {
...@@ -709,113 +717,37 @@ func (b *Bee) Shutdown(ctx context.Context) error { ...@@ -709,113 +717,37 @@ func (b *Bee) Shutdown(ctx context.Context) error {
} }
if err := eg.Wait(); err != nil { if err := eg.Wait(); err != nil {
errs.add(err) mErr = multierror.Append(mErr, err)
} }
if b.recoveryHandleCleanup != nil { if b.recoveryHandleCleanup != nil {
b.recoveryHandleCleanup() b.recoveryHandleCleanup()
} }
if err := b.pusherCloser.Close(); err != nil { tryClose(b.pusherCloser, "pusher")
errs.add(fmt.Errorf("pusher: %w", err)) tryClose(b.pullerCloser, "puller")
} tryClose(b.pullSyncCloser, "pull sync")
tryClose(b.pssCloser, "pss")
if err := b.pullerCloser.Close(); err != nil {
errs.add(fmt.Errorf("puller: %w", err))
}
if err := b.pullSyncCloser.Close(); err != nil {
errs.add(fmt.Errorf("pull sync: %w", err))
}
if err := b.pssCloser.Close(); err != nil {
errs.add(fmt.Errorf("pss: %w", err))
}
b.p2pCancel() b.p2pCancel()
if err := b.p2pService.Close(); err != nil { tryClose(b.p2pService, "p2p server")
errs.add(fmt.Errorf("p2p server: %w", err)) tryClose(b.transactionMonitorCloser, "transaction monitor")
}
if b.transactionMonitorCloser != nil {
if err := b.transactionMonitorCloser.Close(); err != nil {
errs.add(fmt.Errorf("transaction monitor: %w", err))
}
}
if c := b.ethClientCloser; c != nil { if c := b.ethClientCloser; c != nil {
c() c()
} }
if err := b.tracerCloser.Close(); err != nil { tryClose(b.tracerCloser, "tracer")
errs.add(fmt.Errorf("tracer: %w", err)) tryClose(b.tagsCloser, "tag persistence")
} tryClose(b.listenerCloser, "listener")
tryClose(b.postageServiceCloser, "postage service")
if err := b.tagsCloser.Close(); err != nil { tryClose(b.stateStoreCloser, "statestore")
errs.add(fmt.Errorf("tag persistence: %w", err)) tryClose(b.localstoreCloser, "localstore")
} tryClose(b.topologyCloser, "topology driver")
tryClose(b.errorLogWriter, "error log writer")
if b.listenerCloser != nil { tryClose(b.resolverCloser, "resolver service")
if err := b.listenerCloser.Close(); err != nil {
errs.add(fmt.Errorf("listener: %w", err))
}
}
if err := b.postageServiceCloser.Close(); err != nil {
errs.add(fmt.Errorf("postage service: %w", err))
}
if err := b.stateStoreCloser.Close(); err != nil {
errs.add(fmt.Errorf("statestore: %w", err))
}
if err := b.localstoreCloser.Close(); err != nil {
errs.add(fmt.Errorf("localstore: %w", err))
}
if err := b.topologyCloser.Close(); err != nil {
errs.add(fmt.Errorf("topology driver: %w", err))
}
if err := b.errorLogWriter.Close(); err != nil {
errs.add(fmt.Errorf("error log writer: %w", err))
}
// Shutdown the resolver service only if it has been initialized.
if b.resolverCloser != nil {
if err := b.resolverCloser.Close(); err != nil {
errs.add(fmt.Errorf("resolver service: %w", err))
}
}
if errs.hasErrors() {
return errs
}
return nil
}
type multiError struct {
errors []error
}
func (e *multiError) Error() string {
if len(e.errors) == 0 {
return ""
}
s := e.errors[0].Error()
for _, err := range e.errors[1:] {
s += "; " + err.Error()
}
return s
}
func (e *multiError) add(err error) {
e.errors = append(e.errors, err)
}
func (e *multiError) hasErrors() bool { return mErr
return len(e.errors) > 0
} }
func getTxHash(stateStore storage.StateStorer, logger logging.Logger, o Options) ([]byte, error) { func getTxHash(stateStore storage.StateStorer, logger logging.Logger, o Options) ([]byte, error) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment