Commit c3ac3445 authored by Adrian Sutton's avatar Adrian Sutton

op-program: Use a single defer to perform all cleanup.

parent deeeea19
...@@ -53,38 +53,44 @@ func Main(logger log.Logger, cfg *config.Config) error { ...@@ -53,38 +53,44 @@ func Main(logger log.Logger, cfg *config.Config) error {
// FaultProofProgram is the programmatic entry-point for the fault proof program // FaultProofProgram is the programmatic entry-point for the fault proof program
func FaultProofProgram(ctx context.Context, logger log.Logger, cfg *config.Config) error { func FaultProofProgram(ctx context.Context, logger log.Logger, cfg *config.Config) error {
var (
serverErr chan error
pClientRW oppio.FileChannel
hClientRW oppio.FileChannel
)
defer func() {
if pClientRW != nil {
_ = pClientRW.Close()
}
if hClientRW != nil {
_ = hClientRW.Close()
}
if serverErr != nil {
err := <-serverErr
if err != nil {
logger.Error("preimage server failed", "err", err)
}
logger.Debug("Preimage server stopped")
}
}()
// Setup client I/O for preimage oracle interaction // Setup client I/O for preimage oracle interaction
pClientRW, pHostRW, err := oppio.CreateBidirectionalChannel() pClientRW, pHostRW, err := oppio.CreateBidirectionalChannel()
if err != nil { if err != nil {
return fmt.Errorf("failed to create preimage pipe: %w", err) return fmt.Errorf("failed to create preimage pipe: %w", err)
} }
pClientCloser := oppio.NewSafeClose(pClientRW)
defer pClientCloser.Close()
// Setup client I/O for hint comms // Setup client I/O for hint comms
hClientRW, hHostRW, err := oppio.CreateBidirectionalChannel() hClientRW, hHostRW, err := oppio.CreateBidirectionalChannel()
if err != nil { if err != nil {
return fmt.Errorf("failed to create hints pipe: %w", err) return fmt.Errorf("failed to create hints pipe: %w", err)
} }
hClientCloser := oppio.NewSafeClose(hClientRW)
defer hClientCloser.Close()
// Use a channel to receive the server result so we can wait for it to complete before returning // Use a channel to receive the server result so we can wait for it to complete before returning
serverErr := make(chan error) serverErr = make(chan error)
go func() { go func() {
defer close(serverErr)
serverErr <- PreimageServer(ctx, logger, cfg, pHostRW, hHostRW) serverErr <- PreimageServer(ctx, logger, cfg, pHostRW, hHostRW)
}() }()
defer func() {
// Ensure the client streams are closed to trigger the server to exit
pClientCloser.Close()
hClientCloser.Close()
logger.Debug("Waiting for preimage server to exit")
err := <-serverErr
if err != nil {
logger.Error("preimage server failed", "err", err)
}
logger.Debug("Preimage server stopped")
}()
var cmd *exec.Cmd var cmd *exec.Cmd
if cfg.ExecCmd != "" { if cfg.ExecCmd != "" {
...@@ -116,13 +122,20 @@ func FaultProofProgram(ctx context.Context, logger log.Logger, cfg *config.Confi ...@@ -116,13 +122,20 @@ func FaultProofProgram(ctx context.Context, logger log.Logger, cfg *config.Confi
// If either returns an error both handlers are stopped. // If either returns an error both handlers are stopped.
// The supplied preimageChannel and hintChannel will be closed before this function returns. // The supplied preimageChannel and hintChannel will be closed before this function returns.
func PreimageServer(ctx context.Context, logger log.Logger, cfg *config.Config, preimageChannel oppio.FileChannel, hintChannel oppio.FileChannel) error { func PreimageServer(ctx context.Context, logger log.Logger, cfg *config.Config, preimageChannel oppio.FileChannel, hintChannel oppio.FileChannel) error {
preimageCloser := oppio.NewSafeClose(preimageChannel) var serverDone chan error
hintCloser := oppio.NewSafeClose(hintChannel) var hinterDone chan error
closeChannels := func() { defer func() {
_ = preimageCloser.Close() preimageChannel.Close()
_ = hintCloser.Close() hintChannel.Close()
if serverDone != nil {
// Wait for pre-image server to complete
<-serverDone
}
if hinterDone != nil {
// Wait for hinter to complete
<-hinterDone
} }
defer closeChannels() }()
logger.Info("Starting preimage server") logger.Info("Starting preimage server")
var kv kvstore.KV var kv kvstore.KV
if cfg.DataDir == "" { if cfg.DataDir == "" {
...@@ -160,18 +173,12 @@ func PreimageServer(ctx context.Context, logger log.Logger, cfg *config.Config, ...@@ -160,18 +173,12 @@ func PreimageServer(ctx context.Context, logger log.Logger, cfg *config.Config,
splitter := kvstore.NewPreimageSourceSplitter(localPreimageSource.Get, getPreimage) splitter := kvstore.NewPreimageSourceSplitter(localPreimageSource.Get, getPreimage)
preimageGetter := splitter.Get preimageGetter := splitter.Get
serverDone := launchOracleServer(logger, preimageChannel, preimageGetter) serverDone = launchOracleServer(logger, preimageChannel, preimageGetter)
hinterDone := routeHints(logger, hintChannel, hinter) hinterDone = routeHints(logger, hintChannel, hinter)
select { select {
case err := <-serverDone: case err := <-serverDone:
// Close the channels to trigger the hinter to exit and wait for it to complete
closeChannels()
<-hinterDone
return err return err
case err := <-hinterDone: case err := <-hinterDone:
// Close the channels to trigger the oracle server to exit and wait for it to complete
closeChannels()
<-serverDone
return err return err
} }
} }
......
package io
import (
"io"
"sync"
)
type safeClose struct {
c io.Closer
once sync.Once
}
func (s *safeClose) Close() error {
var err error
s.once.Do(func() {
err = s.c.Close()
})
return err
}
func NewSafeClose(c io.Closer) io.Closer {
return &safeClose{
c: c,
}
}
package io
import (
"errors"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestOnlyCallsCloseOnce(t *testing.T) {
delegate := new(mockCloser)
defer delegate.AssertExpectations(t)
safeClose := NewSafeClose(delegate)
// Only expects one close call
delegate.ExpectClose(nil)
require.NoError(t, safeClose.Close())
require.NoError(t, safeClose.Close())
}
func TestReturnsErrorFromFirstCall(t *testing.T) {
delegate := new(mockCloser)
defer delegate.AssertExpectations(t)
safeClose := NewSafeClose(delegate)
err := errors.New("expected")
// Only expects one close call
delegate.ExpectClose(err)
require.ErrorIs(t, safeClose.Close(), err)
// Later calls should not return an error as they didn't need to call Close
require.NoError(t, safeClose.Close())
}
type mockCloser struct {
mock.Mock
}
func (t *mockCloser) Close() error {
err := t.Mock.MethodCalled("Close").Get(0)
if err != nil {
return err.(error)
}
return nil
}
func (t *mockCloser) ExpectClose(err error) {
t.Mock.On("Close").Return(err)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment