Commit d3bd703c authored by refcell's avatar refcell

Clean up testing

parent 33c9b5a4
...@@ -63,10 +63,10 @@ func (s *SplitTraceProvider) AbsolutePreState(ctx context.Context) (preimage []b ...@@ -63,10 +63,10 @@ func (s *SplitTraceProvider) AbsolutePreState(ctx context.Context) (preimage []b
// GetStepData routes the GetStepData request to the lowest internal [types.TraceProvider]. // GetStepData routes the GetStepData request to the lowest internal [types.TraceProvider].
func (s *SplitTraceProvider) GetStepData(ctx context.Context, pos types.Position) (prestate []byte, proofData []byte, preimageData *types.PreimageOracleData, err error) { func (s *SplitTraceProvider) GetStepData(ctx context.Context, pos types.Position) (prestate []byte, proofData []byte, preimageData *types.PreimageOracleData, err error) {
ancestorDepth, _ := s.providerForDepth(uint64(pos.Depth())) ancestorDepth, provider := s.providerForDepth(uint64(pos.Depth()))
relativePosition, err := pos.RelativeToAncestorAtDepth(ancestorDepth) relativePosition, err := pos.RelativeToAncestorAtDepth(ancestorDepth)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
return s.bottomProvider.GetStepData(ctx, relativePosition) return provider.GetStepData(ctx, relativePosition)
} }
...@@ -22,22 +22,14 @@ var ( ...@@ -22,22 +22,14 @@ var (
func TestGet(t *testing.T) { func TestGet(t *testing.T) {
t.Run("ErrorBubblesUp", func(t *testing.T) { t.Run("ErrorBubblesUp", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{getError: mockGetError} mockOutputProvider := mockTraceProvider{getError: mockGetError}
splitProvider := SplitTraceProvider{ splitProvider := newSplitTraceProvider(t, &mockOutputProvider, nil, 40)
logger: testlog.Logger(t, log.LvlInfo),
topProvider: &mockOutputProvider,
topDepth: 40,
}
_, err := splitProvider.Get(context.Background(), types.NewPosition(1, 0)) _, err := splitProvider.Get(context.Background(), types.NewPosition(1, 0))
require.ErrorIs(t, err, mockGetError) require.ErrorIs(t, err, mockGetError)
}) })
t.Run("ReturnsCorrectOutput", func(t *testing.T) { t.Run("ReturnsCorrectOutput", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{getOutput: mockOutput} mockOutputProvider := mockTraceProvider{getOutput: mockOutput}
splitProvider := SplitTraceProvider{ splitProvider := newSplitTraceProvider(t, &mockOutputProvider, nil, 40)
logger: testlog.Logger(t, log.LvlInfo),
topProvider: &mockOutputProvider,
topDepth: 40,
}
output, err := splitProvider.Get(context.Background(), types.NewPosition(6, 3)) output, err := splitProvider.Get(context.Background(), types.NewPosition(6, 3))
require.NoError(t, err) require.NoError(t, err)
expectedGIndex := types.NewPosition(6, 3).ToGIndex() expectedGIndex := types.NewPosition(6, 3).ToGIndex()
...@@ -45,14 +37,8 @@ func TestGet(t *testing.T) { ...@@ -45,14 +37,8 @@ func TestGet(t *testing.T) {
}) })
t.Run("ReturnsCorrectOutputWithMultipleProviders", func(t *testing.T) { t.Run("ReturnsCorrectOutputWithMultipleProviders", func(t *testing.T) {
topProvider := mockTraceProvider{}
bottomProvider := mockTraceProvider{getOutput: mockOutput} bottomProvider := mockTraceProvider{getOutput: mockOutput}
splitProvider := SplitTraceProvider{ splitProvider := newSplitTraceProvider(t, &mockTraceProvider{}, &bottomProvider, 40)
logger: testlog.Logger(t, log.LvlInfo),
topProvider: &topProvider,
bottomProvider: &bottomProvider,
topDepth: 40,
}
output, err := splitProvider.Get(context.Background(), types.NewPosition(42, 17)) output, err := splitProvider.Get(context.Background(), types.NewPosition(42, 17))
require.NoError(t, err) require.NoError(t, err)
expectedGIndex := types.NewPosition(2, 1).ToGIndex() expectedGIndex := types.NewPosition(2, 1).ToGIndex()
...@@ -63,22 +49,14 @@ func TestGet(t *testing.T) { ...@@ -63,22 +49,14 @@ func TestGet(t *testing.T) {
func TestAbsolutePreStateCommitment(t *testing.T) { func TestAbsolutePreStateCommitment(t *testing.T) {
t.Run("ErrorBubblesUp", func(t *testing.T) { t.Run("ErrorBubblesUp", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{absolutePreStateCommitmentError: mockGetError} mockOutputProvider := mockTraceProvider{absolutePreStateCommitmentError: mockGetError}
splitProvider := SplitTraceProvider{ splitProvider := newSplitTraceProvider(t, nil, &mockOutputProvider, 40)
logger: testlog.Logger(t, log.LvlInfo),
bottomProvider: &mockOutputProvider,
topDepth: 40,
}
_, err := splitProvider.AbsolutePreStateCommitment(context.Background()) _, err := splitProvider.AbsolutePreStateCommitment(context.Background())
require.ErrorIs(t, err, mockGetError) require.ErrorIs(t, err, mockGetError)
}) })
t.Run("ReturnsCorrectOutput", func(t *testing.T) { t.Run("ReturnsCorrectOutput", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{absolutePreStateCommitment: mockCommitment} mockOutputProvider := mockTraceProvider{absolutePreStateCommitment: mockCommitment}
splitProvider := SplitTraceProvider{ splitProvider := newSplitTraceProvider(t, nil, &mockOutputProvider, 40)
logger: testlog.Logger(t, log.LvlInfo),
bottomProvider: &mockOutputProvider,
topDepth: 40,
}
output, err := splitProvider.AbsolutePreStateCommitment(context.Background()) output, err := splitProvider.AbsolutePreStateCommitment(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, mockCommitment, output) require.Equal(t, mockCommitment, output)
...@@ -88,27 +66,37 @@ func TestAbsolutePreStateCommitment(t *testing.T) { ...@@ -88,27 +66,37 @@ func TestAbsolutePreStateCommitment(t *testing.T) {
func TestAbsolutePreState(t *testing.T) { func TestAbsolutePreState(t *testing.T) {
t.Run("ErrorBubblesUp", func(t *testing.T) { t.Run("ErrorBubblesUp", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{absolutePreStateError: mockGetError} mockOutputProvider := mockTraceProvider{absolutePreStateError: mockGetError}
splitProvider := SplitTraceProvider{ splitProvider := newSplitTraceProvider(t, nil, &mockOutputProvider, 40)
logger: testlog.Logger(t, log.LvlInfo),
bottomProvider: &mockOutputProvider,
topDepth: 40,
}
_, err := splitProvider.AbsolutePreState(context.Background()) _, err := splitProvider.AbsolutePreState(context.Background())
require.ErrorIs(t, err, mockGetError) require.ErrorIs(t, err, mockGetError)
}) })
t.Run("ReturnsCorrectPreimageData", func(t *testing.T) {
expectedPreimage := []byte{1, 2, 3, 4}
mockOutputProvider := mockTraceProvider{preImageData: expectedPreimage}
splitProvider := newSplitTraceProvider(t, nil, &mockOutputProvider, 40)
output, err := splitProvider.AbsolutePreState(context.Background())
require.NoError(t, err)
require.Equal(t, expectedPreimage, output)
})
} }
func TestGetStepData(t *testing.T) { func TestGetStepData(t *testing.T) {
t.Run("ErrorBubblesUp", func(t *testing.T) { t.Run("ErrorBubblesUp", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{getStepDataError: mockGetError} mockOutputProvider := mockTraceProvider{getStepDataError: mockGetError}
splitProvider := SplitTraceProvider{ splitProvider := newSplitTraceProvider(t, &mockOutputProvider, nil, 40)
logger: testlog.Logger(t, log.LvlInfo),
bottomProvider: &mockOutputProvider,
topDepth: 40,
}
_, _, _, err := splitProvider.GetStepData(context.Background(), types.NewPosition(0, 0)) _, _, _, err := splitProvider.GetStepData(context.Background(), types.NewPosition(0, 0))
require.ErrorIs(t, err, mockGetError) require.ErrorIs(t, err, mockGetError)
}) })
t.Run("ReturnsCorrectStepData", func(t *testing.T) {
expectedStepData := []byte{1, 2, 3, 4}
mockOutputProvider := mockTraceProvider{stepPrestateData: expectedStepData}
splitProvider := newSplitTraceProvider(t, nil, &mockOutputProvider, 40)
output, _, _, err := splitProvider.GetStepData(context.Background(), types.NewPosition(41, 0))
require.NoError(t, err)
require.Equal(t, expectedStepData, output)
})
} }
type mockTraceProvider struct { type mockTraceProvider struct {
...@@ -117,7 +105,18 @@ type mockTraceProvider struct { ...@@ -117,7 +105,18 @@ type mockTraceProvider struct {
absolutePreStateCommitmentError error absolutePreStateCommitmentError error
absolutePreStateCommitment common.Hash absolutePreStateCommitment common.Hash
absolutePreStateError error absolutePreStateError error
preImageData []byte
getStepDataError error getStepDataError error
stepPrestateData []byte
}
func newSplitTraceProvider(t *testing.T, tp *mockTraceProvider, bp *mockTraceProvider, topDepth uint64) SplitTraceProvider {
return SplitTraceProvider{
logger: testlog.Logger(t, log.LvlInfo),
topProvider: tp,
bottomProvider: bp,
topDepth: topDepth,
}
} }
func (m *mockTraceProvider) Get(ctx context.Context, pos types.Position) (common.Hash, error) { func (m *mockTraceProvider) Get(ctx context.Context, pos types.Position) (common.Hash, error) {
...@@ -138,12 +137,12 @@ func (m *mockTraceProvider) AbsolutePreState(ctx context.Context) (preimage []by ...@@ -138,12 +137,12 @@ func (m *mockTraceProvider) AbsolutePreState(ctx context.Context) (preimage []by
if m.absolutePreStateError != nil { if m.absolutePreStateError != nil {
return []byte{}, m.absolutePreStateError return []byte{}, m.absolutePreStateError
} }
return []byte{}, nil return m.preImageData, nil
} }
func (m *mockTraceProvider) GetStepData(ctx context.Context, pos types.Position) ([]byte, []byte, *types.PreimageOracleData, error) { func (m *mockTraceProvider) GetStepData(ctx context.Context, pos types.Position) ([]byte, []byte, *types.PreimageOracleData, error) {
if m.getStepDataError != nil { if m.getStepDataError != nil {
return nil, nil, nil, m.getStepDataError return nil, nil, nil, m.getStepDataError
} }
return nil, nil, nil, nil return m.stepPrestateData, nil, nil, nil
} }
...@@ -6,7 +6,7 @@ import ( ...@@ -6,7 +6,7 @@ import (
) )
var ( var (
PositionDepthTooSmall = errors.New("Position depth is too small") ErrPositionDepthTooSmall = errors.New("Position depth is too small")
) )
// Position is a golang wrapper around the dispute game Position type. // Position is a golang wrapper around the dispute game Position type.
...@@ -36,7 +36,7 @@ func (p Position) MoveRight() Position { ...@@ -36,7 +36,7 @@ func (p Position) MoveRight() Position {
// [ancestor] is the depth of the subtree root node. // [ancestor] is the depth of the subtree root node.
func (p Position) RelativeToAncestorAtDepth(ancestor uint64) (Position, error) { func (p Position) RelativeToAncestorAtDepth(ancestor uint64) (Position, error) {
if ancestor > uint64(p.depth) { if ancestor > uint64(p.depth) {
return Position{}, PositionDepthTooSmall return Position{}, ErrPositionDepthTooSmall
} }
newPosDepth := uint64(p.depth) - ancestor newPosDepth := uint64(p.depth) - ancestor
nodesAtDepth := 1 << newPosDepth nodesAtDepth := 1 << newPosDepth
......
...@@ -124,7 +124,7 @@ func TestRelativeToAncestorAtDepth(t *testing.T) { ...@@ -124,7 +124,7 @@ func TestRelativeToAncestorAtDepth(t *testing.T) {
t.Run("ErrorsForDeepAncestor", func(t *testing.T) { t.Run("ErrorsForDeepAncestor", func(t *testing.T) {
pos := NewPosition(1, 1) pos := NewPosition(1, 1)
_, err := pos.RelativeToAncestorAtDepth(2) _, err := pos.RelativeToAncestorAtDepth(2)
require.ErrorIs(t, err, PositionDepthTooSmall) require.ErrorIs(t, err, ErrPositionDepthTooSmall)
}) })
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment