Commit cf41f5c4 authored by OptimismBot's avatar OptimismBot Committed by GitHub

Merge pull request #6471 from ethereum-optimism/aj/trace-provider-ctx

op-challenger: Pass a context through to TraceProvider
parents e0b17bed 5571000a
...@@ -95,7 +95,7 @@ func (a *Agent) newGameFromContracts(ctx context.Context) (types.Game, error) { ...@@ -95,7 +95,7 @@ func (a *Agent) newGameFromContracts(ctx context.Context) (types.Game, error) {
// move determines & executes the next move given a claim // move determines & executes the next move given a claim
func (a *Agent) move(ctx context.Context, claim types.Claim, game types.Game) error { func (a *Agent) move(ctx context.Context, claim types.Claim, game types.Game) error {
nextMove, err := a.solver.NextMove(claim, game.AgreeWithClaimLevel(claim)) nextMove, err := a.solver.NextMove(ctx, claim, game.AgreeWithClaimLevel(claim))
if err != nil { if err != nil {
return fmt.Errorf("execute next move: %w", err) return fmt.Errorf("execute next move: %w", err)
} }
...@@ -133,7 +133,7 @@ func (a *Agent) step(ctx context.Context, claim types.Claim, game types.Game) er ...@@ -133,7 +133,7 @@ func (a *Agent) step(ctx context.Context, claim types.Claim, game types.Game) er
} }
a.log.Info("Attempting step", "claim_depth", claim.Depth(), "maxDepth", a.maxDepth) a.log.Info("Attempting step", "claim_depth", claim.Depth(), "maxDepth", a.maxDepth)
step, err := a.solver.AttemptStep(claim, agreeWithClaimLevel) step, err := a.solver.AttemptStep(ctx, claim, agreeWithClaimLevel)
if err != nil { if err != nil {
return fmt.Errorf("attempt step: %w", err) return fmt.Errorf("attempt step: %w", err)
} }
......
package alphabet package alphabet
import ( import (
"context"
"errors" "errors"
"math/big" "math/big"
"strings" "strings"
...@@ -30,33 +31,33 @@ func NewAlphabetProvider(state string, depth uint64) *AlphabetProvider { ...@@ -30,33 +31,33 @@ func NewAlphabetProvider(state string, depth uint64) *AlphabetProvider {
} }
// GetOracleData should not return any preimage oracle data for the alphabet provider. // GetOracleData should not return any preimage oracle data for the alphabet provider.
func (p *AlphabetProvider) GetOracleData(i uint64) (*types.PreimageOracleData, error) { func (p *AlphabetProvider) GetOracleData(ctx context.Context, i uint64) (*types.PreimageOracleData, error) {
return &types.PreimageOracleData{}, nil return &types.PreimageOracleData{}, nil
} }
// GetPreimage returns the preimage for the given hash. // GetPreimage returns the preimage for the given hash.
func (ap *AlphabetProvider) GetPreimage(i uint64) ([]byte, []byte, error) { func (ap *AlphabetProvider) GetPreimage(ctx context.Context, i uint64) ([]byte, []byte, error) {
// The index cannot be larger than the maximum index as computed by the depth. // The index cannot be larger than the maximum index as computed by the depth.
if i >= ap.maxLen { if i >= ap.maxLen {
return nil, nil, ErrIndexTooLarge return nil, nil, ErrIndexTooLarge
} }
// We extend the deepest hash to the maximum depth if the trace is not expansive. // We extend the deepest hash to the maximum depth if the trace is not expansive.
if i >= uint64(len(ap.state)) { if i >= uint64(len(ap.state)) {
return ap.GetPreimage(uint64(len(ap.state)) - 1) return ap.GetPreimage(ctx, uint64(len(ap.state))-1)
} }
return BuildAlphabetPreimage(i, ap.state[i]), []byte{}, nil return BuildAlphabetPreimage(i, ap.state[i]), []byte{}, nil
} }
// Get returns the claim value at the given index in the trace. // Get returns the claim value at the given index in the trace.
func (ap *AlphabetProvider) Get(i uint64) (common.Hash, error) { func (ap *AlphabetProvider) Get(ctx context.Context, i uint64) (common.Hash, error) {
claimBytes, _, err := ap.GetPreimage(i) claimBytes, _, err := ap.GetPreimage(ctx, i)
if err != nil { if err != nil {
return common.Hash{}, err return common.Hash{}, err
} }
return crypto.Keccak256Hash(claimBytes), nil return crypto.Keccak256Hash(claimBytes), nil
} }
func (ap *AlphabetProvider) AbsolutePreState() []byte { func (ap *AlphabetProvider) AbsolutePreState(ctx context.Context) []byte {
return common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000060") return common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000060")
} }
......
package alphabet package alphabet
import ( import (
"context"
"math/big" "math/big"
"testing" "testing"
...@@ -39,7 +40,7 @@ func TestAlphabetProvider_Get_ClaimsByTraceIndex(t *testing.T) { ...@@ -39,7 +40,7 @@ func TestAlphabetProvider_Get_ClaimsByTraceIndex(t *testing.T) {
// Execute each trace and check the alphabet provider returns the expected hash. // Execute each trace and check the alphabet provider returns the expected hash.
for _, trace := range traces { for _, trace := range traces {
expectedHash, err := canonicalProvider.Get(trace.traceIndex) expectedHash, err := canonicalProvider.Get(context.Background(), trace.traceIndex)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, trace.expectedHash, expectedHash) require.Equal(t, trace.expectedHash, expectedHash)
} }
...@@ -60,7 +61,7 @@ func FuzzIndexToBytes(f *testing.F) { ...@@ -60,7 +61,7 @@ func FuzzIndexToBytes(f *testing.F) {
func TestGetPreimage_Succeeds(t *testing.T) { func TestGetPreimage_Succeeds(t *testing.T) {
ap := NewAlphabetProvider("abc", 2) ap := NewAlphabetProvider("abc", 2)
expected := BuildAlphabetPreimage(0, "a'") expected := BuildAlphabetPreimage(0, "a'")
retrieved, proof, err := ap.GetPreimage(uint64(0)) retrieved, proof, err := ap.GetPreimage(context.Background(), uint64(0))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expected, retrieved) require.Equal(t, expected, retrieved)
require.Empty(t, proof) require.Empty(t, proof)
...@@ -70,14 +71,14 @@ func TestGetPreimage_Succeeds(t *testing.T) { ...@@ -70,14 +71,14 @@ func TestGetPreimage_Succeeds(t *testing.T) {
// function errors if the index is too large. // function errors if the index is too large.
func TestGetPreimage_TooLargeIndex_Fails(t *testing.T) { func TestGetPreimage_TooLargeIndex_Fails(t *testing.T) {
ap := NewAlphabetProvider("abc", 2) ap := NewAlphabetProvider("abc", 2)
_, _, err := ap.GetPreimage(4) _, _, err := ap.GetPreimage(context.Background(), 4)
require.ErrorIs(t, err, ErrIndexTooLarge) require.ErrorIs(t, err, ErrIndexTooLarge)
} }
// TestGet_Succeeds tests the Get function. // TestGet_Succeeds tests the Get function.
func TestGet_Succeeds(t *testing.T) { func TestGet_Succeeds(t *testing.T) {
ap := NewAlphabetProvider("abc", 2) ap := NewAlphabetProvider("abc", 2)
claim, err := ap.Get(0) claim, err := ap.Get(context.Background(), 0)
require.NoError(t, err) require.NoError(t, err)
expected := alphabetClaim(0, "a") expected := alphabetClaim(0, "a")
require.Equal(t, expected, claim) require.Equal(t, expected, claim)
...@@ -87,7 +88,7 @@ func TestGet_Succeeds(t *testing.T) { ...@@ -87,7 +88,7 @@ func TestGet_Succeeds(t *testing.T) {
// greater than the number of indices: 2^depth - 1. // greater than the number of indices: 2^depth - 1.
func TestGet_IndexTooLarge(t *testing.T) { func TestGet_IndexTooLarge(t *testing.T) {
ap := NewAlphabetProvider("abc", 2) ap := NewAlphabetProvider("abc", 2)
_, err := ap.Get(4) _, err := ap.Get(context.Background(), 4)
require.ErrorIs(t, err, ErrIndexTooLarge) require.ErrorIs(t, err, ErrIndexTooLarge)
} }
...@@ -95,7 +96,7 @@ func TestGet_IndexTooLarge(t *testing.T) { ...@@ -95,7 +96,7 @@ func TestGet_IndexTooLarge(t *testing.T) {
// than the trace, but smaller than the maximum depth. // than the trace, but smaller than the maximum depth.
func TestGet_Extends(t *testing.T) { func TestGet_Extends(t *testing.T) {
ap := NewAlphabetProvider("abc", 2) ap := NewAlphabetProvider("abc", 2)
claim, err := ap.Get(3) claim, err := ap.Get(context.Background(), 3)
require.NoError(t, err) require.NoError(t, err)
expected := alphabetClaim(2, "c") expected := alphabetClaim(2, "c")
require.Equal(t, expected, claim) require.Equal(t, expected, claim)
......
package cannon package cannon
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"os" "os"
...@@ -42,7 +43,7 @@ func NewExecutor(logger log.Logger, cfg *config.Config) *Executor { ...@@ -42,7 +43,7 @@ func NewExecutor(logger log.Logger, cfg *config.Config) *Executor {
} }
} }
func (e *Executor) GenerateProof(dir string, i uint64) error { func (e *Executor) GenerateProof(ctx context.Context, dir string, i uint64) error {
start, err := e.selectSnapshot(e.logger, filepath.Join(e.dataDir, snapsDir), e.absolutePreState, i) start, err := e.selectSnapshot(e.logger, filepath.Join(e.dataDir, snapsDir), e.absolutePreState, i)
if err != nil { if err != nil {
return fmt.Errorf("find starting snapshot: %w", err) return fmt.Errorf("find starting snapshot: %w", err)
......
package cannon package cannon
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
...@@ -28,7 +29,7 @@ type proofData struct { ...@@ -28,7 +29,7 @@ type proofData struct {
type ProofGenerator interface { type ProofGenerator interface {
// GenerateProof executes cannon to generate a proof at the specified trace index in dataDir. // GenerateProof executes cannon to generate a proof at the specified trace index in dataDir.
GenerateProof(dataDir string, proofAt uint64) error GenerateProof(ctx context.Context, dataDir string, proofAt uint64) error
} }
type CannonTraceProvider struct { type CannonTraceProvider struct {
...@@ -43,8 +44,8 @@ func NewCannonTraceProvider(logger log.Logger, cfg *config.Config) *CannonTraceP ...@@ -43,8 +44,8 @@ func NewCannonTraceProvider(logger log.Logger, cfg *config.Config) *CannonTraceP
} }
} }
func (p *CannonTraceProvider) GetOracleData(i uint64) (*types.PreimageOracleData, error) { func (p *CannonTraceProvider) GetOracleData(ctx context.Context, i uint64) (*types.PreimageOracleData, error) {
proof, err := p.loadProof(i) proof, err := p.loadProof(ctx, i)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -52,8 +53,8 @@ func (p *CannonTraceProvider) GetOracleData(i uint64) (*types.PreimageOracleData ...@@ -52,8 +53,8 @@ func (p *CannonTraceProvider) GetOracleData(i uint64) (*types.PreimageOracleData
return &data, nil return &data, nil
} }
func (p *CannonTraceProvider) Get(i uint64) (common.Hash, error) { func (p *CannonTraceProvider) Get(ctx context.Context, i uint64) (common.Hash, error) {
proof, err := p.loadProof(i) proof, err := p.loadProof(ctx, i)
if err != nil { if err != nil {
return common.Hash{}, err return common.Hash{}, err
} }
...@@ -65,8 +66,8 @@ func (p *CannonTraceProvider) Get(i uint64) (common.Hash, error) { ...@@ -65,8 +66,8 @@ func (p *CannonTraceProvider) Get(i uint64) (common.Hash, error) {
return value, nil return value, nil
} }
func (p *CannonTraceProvider) GetPreimage(i uint64) ([]byte, []byte, error) { func (p *CannonTraceProvider) GetPreimage(ctx context.Context, i uint64) ([]byte, []byte, error) {
proof, err := p.loadProof(i) proof, err := p.loadProof(ctx, i)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -81,15 +82,15 @@ func (p *CannonTraceProvider) GetPreimage(i uint64) ([]byte, []byte, error) { ...@@ -81,15 +82,15 @@ func (p *CannonTraceProvider) GetPreimage(i uint64) ([]byte, []byte, error) {
return value, data, nil return value, data, nil
} }
func (p *CannonTraceProvider) AbsolutePreState() []byte { func (p *CannonTraceProvider) AbsolutePreState(ctx context.Context) []byte {
panic("absolute prestate not yet supported") panic("absolute prestate not yet supported")
} }
func (p *CannonTraceProvider) loadProof(i uint64) (*proofData, error) { func (p *CannonTraceProvider) loadProof(ctx context.Context, i uint64) (*proofData, error) {
path := filepath.Join(p.dir, proofsDir, fmt.Sprintf("%d.json", i)) path := filepath.Join(p.dir, proofsDir, fmt.Sprintf("%d.json", i))
file, err := os.Open(path) file, err := os.Open(path)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
if err := p.generator.GenerateProof(p.dir, i); err != nil { if err := p.generator.GenerateProof(ctx, p.dir, i); err != nil {
return nil, fmt.Errorf("generate cannon trace with proof at %v: %w", i, err) return nil, fmt.Errorf("generate cannon trace with proof at %v: %w", i, err)
} }
// Try opening the file again now and it should exist. // Try opening the file again now and it should exist.
......
package cannon package cannon
import ( import (
"context"
"embed" "embed"
_ "embed" _ "embed"
"os" "os"
...@@ -18,7 +19,7 @@ func TestGet(t *testing.T) { ...@@ -18,7 +19,7 @@ func TestGet(t *testing.T) {
dataDir := setupTestData(t) dataDir := setupTestData(t)
t.Run("ExistingProof", func(t *testing.T) { t.Run("ExistingProof", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
value, err := provider.Get(0) value, err := provider.Get(context.Background(), 0)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, common.HexToHash("0x45fd9aa59768331c726e719e76aa343e73123af888804604785ae19506e65e87"), value) require.Equal(t, common.HexToHash("0x45fd9aa59768331c726e719e76aa343e73123af888804604785ae19506e65e87"), value)
require.Empty(t, generator.generated) require.Empty(t, generator.generated)
...@@ -26,21 +27,21 @@ func TestGet(t *testing.T) { ...@@ -26,21 +27,21 @@ func TestGet(t *testing.T) {
t.Run("ProofUnavailable", func(t *testing.T) { t.Run("ProofUnavailable", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
_, err := provider.Get(7) _, err := provider.Get(context.Background(), 7)
require.ErrorIs(t, err, os.ErrNotExist) require.ErrorIs(t, err, os.ErrNotExist)
require.Contains(t, generator.generated, 7, "should have tried to generate the proof") require.Contains(t, generator.generated, 7, "should have tried to generate the proof")
}) })
t.Run("MissingPostHash", func(t *testing.T) { t.Run("MissingPostHash", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
_, err := provider.Get(1) _, err := provider.Get(context.Background(), 1)
require.ErrorContains(t, err, "missing post hash") require.ErrorContains(t, err, "missing post hash")
require.Empty(t, generator.generated) require.Empty(t, generator.generated)
}) })
t.Run("IgnoreUnknownFields", func(t *testing.T) { t.Run("IgnoreUnknownFields", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
value, err := provider.Get(2) value, err := provider.Get(context.Background(), 2)
require.NoError(t, err) require.NoError(t, err)
expected := common.HexToHash("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") expected := common.HexToHash("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
require.Equal(t, expected, value) require.Equal(t, expected, value)
...@@ -52,7 +53,7 @@ func TestGetOracleData(t *testing.T) { ...@@ -52,7 +53,7 @@ func TestGetOracleData(t *testing.T) {
dataDir := setupTestData(t) dataDir := setupTestData(t)
t.Run("ExistingProof", func(t *testing.T) { t.Run("ExistingProof", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
oracleData, err := provider.GetOracleData(420) oracleData, err := provider.GetOracleData(context.Background(), 420)
require.NoError(t, err) require.NoError(t, err)
require.False(t, oracleData.IsLocal) require.False(t, oracleData.IsLocal)
expectedKey := common.Hex2Bytes("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") expectedKey := common.Hex2Bytes("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
...@@ -64,14 +65,14 @@ func TestGetOracleData(t *testing.T) { ...@@ -64,14 +65,14 @@ func TestGetOracleData(t *testing.T) {
t.Run("ProofUnavailable", func(t *testing.T) { t.Run("ProofUnavailable", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
_, err := provider.GetOracleData(7) _, err := provider.GetOracleData(context.Background(), 7)
require.ErrorIs(t, err, os.ErrNotExist) require.ErrorIs(t, err, os.ErrNotExist)
require.Contains(t, generator.generated, 7, "should have tried to generate the proof") require.Contains(t, generator.generated, 7, "should have tried to generate the proof")
}) })
t.Run("IgnoreUnknownFields", func(t *testing.T) { t.Run("IgnoreUnknownFields", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
oracleData, err := provider.GetOracleData(421) oracleData, err := provider.GetOracleData(context.Background(), 421)
require.NoError(t, err) require.NoError(t, err)
require.False(t, oracleData.IsLocal) require.False(t, oracleData.IsLocal)
expectedKey := common.Hex2Bytes("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") expectedKey := common.Hex2Bytes("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
...@@ -86,7 +87,7 @@ func TestGetPreimage(t *testing.T) { ...@@ -86,7 +87,7 @@ func TestGetPreimage(t *testing.T) {
dataDir := setupTestData(t) dataDir := setupTestData(t)
t.Run("ExistingProof", func(t *testing.T) { t.Run("ExistingProof", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
value, proof, err := provider.GetPreimage(0) value, proof, err := provider.GetPreimage(context.Background(), 0)
require.NoError(t, err) require.NoError(t, err)
expected := common.Hex2Bytes("b8f068de604c85ea0e2acd437cdb47add074a2d70b81d018390c504b71fe26f400000000000000000000000000000000000000000000000000000000000000000000000000") expected := common.Hex2Bytes("b8f068de604c85ea0e2acd437cdb47add074a2d70b81d018390c504b71fe26f400000000000000000000000000000000000000000000000000000000000000000000000000")
require.Equal(t, expected, value) require.Equal(t, expected, value)
...@@ -97,21 +98,21 @@ func TestGetPreimage(t *testing.T) { ...@@ -97,21 +98,21 @@ func TestGetPreimage(t *testing.T) {
t.Run("ProofUnavailable", func(t *testing.T) { t.Run("ProofUnavailable", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
_, _, err := provider.GetPreimage(7) _, _, err := provider.GetPreimage(context.Background(), 7)
require.ErrorIs(t, err, os.ErrNotExist) require.ErrorIs(t, err, os.ErrNotExist)
require.Contains(t, generator.generated, 7, "should have tried to generate the proof") require.Contains(t, generator.generated, 7, "should have tried to generate the proof")
}) })
t.Run("MissingStateData", func(t *testing.T) { t.Run("MissingStateData", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
_, _, err := provider.GetPreimage(1) _, _, err := provider.GetPreimage(context.Background(), 1)
require.ErrorContains(t, err, "missing state data") require.ErrorContains(t, err, "missing state data")
require.Empty(t, generator.generated) require.Empty(t, generator.generated)
}) })
t.Run("IgnoreUnknownFields", func(t *testing.T) { t.Run("IgnoreUnknownFields", func(t *testing.T) {
provider, generator := setupWithTestData(dataDir) provider, generator := setupWithTestData(dataDir)
value, proof, err := provider.GetPreimage(2) value, proof, err := provider.GetPreimage(context.Background(), 2)
require.NoError(t, err) require.NoError(t, err)
expected := common.Hex2Bytes("cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc") expected := common.Hex2Bytes("cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc")
require.Equal(t, expected, value) require.Equal(t, expected, value)
...@@ -149,7 +150,7 @@ type stubGenerator struct { ...@@ -149,7 +150,7 @@ type stubGenerator struct {
generated []int // Using int makes assertions easier generated []int // Using int makes assertions easier
} }
func (e *stubGenerator) GenerateProof(dir string, i uint64) error { func (e *stubGenerator) GenerateProof(ctx context.Context, dir string, i uint64) error {
e.generated = append(e.generated, int(i)) e.generated = append(e.generated, int(i))
return nil return nil
} }
package solver package solver
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
...@@ -28,21 +29,21 @@ func NewSolver(gameDepth int, traceProvider types.TraceProvider) *Solver { ...@@ -28,21 +29,21 @@ func NewSolver(gameDepth int, traceProvider types.TraceProvider) *Solver {
} }
// NextMove returns the next move to make given the current state of the game. // NextMove returns the next move to make given the current state of the game.
func (s *Solver) NextMove(claim types.Claim, agreeWithClaimLevel bool) (*types.Claim, error) { func (s *Solver) NextMove(ctx context.Context, claim types.Claim, agreeWithClaimLevel bool) (*types.Claim, error) {
if agreeWithClaimLevel { if agreeWithClaimLevel {
return nil, nil return nil, nil
} }
if claim.Depth() == s.gameDepth { if claim.Depth() == s.gameDepth {
return nil, types.ErrGameDepthReached return nil, types.ErrGameDepthReached
} }
agree, err := s.agreeWithClaim(claim.ClaimData) agree, err := s.agreeWithClaim(ctx, claim.ClaimData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if agree { if agree {
return s.defend(claim) return s.defend(ctx, claim)
} else { } else {
return s.attack(claim) return s.attack(ctx, claim)
} }
} }
...@@ -56,14 +57,14 @@ type StepData struct { ...@@ -56,14 +57,14 @@ type StepData struct {
// AttemptStep determines what step should occur for a given leaf claim. // AttemptStep determines what step should occur for a given leaf claim.
// An error will be returned if the claim is not at the max depth. // An error will be returned if the claim is not at the max depth.
func (s *Solver) AttemptStep(claim types.Claim, agreeWithClaimLevel bool) (StepData, error) { func (s *Solver) AttemptStep(ctx context.Context, claim types.Claim, agreeWithClaimLevel bool) (StepData, error) {
if claim.Depth() != s.gameDepth { if claim.Depth() != s.gameDepth {
return StepData{}, ErrStepNonLeafNode return StepData{}, ErrStepNonLeafNode
} }
if agreeWithClaimLevel { if agreeWithClaimLevel {
return StepData{}, ErrStepAgreedClaim return StepData{}, ErrStepAgreedClaim
} }
claimCorrect, err := s.agreeWithClaim(claim.ClaimData) claimCorrect, err := s.agreeWithClaim(ctx, claim.ClaimData)
if err != nil { if err != nil {
return StepData{}, err return StepData{}, err
} }
...@@ -72,19 +73,19 @@ func (s *Solver) AttemptStep(claim types.Claim, agreeWithClaimLevel bool) (StepD ...@@ -72,19 +73,19 @@ func (s *Solver) AttemptStep(claim types.Claim, agreeWithClaimLevel bool) (StepD
var proofData []byte var proofData []byte
// If we are attacking index 0, we provide the absolute pre-state, not an intermediate state // If we are attacking index 0, we provide the absolute pre-state, not an intermediate state
if index == 0 && !claimCorrect { if index == 0 && !claimCorrect {
preState = s.trace.AbsolutePreState() preState = s.trace.AbsolutePreState(ctx)
} else { } else {
// If attacking, get the state just before, other get the state after // If attacking, get the state just before, other get the state after
if !claimCorrect { if !claimCorrect {
index = index - 1 index = index - 1
} }
preState, proofData, err = s.trace.GetPreimage(index) preState, proofData, err = s.trace.GetPreimage(ctx, index)
if err != nil { if err != nil {
return StepData{}, err return StepData{}, err
} }
} }
oracleData, err := s.trace.GetOracleData(index) oracleData, err := s.trace.GetOracleData(ctx, index)
if err != nil { if err != nil {
return StepData{}, err return StepData{}, err
} }
...@@ -99,9 +100,9 @@ func (s *Solver) AttemptStep(claim types.Claim, agreeWithClaimLevel bool) (StepD ...@@ -99,9 +100,9 @@ func (s *Solver) AttemptStep(claim types.Claim, agreeWithClaimLevel bool) (StepD
} }
// attack returns a response that attacks the claim. // attack returns a response that attacks the claim.
func (s *Solver) attack(claim types.Claim) (*types.Claim, error) { func (s *Solver) attack(ctx context.Context, claim types.Claim) (*types.Claim, error) {
position := claim.Attack() position := claim.Attack()
value, err := s.traceAtPosition(position) value, err := s.traceAtPosition(ctx, position)
if err != nil { if err != nil {
return nil, fmt.Errorf("attack claim: %w", err) return nil, fmt.Errorf("attack claim: %w", err)
} }
...@@ -113,12 +114,12 @@ func (s *Solver) attack(claim types.Claim) (*types.Claim, error) { ...@@ -113,12 +114,12 @@ func (s *Solver) attack(claim types.Claim) (*types.Claim, error) {
} }
// defend returns a response that defends the claim. // defend returns a response that defends the claim.
func (s *Solver) defend(claim types.Claim) (*types.Claim, error) { func (s *Solver) defend(ctx context.Context, claim types.Claim) (*types.Claim, error) {
if claim.IsRoot() { if claim.IsRoot() {
return nil, nil return nil, nil
} }
position := claim.Defend() position := claim.Defend()
value, err := s.traceAtPosition(position) value, err := s.traceAtPosition(ctx, position)
if err != nil { if err != nil {
return nil, fmt.Errorf("defend claim: %w", err) return nil, fmt.Errorf("defend claim: %w", err)
} }
...@@ -130,14 +131,14 @@ func (s *Solver) defend(claim types.Claim) (*types.Claim, error) { ...@@ -130,14 +131,14 @@ func (s *Solver) defend(claim types.Claim) (*types.Claim, error) {
} }
// agreeWithClaim returns true if the claim is correct according to the internal [TraceProvider]. // agreeWithClaim returns true if the claim is correct according to the internal [TraceProvider].
func (s *Solver) agreeWithClaim(claim types.ClaimData) (bool, error) { func (s *Solver) agreeWithClaim(ctx context.Context, claim types.ClaimData) (bool, error) {
ourValue, err := s.traceAtPosition(claim.Position) ourValue, err := s.traceAtPosition(ctx, claim.Position)
return ourValue == claim.Value, err return ourValue == claim.Value, err
} }
// traceAtPosition returns the [common.Hash] from internal [TraceProvider] at the given [Position]. // traceAtPosition returns the [common.Hash] from internal [TraceProvider] at the given [Position].
func (s *Solver) traceAtPosition(p types.Position) (common.Hash, error) { func (s *Solver) traceAtPosition(ctx context.Context, p types.Position) (common.Hash, error) {
index := p.TraceIndex(s.gameDepth) index := p.TraceIndex(s.gameDepth)
hash, err := s.trace.Get(index) hash, err := s.trace.Get(ctx, index)
return hash, err return hash, err
} }
package solver_test package solver_test
import ( import (
"context"
"errors" "errors"
"testing" "testing"
...@@ -84,7 +85,7 @@ func TestNextMove(t *testing.T) { ...@@ -84,7 +85,7 @@ func TestNextMove(t *testing.T) {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
solver := solver.NewSolver(maxDepth, builder.CorrectTraceProvider()) solver := solver.NewSolver(maxDepth, builder.CorrectTraceProvider())
move, err := solver.NextMove(test.claim, test.agreeWithLevel) move, err := solver.NextMove(context.Background(), test.claim, test.agreeWithLevel)
if test.expectedErr == nil { if test.expectedErr == nil {
require.NoError(t, err) require.NoError(t, err)
} else { } else {
...@@ -110,6 +111,8 @@ func TestAttemptStep(t *testing.T) { ...@@ -110,6 +111,8 @@ func TestAttemptStep(t *testing.T) {
errProvider := errors.New("provider error") errProvider := errors.New("provider error")
ctx := context.Background()
tests := []struct { tests := []struct {
name string name string
claim types.Claim claim types.Claim
...@@ -126,7 +129,7 @@ func TestAttemptStep(t *testing.T) { ...@@ -126,7 +129,7 @@ func TestAttemptStep(t *testing.T) {
name: "AttackFirstTraceIndex", name: "AttackFirstTraceIndex",
claim: builder.CreateLeafClaim(0, false), claim: builder.CreateLeafClaim(0, false),
expectAttack: true, expectAttack: true,
expectPreState: builder.CorrectTraceProvider().AbsolutePreState(), expectPreState: builder.CorrectTraceProvider().AbsolutePreState(ctx),
expectProofData: nil, expectProofData: nil,
expectedOracleKey: []byte{byte(0)}, expectedOracleKey: []byte{byte(0)},
expectedOracleData: []byte{byte(0)}, expectedOracleData: []byte{byte(0)},
...@@ -222,7 +225,7 @@ func TestAttemptStep(t *testing.T) { ...@@ -222,7 +225,7 @@ func TestAttemptStep(t *testing.T) {
} }
builder = test.NewClaimBuilder(t, maxDepth, alphabetProvider) builder = test.NewClaimBuilder(t, maxDepth, alphabetProvider)
alphabetSolver := solver.NewSolver(maxDepth, builder.CorrectTraceProvider()) alphabetSolver := solver.NewSolver(maxDepth, builder.CorrectTraceProvider())
step, err := alphabetSolver.AttemptStep(tableTest.claim, tableTest.agreeWithLevel) step, err := alphabetSolver.AttemptStep(ctx, tableTest.claim, tableTest.agreeWithLevel)
if tableTest.expectedErr == nil { if tableTest.expectedErr == nil {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, tableTest.claim, step.LeafClaim) require.Equal(t, tableTest.claim, step.LeafClaim)
......
package test package test
import ( import (
"context"
"testing" "testing"
"github.com/ethereum-optimism/optimism/op-challenger/fault/alphabet" "github.com/ethereum-optimism/optimism/op-challenger/fault/alphabet"
...@@ -24,15 +25,15 @@ type alphabetWithProofProvider struct { ...@@ -24,15 +25,15 @@ type alphabetWithProofProvider struct {
OracleError error OracleError error
} }
func (a *alphabetWithProofProvider) GetPreimage(i uint64) ([]byte, []byte, error) { func (a *alphabetWithProofProvider) GetPreimage(ctx context.Context, i uint64) ([]byte, []byte, error) {
preimage, _, err := a.AlphabetProvider.GetPreimage(i) preimage, _, err := a.AlphabetProvider.GetPreimage(ctx, i)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return preimage, []byte{byte(i)}, nil return preimage, []byte{byte(i)}, nil
} }
func (a *alphabetWithProofProvider) GetOracleData(i uint64) (*types.PreimageOracleData, error) { func (a *alphabetWithProofProvider) GetOracleData(ctx context.Context, i uint64) (*types.PreimageOracleData, error) {
if a.OracleError != nil { if a.OracleError != nil {
return &types.PreimageOracleData{}, a.OracleError return &types.PreimageOracleData{}, a.OracleError
} }
......
package test package test
import ( import (
"context"
"math/big" "math/big"
"testing" "testing"
...@@ -32,21 +33,21 @@ func (c *ClaimBuilder) CorrectTraceProvider() types.TraceProvider { ...@@ -32,21 +33,21 @@ func (c *ClaimBuilder) CorrectTraceProvider() types.TraceProvider {
// CorrectClaim returns the canonical claim at a specified trace index // CorrectClaim returns the canonical claim at a specified trace index
func (c *ClaimBuilder) CorrectClaim(idx uint64) common.Hash { func (c *ClaimBuilder) CorrectClaim(idx uint64) common.Hash {
value, err := c.correct.Get(idx) value, err := c.correct.Get(context.Background(), idx)
c.require.NoError(err) c.require.NoError(err)
return value return value
} }
// CorrectPreState returns the pre-image of the canonical claim at the specified trace index // CorrectPreState returns the pre-image of the canonical claim at the specified trace index
func (c *ClaimBuilder) CorrectPreState(idx uint64) []byte { func (c *ClaimBuilder) CorrectPreState(idx uint64) []byte {
preimage, _, err := c.correct.GetPreimage(idx) preimage, _, err := c.correct.GetPreimage(context.Background(), idx)
c.require.NoError(err) c.require.NoError(err)
return preimage return preimage
} }
// CorrectProofData returns the proof-data for the canonical claim at the specified trace index // CorrectProofData returns the proof-data for the canonical claim at the specified trace index
func (c *ClaimBuilder) CorrectProofData(idx uint64) []byte { func (c *ClaimBuilder) CorrectProofData(idx uint64) []byte {
_, proof, err := c.correct.GetPreimage(idx) _, proof, err := c.correct.GetPreimage(context.Background(), idx)
c.require.NoError(err) c.require.NoError(err)
return proof return proof
} }
......
package types package types
import ( import (
"context"
"errors" "errors"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
...@@ -47,19 +48,19 @@ type StepCallData struct { ...@@ -47,19 +48,19 @@ type StepCallData struct {
type TraceProvider interface { type TraceProvider interface {
// Get returns the claim value at the requested index. // Get returns the claim value at the requested index.
// Get(i) = Keccak256(GetPreimage(i)) // Get(i) = Keccak256(GetPreimage(i))
Get(i uint64) (common.Hash, error) Get(ctx context.Context, i uint64) (common.Hash, error)
// GetOracleData returns preimage oracle data that can be submitted to the pre-image // GetOracleData returns preimage oracle data that can be submitted to the pre-image
// oracle and the dispute game contract. This function accepts a trace index for // oracle and the dispute game contract. This function accepts a trace index for
// which the provider returns needed preimage data. // which the provider returns needed preimage data.
GetOracleData(i uint64) (*PreimageOracleData, error) GetOracleData(ctx context.Context, i uint64) (*PreimageOracleData, error)
// GetPreimage returns the pre-image for a claim at the specified trace index, along // GetPreimage returns the pre-image for a claim at the specified trace index, along
// with any associated proof data to assist in its verification. // with any associated proof data to assist in its verification.
GetPreimage(i uint64) (preimage []byte, proofData []byte, err error) GetPreimage(ctx context.Context, i uint64) (preimage []byte, proofData []byte, err error)
// AbsolutePreState is the pre-image value of the trace that transitions to the trace value at index 0 // AbsolutePreState is the pre-image value of the trace that transitions to the trace value at index 0
AbsolutePreState() []byte AbsolutePreState(ctx context.Context) []byte
} }
// ClaimData is the core of a claim. It must be unique inside a specific game. // ClaimData is the core of a claim. It must be unique inside a specific game.
......
...@@ -68,7 +68,7 @@ func (h *FactoryHelper) StartAlphabetGame(ctx context.Context, claimedAlphabet s ...@@ -68,7 +68,7 @@ func (h *FactoryHelper) StartAlphabetGame(ctx context.Context, claimedAlphabet s
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel() defer cancel()
trace := alphabet.NewAlphabetProvider(claimedAlphabet, 4) trace := alphabet.NewAlphabetProvider(claimedAlphabet, 4)
rootClaim, err := trace.Get(lastAlphabetTraceIndex) rootClaim, err := trace.Get(ctx, lastAlphabetTraceIndex)
h.require.NoError(err, "get root claim") h.require.NoError(err, "get root claim")
tx, err := h.factory.Create(h.opts, faultGameType, rootClaim, alphaExtraData) tx, err := h.factory.Create(h.opts, faultGameType, rootClaim, alphaExtraData)
h.require.NoError(err, "create fault dispute game") h.require.NoError(err, "create fault dispute game")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment