Commit c3ac3445 authored by Adrian Sutton's avatar Adrian Sutton

op-program: Use a single defer to perform all cleanup.

parent deeeea19
......@@ -53,38 +53,44 @@ func Main(logger log.Logger, cfg *config.Config) error {
// FaultProofProgram is the programmatic entry-point for the fault proof program
func FaultProofProgram(ctx context.Context, logger log.Logger, cfg *config.Config) error {
var (
serverErr chan error
pClientRW oppio.FileChannel
hClientRW oppio.FileChannel
)
defer func() {
if pClientRW != nil {
_ = pClientRW.Close()
}
if hClientRW != nil {
_ = hClientRW.Close()
}
if serverErr != nil {
err := <-serverErr
if err != nil {
logger.Error("preimage server failed", "err", err)
}
logger.Debug("Preimage server stopped")
}
}()
// Setup client I/O for preimage oracle interaction
pClientRW, pHostRW, err := oppio.CreateBidirectionalChannel()
if err != nil {
return fmt.Errorf("failed to create preimage pipe: %w", err)
}
pClientCloser := oppio.NewSafeClose(pClientRW)
defer pClientCloser.Close()
// Setup client I/O for hint comms
hClientRW, hHostRW, err := oppio.CreateBidirectionalChannel()
if err != nil {
return fmt.Errorf("failed to create hints pipe: %w", err)
}
hClientCloser := oppio.NewSafeClose(hClientRW)
defer hClientCloser.Close()
// Use a channel to receive the server result so we can wait for it to complete before returning
serverErr := make(chan error)
serverErr = make(chan error)
go func() {
defer close(serverErr)
serverErr <- PreimageServer(ctx, logger, cfg, pHostRW, hHostRW)
}()
defer func() {
// Ensure the client streams are closed to trigger the server to exit
pClientCloser.Close()
hClientCloser.Close()
logger.Debug("Waiting for preimage server to exit")
err := <-serverErr
if err != nil {
logger.Error("preimage server failed", "err", err)
}
logger.Debug("Preimage server stopped")
}()
var cmd *exec.Cmd
if cfg.ExecCmd != "" {
......@@ -116,13 +122,20 @@ func FaultProofProgram(ctx context.Context, logger log.Logger, cfg *config.Confi
// If either returns an error both handlers are stopped.
// The supplied preimageChannel and hintChannel will be closed before this function returns.
func PreimageServer(ctx context.Context, logger log.Logger, cfg *config.Config, preimageChannel oppio.FileChannel, hintChannel oppio.FileChannel) error {
preimageCloser := oppio.NewSafeClose(preimageChannel)
hintCloser := oppio.NewSafeClose(hintChannel)
closeChannels := func() {
_ = preimageCloser.Close()
_ = hintCloser.Close()
}
defer closeChannels()
var serverDone chan error
var hinterDone chan error
defer func() {
preimageChannel.Close()
hintChannel.Close()
if serverDone != nil {
// Wait for pre-image server to complete
<-serverDone
}
if hinterDone != nil {
// Wait for hinter to complete
<-hinterDone
}
}()
logger.Info("Starting preimage server")
var kv kvstore.KV
if cfg.DataDir == "" {
......@@ -160,18 +173,12 @@ func PreimageServer(ctx context.Context, logger log.Logger, cfg *config.Config,
splitter := kvstore.NewPreimageSourceSplitter(localPreimageSource.Get, getPreimage)
preimageGetter := splitter.Get
serverDone := launchOracleServer(logger, preimageChannel, preimageGetter)
hinterDone := routeHints(logger, hintChannel, hinter)
serverDone = launchOracleServer(logger, preimageChannel, preimageGetter)
hinterDone = routeHints(logger, hintChannel, hinter)
select {
case err := <-serverDone:
// Close the channels to trigger the hinter to exit and wait for it to complete
closeChannels()
<-hinterDone
return err
case err := <-hinterDone:
// Close the channels to trigger the oracle server to exit and wait for it to complete
closeChannels()
<-serverDone
return err
}
}
......
package io
import (
"io"
"sync"
)
type safeClose struct {
c io.Closer
once sync.Once
}
func (s *safeClose) Close() error {
var err error
s.once.Do(func() {
err = s.c.Close()
})
return err
}
func NewSafeClose(c io.Closer) io.Closer {
return &safeClose{
c: c,
}
}
package io
import (
"errors"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestOnlyCallsCloseOnce(t *testing.T) {
delegate := new(mockCloser)
defer delegate.AssertExpectations(t)
safeClose := NewSafeClose(delegate)
// Only expects one close call
delegate.ExpectClose(nil)
require.NoError(t, safeClose.Close())
require.NoError(t, safeClose.Close())
}
func TestReturnsErrorFromFirstCall(t *testing.T) {
delegate := new(mockCloser)
defer delegate.AssertExpectations(t)
safeClose := NewSafeClose(delegate)
err := errors.New("expected")
// Only expects one close call
delegate.ExpectClose(err)
require.ErrorIs(t, safeClose.Close(), err)
// Later calls should not return an error as they didn't need to call Close
require.NoError(t, safeClose.Close())
}
type mockCloser struct {
mock.Mock
}
func (t *mockCloser) Close() error {
err := t.Mock.MethodCalled("Close").Get(0)
if err != nil {
return err.(error)
}
return nil
}
func (t *mockCloser) ExpectClose(err error) {
t.Mock.On("Close").Return(err)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment