Commit 33c9b5a4 authored by refcell's avatar refcell

Fix the split trace provider

parent 1e956691
......@@ -2,7 +2,6 @@ package split
import (
"context"
"fmt"
"github.com/ethereum-optimism/optimism/op-challenger/game/fault/types"
......@@ -10,74 +9,64 @@ import (
"github.com/ethereum/go-ethereum/log"
)
var (
GetStepDataErr = fmt.Errorf("GetStepData not supported")
NoProvidersErr = fmt.Errorf("no trace providers configured")
)
var _ types.TraceProvider = (*SplitTraceProvider)(nil)
// SplitTraceProvider is a [types.TraceProvider] implementation that
// routes requests to the correct internal trace provider based on the
// depth of the requested trace.
type SplitTraceProvider struct {
logger log.Logger
providers []types.TraceProvider
depthTiers []uint64
logger log.Logger
topProvider types.TraceProvider
bottomProvider types.TraceProvider
topDepth uint64
}
func NewTraceProvider(logger log.Logger, providers []types.TraceProvider, depthTiers []uint64) *SplitTraceProvider {
// NewTraceProvider creates a new [SplitTraceProvider] instance.
// The [topDepth] parameter specifies the depth at which the internal
// [types.TraceProvider] should be switched.
func NewTraceProvider(logger log.Logger, topProvider types.TraceProvider, bottomProvider types.TraceProvider, topDepth uint64) *SplitTraceProvider {
return &SplitTraceProvider{
logger: logger,
providers: providers,
depthTiers: depthTiers,
logger: logger,
topProvider: topProvider,
bottomProvider: bottomProvider,
topDepth: topDepth,
}
}
func (s *SplitTraceProvider) providerForDepth(depth uint64) (uint64, types.TraceProvider) {
reduced := uint64(0)
for i, tier := range s.depthTiers {
if depth <= tier {
return reduced, s.providers[i]
}
if i < len(s.providers)-1 {
reduced += tier
}
if depth <= s.topDepth {
return 0, s.topProvider
}
return reduced, s.providers[len(s.providers)-1]
return s.topDepth, s.bottomProvider
}
// Get routes the Get request to the internal [types.TraceProvider] that
// that serves the trace index at the depth.
func (s *SplitTraceProvider) Get(ctx context.Context, pos types.Position) (common.Hash, error) {
if len(s.providers) == 0 {
return common.Hash{}, NoProvidersErr
ancestorDepth, provider := s.providerForDepth(uint64(pos.Depth()))
relativePosition, err := pos.RelativeToAncestorAtDepth(ancestorDepth)
if err != nil {
return common.Hash{}, err
}
reduced, provider := s.providerForDepth(uint64(pos.Depth()))
localizedPosition := pos.Localize(reduced)
return provider.Get(ctx, localizedPosition)
return provider.Get(ctx, relativePosition)
}
// AbsolutePreStateCommitment returns the absolute prestate from the lowest internal [types.TraceProvider]
func (s *SplitTraceProvider) AbsolutePreStateCommitment(ctx context.Context) (hash common.Hash, err error) {
if len(s.providers) == 0 {
return common.Hash{}, NoProvidersErr
}
return s.providers[len(s.providers)-1].AbsolutePreStateCommitment(ctx)
return s.bottomProvider.AbsolutePreStateCommitment(ctx)
}
// AbsolutePreState routes the AbsolutePreState request to the lowest internal [types.TraceProvider].
func (s *SplitTraceProvider) AbsolutePreState(ctx context.Context) (preimage []byte, err error) {
if len(s.providers) == 0 {
return nil, NoProvidersErr
}
return s.providers[len(s.providers)-1].AbsolutePreState(ctx)
return s.bottomProvider.AbsolutePreState(ctx)
}
// GetStepData routes the GetStepData request to the lowest internal [types.TraceProvider].
func (s *SplitTraceProvider) GetStepData(ctx context.Context, i types.Position) (prestate []byte, proofData []byte, preimageData *types.PreimageOracleData, err error) {
if len(s.providers) == 0 {
return nil, nil, nil, NoProvidersErr
func (s *SplitTraceProvider) GetStepData(ctx context.Context, pos types.Position) (prestate []byte, proofData []byte, preimageData *types.PreimageOracleData, err error) {
ancestorDepth, _ := s.providerForDepth(uint64(pos.Depth()))
relativePosition, err := pos.RelativeToAncestorAtDepth(ancestorDepth)
if err != nil {
return nil, nil, nil, err
}
return s.providers[len(s.providers)-1].GetStepData(ctx, i)
return s.bottomProvider.GetStepData(ctx, relativePosition)
}
......@@ -23,9 +23,9 @@ func TestGet(t *testing.T) {
t.Run("ErrorBubblesUp", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{getError: mockGetError}
splitProvider := SplitTraceProvider{
logger: testlog.Logger(t, log.LvlInfo),
providers: []types.TraceProvider{&mockOutputProvider},
depthTiers: []uint64{40, 20},
logger: testlog.Logger(t, log.LvlInfo),
topProvider: &mockOutputProvider,
topDepth: 40,
}
_, err := splitProvider.Get(context.Background(), types.NewPosition(1, 0))
require.ErrorIs(t, err, mockGetError)
......@@ -34,9 +34,9 @@ func TestGet(t *testing.T) {
t.Run("ReturnsCorrectOutput", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{getOutput: mockOutput}
splitProvider := SplitTraceProvider{
logger: testlog.Logger(t, log.LvlInfo),
providers: []types.TraceProvider{&mockOutputProvider},
depthTiers: []uint64{40, 20},
logger: testlog.Logger(t, log.LvlInfo),
topProvider: &mockOutputProvider,
topDepth: 40,
}
output, err := splitProvider.Get(context.Background(), types.NewPosition(6, 3))
require.NoError(t, err)
......@@ -45,12 +45,13 @@ func TestGet(t *testing.T) {
})
t.Run("ReturnsCorrectOutputWithMultipleProviders", func(t *testing.T) {
firstOutputProvider := mockTraceProvider{}
secondOutputProvider := mockTraceProvider{getOutput: mockOutput}
topProvider := mockTraceProvider{}
bottomProvider := mockTraceProvider{getOutput: mockOutput}
splitProvider := SplitTraceProvider{
logger: testlog.Logger(t, log.LvlInfo),
providers: []types.TraceProvider{&firstOutputProvider, &secondOutputProvider},
depthTiers: []uint64{40, 20},
logger: testlog.Logger(t, log.LvlInfo),
topProvider: &topProvider,
bottomProvider: &bottomProvider,
topDepth: 40,
}
output, err := splitProvider.Get(context.Background(), types.NewPosition(42, 17))
require.NoError(t, err)
......@@ -63,9 +64,9 @@ func TestAbsolutePreStateCommitment(t *testing.T) {
t.Run("ErrorBubblesUp", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{absolutePreStateCommitmentError: mockGetError}
splitProvider := SplitTraceProvider{
logger: testlog.Logger(t, log.LvlInfo),
providers: []types.TraceProvider{&mockOutputProvider},
depthTiers: []uint64{40, 20},
logger: testlog.Logger(t, log.LvlInfo),
bottomProvider: &mockOutputProvider,
topDepth: 40,
}
_, err := splitProvider.AbsolutePreStateCommitment(context.Background())
require.ErrorIs(t, err, mockGetError)
......@@ -74,9 +75,9 @@ func TestAbsolutePreStateCommitment(t *testing.T) {
t.Run("ReturnsCorrectOutput", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{absolutePreStateCommitment: mockCommitment}
splitProvider := SplitTraceProvider{
logger: testlog.Logger(t, log.LvlInfo),
providers: []types.TraceProvider{&mockOutputProvider},
depthTiers: []uint64{40, 20},
logger: testlog.Logger(t, log.LvlInfo),
bottomProvider: &mockOutputProvider,
topDepth: 40,
}
output, err := splitProvider.AbsolutePreStateCommitment(context.Background())
require.NoError(t, err)
......@@ -88,9 +89,9 @@ func TestAbsolutePreState(t *testing.T) {
t.Run("ErrorBubblesUp", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{absolutePreStateError: mockGetError}
splitProvider := SplitTraceProvider{
logger: testlog.Logger(t, log.LvlInfo),
providers: []types.TraceProvider{&mockOutputProvider},
depthTiers: []uint64{40},
logger: testlog.Logger(t, log.LvlInfo),
bottomProvider: &mockOutputProvider,
topDepth: 40,
}
_, err := splitProvider.AbsolutePreState(context.Background())
require.ErrorIs(t, err, mockGetError)
......@@ -101,9 +102,9 @@ func TestGetStepData(t *testing.T) {
t.Run("ErrorBubblesUp", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{getStepDataError: mockGetError}
splitProvider := SplitTraceProvider{
logger: testlog.Logger(t, log.LvlInfo),
providers: []types.TraceProvider{&mockOutputProvider},
depthTiers: []uint64{40},
logger: testlog.Logger(t, log.LvlInfo),
bottomProvider: &mockOutputProvider,
topDepth: 40,
}
_, _, _, err := splitProvider.GetStepData(context.Background(), types.NewPosition(0, 0))
require.ErrorIs(t, err, mockGetError)
......
package types
import "fmt"
import (
"errors"
"fmt"
)
var (
PositionDepthTooSmall = errors.New("Position depth is too small")
)
// Position is a golang wrapper around the dispute game Position type.
type Position struct {
......@@ -25,13 +32,16 @@ func (p Position) MoveRight() Position {
}
}
// Localize returns a new position for a subtree.
// reduced is the number of levels to reduce the depth by.
func (p Position) Localize(reduced uint64) Position {
newPosDepth := uint64(p.depth) - reduced
// RelativeToAncestorAtDepth returns a new position for a subtree.
// [ancestor] is the depth of the subtree root node.
func (p Position) RelativeToAncestorAtDepth(ancestor uint64) (Position, error) {
if ancestor > uint64(p.depth) {
return Position{}, PositionDepthTooSmall
}
newPosDepth := uint64(p.depth) - ancestor
nodesAtDepth := 1 << newPosDepth
newIndexAtDepth := p.indexAtDepth % nodesAtDepth
return NewPosition(int(newPosDepth), newIndexAtDepth)
return NewPosition(int(newPosDepth), newIndexAtDepth), nil
}
func (p Position) Depth() int {
......
......@@ -119,3 +119,19 @@ func TestDefend(t *testing.T) {
require.Equalf(t, test.DefendGIndex, result.ToGIndex(), "Defend from GIndex %v", pos.ToGIndex())
}
}
func TestRelativeToAncestorAtDepth(t *testing.T) {
t.Run("ErrorsForDeepAncestor", func(t *testing.T) {
pos := NewPosition(1, 1)
_, err := pos.RelativeToAncestorAtDepth(2)
require.ErrorIs(t, err, PositionDepthTooSmall)
})
t.Run("Success", func(t *testing.T) {
pos := NewPosition(2, 1)
expectedRelativePosition := NewPosition(1, 1)
relativePosition, err := pos.RelativeToAncestorAtDepth(1)
require.NoError(t, err)
require.Equal(t, expectedRelativePosition, relativePosition)
})
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment