Commit 7a8dea9a authored by Peter Mrekaj's avatar Peter Mrekaj Committed by GitHub

refactor(node): simplify bee shutdown repetitive io.Closer procedure (#1808)

parent 84158764
......@@ -67,6 +67,7 @@ import (
"github.com/ethersphere/bee/pkg/topology/lightnode"
"github.com/ethersphere/bee/pkg/tracing"
"github.com/ethersphere/bee/pkg/traversal"
"github.com/hashicorp/go-multierror"
ma "github.com/multiformats/go-multiaddr"
"github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
......@@ -682,14 +683,21 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
}
func (b *Bee) Shutdown(ctx context.Context) error {
errs := new(multiError)
var mErr error
if b.apiCloser != nil {
if err := b.apiCloser.Close(); err != nil {
errs.add(fmt.Errorf("api: %w", err))
// tryClose is a convenient closure which decrease
// repetitive io.Closer tryClose procedure.
tryClose := func(c io.Closer, errMsg string) {
if c == nil {
return
}
if err := c.Close(); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("%s: %w", errMsg, err))
}
}
tryClose(b.apiCloser, "api")
var eg errgroup.Group
if b.apiServer != nil {
eg.Go(func() error {
......@@ -709,113 +717,37 @@ func (b *Bee) Shutdown(ctx context.Context) error {
}
if err := eg.Wait(); err != nil {
errs.add(err)
mErr = multierror.Append(mErr, err)
}
if b.recoveryHandleCleanup != nil {
b.recoveryHandleCleanup()
}
if err := b.pusherCloser.Close(); err != nil {
errs.add(fmt.Errorf("pusher: %w", err))
}
if err := b.pullerCloser.Close(); err != nil {
errs.add(fmt.Errorf("puller: %w", err))
}
if err := b.pullSyncCloser.Close(); err != nil {
errs.add(fmt.Errorf("pull sync: %w", err))
}
if err := b.pssCloser.Close(); err != nil {
errs.add(fmt.Errorf("pss: %w", err))
}
tryClose(b.pusherCloser, "pusher")
tryClose(b.pullerCloser, "puller")
tryClose(b.pullSyncCloser, "pull sync")
tryClose(b.pssCloser, "pss")
b.p2pCancel()
if err := b.p2pService.Close(); err != nil {
errs.add(fmt.Errorf("p2p server: %w", err))
}
if b.transactionMonitorCloser != nil {
if err := b.transactionMonitorCloser.Close(); err != nil {
errs.add(fmt.Errorf("transaction monitor: %w", err))
}
}
tryClose(b.p2pService, "p2p server")
tryClose(b.transactionMonitorCloser, "transaction monitor")
if c := b.ethClientCloser; c != nil {
c()
}
if err := b.tracerCloser.Close(); err != nil {
errs.add(fmt.Errorf("tracer: %w", err))
}
if err := b.tagsCloser.Close(); err != nil {
errs.add(fmt.Errorf("tag persistence: %w", err))
}
if b.listenerCloser != nil {
if err := b.listenerCloser.Close(); err != nil {
errs.add(fmt.Errorf("listener: %w", err))
}
}
if err := b.postageServiceCloser.Close(); err != nil {
errs.add(fmt.Errorf("postage service: %w", err))
}
if err := b.stateStoreCloser.Close(); err != nil {
errs.add(fmt.Errorf("statestore: %w", err))
}
if err := b.localstoreCloser.Close(); err != nil {
errs.add(fmt.Errorf("localstore: %w", err))
}
if err := b.topologyCloser.Close(); err != nil {
errs.add(fmt.Errorf("topology driver: %w", err))
}
if err := b.errorLogWriter.Close(); err != nil {
errs.add(fmt.Errorf("error log writer: %w", err))
}
// Shutdown the resolver service only if it has been initialized.
if b.resolverCloser != nil {
if err := b.resolverCloser.Close(); err != nil {
errs.add(fmt.Errorf("resolver service: %w", err))
}
}
if errs.hasErrors() {
return errs
}
return nil
}
type multiError struct {
errors []error
}
func (e *multiError) Error() string {
if len(e.errors) == 0 {
return ""
}
s := e.errors[0].Error()
for _, err := range e.errors[1:] {
s += "; " + err.Error()
}
return s
}
func (e *multiError) add(err error) {
e.errors = append(e.errors, err)
}
tryClose(b.tracerCloser, "tracer")
tryClose(b.tagsCloser, "tag persistence")
tryClose(b.listenerCloser, "listener")
tryClose(b.postageServiceCloser, "postage service")
tryClose(b.stateStoreCloser, "statestore")
tryClose(b.localstoreCloser, "localstore")
tryClose(b.topologyCloser, "topology driver")
tryClose(b.errorLogWriter, "error log writer")
tryClose(b.resolverCloser, "resolver service")
func (e *multiError) hasErrors() bool {
return len(e.errors) > 0
return mErr
}
func getTxHash(stateStore storage.StateStorer, logger logging.Logger, o Options) ([]byte, error) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment