Commit 12829cb6 authored by Andreas Bigger's avatar Andreas Bigger

Fix trace indexing

parent 3d1cc98b
...@@ -15,7 +15,6 @@ import ( ...@@ -15,7 +15,6 @@ import (
var ( var (
GetStepDataErr = fmt.Errorf("GetStepData not supported") GetStepDataErr = fmt.Errorf("GetStepData not supported")
AbsolutePreStateErr = fmt.Errorf("AbsolutePreState not supported") AbsolutePreStateErr = fmt.Errorf("AbsolutePreState not supported")
PreStateRequestErr = fmt.Errorf("Requested trace index is before prestate block")
) )
var _ types.TraceProvider = (*OutputTraceProvider)(nil) var _ types.TraceProvider = (*OutputTraceProvider)(nil)
...@@ -27,31 +26,33 @@ type OutputRollupClient interface { ...@@ -27,31 +26,33 @@ type OutputRollupClient interface {
// OutputTraceProvider is a [types.TraceProvider] implementation that uses // OutputTraceProvider is a [types.TraceProvider] implementation that uses
// output roots for given L2 Blocks as a trace. // output roots for given L2 Blocks as a trace.
type OutputTraceProvider struct { type OutputTraceProvider struct {
logger log.Logger logger log.Logger
rollupClient OutputRollupClient rollupClient OutputRollupClient
prestateBlock uint64 prestateBlock uint64
poststateBlock uint64
} }
func NewTraceProvider(ctx context.Context, logger log.Logger, rollupRpc string, prestateBlock uint64) (*OutputTraceProvider, error) { func NewTraceProvider(ctx context.Context, logger log.Logger, rollupRpc string, prestateBlock, poststateBlock uint64) (*OutputTraceProvider, error) {
rollupClient, err := client.DialRollupClientWithTimeout(client.DefaultDialTimeout, logger, rollupRpc) rollupClient, err := client.DialRollupClientWithTimeout(client.DefaultDialTimeout, logger, rollupRpc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewTraceProviderFromInputs(logger, rollupClient, prestateBlock), nil return NewTraceProviderFromInputs(logger, rollupClient, prestateBlock, poststateBlock), nil
} }
func NewTraceProviderFromInputs(logger log.Logger, rollupClient OutputRollupClient, prestateBlock uint64) *OutputTraceProvider { func NewTraceProviderFromInputs(logger log.Logger, rollupClient OutputRollupClient, prestateBlock, poststateBlock uint64) *OutputTraceProvider {
return &OutputTraceProvider{ return &OutputTraceProvider{
logger: logger, logger: logger,
rollupClient: rollupClient, rollupClient: rollupClient,
prestateBlock: prestateBlock, prestateBlock: prestateBlock,
poststateBlock: poststateBlock,
} }
} }
func (o *OutputTraceProvider) Get(ctx context.Context, traceIndex uint64) (common.Hash, error) { func (o *OutputTraceProvider) Get(ctx context.Context, traceIndex uint64) (common.Hash, error) {
outputBlock := traceIndex + 1 outputBlock := traceIndex + o.prestateBlock + 1
if outputBlock < o.prestateBlock { if outputBlock > o.poststateBlock {
return common.Hash{}, PreStateRequestErr outputBlock = o.poststateBlock
} }
output, err := o.rollupClient.OutputAtBlock(ctx, outputBlock) output, err := o.rollupClient.OutputAtBlock(ctx, outputBlock)
if err != nil { if err != nil {
......
...@@ -14,43 +14,54 @@ import ( ...@@ -14,43 +14,54 @@ import (
) )
var ( var (
prestateBlock = uint64(100) prestateBlock = uint64(100)
prestateOutputRoot = common.HexToHash("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") poststateBlock = uint64(200)
firstOutputRoot = common.HexToHash("0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") prestateOutputRoot = common.HexToHash("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
firstOutputRoot = common.HexToHash("0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
poststateOutputRoot = common.HexToHash("0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc")
) )
func TestGet(t *testing.T) { func TestGet(t *testing.T) {
t.Run("TraceIndexBeforePrestate", func(t *testing.T) { t.Run("FirstBlockAfterPrestate", func(t *testing.T) {
provider, _ := setupWithTestData(t, prestateBlock) provider, _ := setupWithTestData(t, prestateBlock, poststateBlock)
_, err := provider.Get(context.Background(), 0) value, err := provider.Get(context.Background(), 0)
require.ErrorIs(t, err, PreStateRequestErr) require.NoError(t, err)
require.Equal(t, value, firstOutputRoot)
}) })
t.Run("MissingOutputAtBlock", func(t *testing.T) { t.Run("MissingOutputAtBlock", func(t *testing.T) {
provider, _ := setupWithTestData(t, prestateBlock) provider, _ := setupWithTestData(t, prestateBlock, poststateBlock)
traceIndex := 101 _, err := provider.Get(context.Background(), 1)
_, err := provider.Get(context.Background(), uint64(traceIndex)) require.ErrorAs(t, fmt.Errorf("no output at block %d", prestateBlock+2), &err)
require.ErrorAs(t, fmt.Errorf("no output at block %d", uint64(traceIndex+1)), &err)
}) })
t.Run("Success", func(t *testing.T) { t.Run("PostStateBlock", func(t *testing.T) {
provider, _ := setupWithTestData(t, prestateBlock) provider, _ := setupWithTestData(t, prestateBlock, poststateBlock)
value, err := provider.Get(context.Background(), 100) traceIndex := poststateBlock - prestateBlock
value, err := provider.Get(context.Background(), traceIndex)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, value, firstOutputRoot) require.Equal(t, value, poststateOutputRoot)
})
t.Run("AfterPostStateBlock", func(t *testing.T) {
provider, _ := setupWithTestData(t, prestateBlock, poststateBlock)
traceIndex := poststateBlock - prestateBlock + 1
value, err := provider.Get(context.Background(), traceIndex)
require.NoError(t, err)
require.Equal(t, value, poststateOutputRoot)
}) })
} }
func TestAbsolutePreStateCommitment(t *testing.T) { func TestAbsolutePreStateCommitment(t *testing.T) {
t.Run("FailedToFetchOutput", func(t *testing.T) { t.Run("FailedToFetchOutput", func(t *testing.T) {
provider, rollupClient := setupWithTestData(t, prestateBlock) provider, rollupClient := setupWithTestData(t, prestateBlock, poststateBlock)
rollupClient.errorsOnPrestateFetch = true rollupClient.errorsOnPrestateFetch = true
_, err := provider.AbsolutePreStateCommitment(context.Background()) _, err := provider.AbsolutePreStateCommitment(context.Background())
require.ErrorAs(t, fmt.Errorf("no output at block %d", prestateBlock), &err) require.ErrorAs(t, fmt.Errorf("no output at block %d", prestateBlock), &err)
}) })
t.Run("Success", func(t *testing.T) { t.Run("ReturnsCorrectPrestateOutput", func(t *testing.T) {
provider, _ := setupWithTestData(t, prestateBlock) provider, _ := setupWithTestData(t, prestateBlock, poststateBlock)
value, err := provider.AbsolutePreStateCommitment(context.Background()) value, err := provider.AbsolutePreStateCommitment(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, value, prestateOutputRoot) require.Equal(t, value, prestateOutputRoot)
...@@ -58,32 +69,36 @@ func TestAbsolutePreStateCommitment(t *testing.T) { ...@@ -58,32 +69,36 @@ func TestAbsolutePreStateCommitment(t *testing.T) {
} }
func TestGetStepData(t *testing.T) { func TestGetStepData(t *testing.T) {
provider, _ := setupWithTestData(t, prestateBlock) provider, _ := setupWithTestData(t, prestateBlock, poststateBlock)
_, _, _, err := provider.GetStepData(context.Background(), 0) _, _, _, err := provider.GetStepData(context.Background(), 0)
require.ErrorIs(t, err, GetStepDataErr) require.ErrorIs(t, err, GetStepDataErr)
} }
func TestAbsolutePreState(t *testing.T) { func TestAbsolutePreState(t *testing.T) {
provider, _ := setupWithTestData(t, prestateBlock) provider, _ := setupWithTestData(t, prestateBlock, poststateBlock)
_, err := provider.AbsolutePreState(context.Background()) _, err := provider.AbsolutePreState(context.Background())
require.ErrorIs(t, err, AbsolutePreStateErr) require.ErrorIs(t, err, AbsolutePreStateErr)
} }
func setupWithTestData(t *testing.T, prestateBlock uint64) (*OutputTraceProvider, *stubRollupClient) { func setupWithTestData(t *testing.T, prestateBlock, poststateBlock uint64) (*OutputTraceProvider, *stubRollupClient) {
rollupClient := stubRollupClient{ rollupClient := stubRollupClient{
outputs: map[uint64]*eth.OutputResponse{ outputs: map[uint64]*eth.OutputResponse{
100: { prestateBlock: {
OutputRoot: eth.Bytes32(prestateOutputRoot), OutputRoot: eth.Bytes32(prestateOutputRoot),
}, },
101: { 101: {
OutputRoot: eth.Bytes32(firstOutputRoot), OutputRoot: eth.Bytes32(firstOutputRoot),
}, },
poststateBlock: {
OutputRoot: eth.Bytes32(poststateOutputRoot),
},
}, },
} }
return &OutputTraceProvider{ return &OutputTraceProvider{
logger: testlog.Logger(t, log.LvlInfo), logger: testlog.Logger(t, log.LvlInfo),
rollupClient: &rollupClient, rollupClient: &rollupClient,
prestateBlock: prestateBlock, prestateBlock: prestateBlock,
poststateBlock: poststateBlock,
}, &rollupClient }, &rollupClient
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment