Commit e7685d1b authored by refcell.eth's avatar refcell.eth Committed by GitHub

Merge pull request #7321 from ethereum-optimism/refcell/deep-positions

feat(op-challenger): Deep Positions
parents 44761db3 19f79596
...@@ -119,10 +119,6 @@ func TestLoader_FetchClaims(t *testing.T) { ...@@ -119,10 +119,6 @@ func TestLoader_FetchClaims(t *testing.T) {
Value: expectedClaims[0].Claim, Value: expectedClaims[0].Claim,
Position: types.NewPositionFromGIndex(expectedClaims[0].Position.Uint64()), Position: types.NewPositionFromGIndex(expectedClaims[0].Position.Uint64()),
}, },
Parent: types.ClaimData{
Value: expectedClaims[0].Claim,
Position: types.NewPositionFromGIndex(expectedClaims[0].Position.Uint64()),
},
Countered: false, Countered: false,
Clock: uint64(0), Clock: uint64(0),
ContractIndex: 0, ContractIndex: 0,
...@@ -134,11 +130,12 @@ func TestLoader_FetchClaims(t *testing.T) { ...@@ -134,11 +130,12 @@ func TestLoader_FetchClaims(t *testing.T) {
}, },
Parent: types.ClaimData{ Parent: types.ClaimData{
Value: expectedClaims[0].Claim, Value: expectedClaims[0].Claim,
Position: types.NewPositionFromGIndex(expectedClaims[1].Position.Uint64()), Position: types.NewPositionFromGIndex(expectedClaims[0].Position.Uint64()),
}, },
Countered: false, Countered: false,
Clock: uint64(0), Clock: uint64(0),
ContractIndex: 1, ContractIndex: 1,
ParentContractIndex: 0,
}, },
{ {
ClaimData: types.ClaimData{ ClaimData: types.ClaimData{
...@@ -146,12 +143,13 @@ func TestLoader_FetchClaims(t *testing.T) { ...@@ -146,12 +143,13 @@ func TestLoader_FetchClaims(t *testing.T) {
Position: types.NewPositionFromGIndex(expectedClaims[2].Position.Uint64()), Position: types.NewPositionFromGIndex(expectedClaims[2].Position.Uint64()),
}, },
Parent: types.ClaimData{ Parent: types.ClaimData{
Value: expectedClaims[0].Claim, Value: expectedClaims[1].Claim,
Position: types.NewPositionFromGIndex(expectedClaims[2].Position.Uint64()), Position: types.NewPositionFromGIndex(expectedClaims[1].Position.Uint64()),
}, },
Countered: false, Countered: false,
Clock: uint64(0), Clock: uint64(0),
ContractIndex: 2, ContractIndex: 2,
ParentContractIndex: 1,
}, },
}, claims) }, claims)
}) })
...@@ -204,21 +202,23 @@ func newMockCaller() *mockCaller { ...@@ -204,21 +202,23 @@ func newMockCaller() *mockCaller {
}{ }{
{ {
Claim: [32]byte{0x00}, Claim: [32]byte{0x00},
Position: big.NewInt(0), Position: big.NewInt(1),
Countered: false, Countered: false,
Clock: big.NewInt(0), Clock: big.NewInt(0),
}, },
{ {
Claim: [32]byte{0x01}, Claim: [32]byte{0x01},
Position: big.NewInt(0), Position: big.NewInt(2),
Countered: false, Countered: false,
Clock: big.NewInt(0), Clock: big.NewInt(0),
ParentIndex: 0,
}, },
{ {
Claim: [32]byte{0x02}, Claim: [32]byte{0x02},
Position: big.NewInt(0), Position: big.NewInt(3),
Countered: false, Countered: false,
Clock: big.NewInt(0), Clock: big.NewInt(0),
ParentIndex: 1,
}, },
}, },
} }
...@@ -240,7 +240,7 @@ func (m *mockCaller) ClaimData(opts *bind.CallOpts, arg0 *big.Int) (struct { ...@@ -240,7 +240,7 @@ func (m *mockCaller) ClaimData(opts *bind.CallOpts, arg0 *big.Int) (struct {
Clock *big.Int Clock *big.Int
}{}, mockClaimDataError }{}, mockClaimDataError
} }
returnClaim := m.returnClaims[m.currentIndex] returnClaim := m.returnClaims[arg0.Uint64()]
m.currentIndex++ m.currentIndex++
return returnClaim, nil return returnClaim, nil
} }
......
...@@ -56,7 +56,6 @@ func TestCalculateNextActions(t *testing.T) { ...@@ -56,7 +56,6 @@ func TestCalculateNextActions(t *testing.T) {
builder.Seq().AttackCorrect() builder.Seq().AttackCorrect()
}, },
}, },
{ {
name: "RespondToAllClaimsAtDisagreeingLevel", name: "RespondToAllClaimsAtDisagreeingLevel",
agreeWithOutputRoot: true, agreeWithOutputRoot: true,
...@@ -70,7 +69,6 @@ func TestCalculateNextActions(t *testing.T) { ...@@ -70,7 +69,6 @@ func TestCalculateNextActions(t *testing.T) {
honestClaim.Defend(common.Hash{0xdd}).ExpectAttack() honestClaim.Defend(common.Hash{0xdd}).ExpectAttack()
}, },
}, },
{ {
name: "StepAtMaxDepth", name: "StepAtMaxDepth",
agreeWithOutputRoot: true, agreeWithOutputRoot: true,
...@@ -83,7 +81,6 @@ func TestCalculateNextActions(t *testing.T) { ...@@ -83,7 +81,6 @@ func TestCalculateNextActions(t *testing.T) {
lastHonestClaim.Attack(common.Hash{0xdd}).ExpectStepAttack() lastHonestClaim.Attack(common.Hash{0xdd}).ExpectStepAttack()
}, },
}, },
{ {
name: "PoisonedPreState", name: "PoisonedPreState",
agreeWithOutputRoot: true, agreeWithOutputRoot: true,
......
...@@ -2,6 +2,7 @@ package solver ...@@ -2,6 +2,7 @@ package solver
import ( import (
"context" "context"
"math/big"
"testing" "testing"
faulttest "github.com/ethereum-optimism/optimism/op-challenger/game/fault/test" faulttest "github.com/ethereum-optimism/optimism/op-challenger/game/fault/test"
...@@ -16,7 +17,8 @@ func TestAttemptStep(t *testing.T) { ...@@ -16,7 +17,8 @@ func TestAttemptStep(t *testing.T) {
// Last accessible leaf is the second last trace index // Last accessible leaf is the second last trace index
// The root node is used for the last trace index and can only be attacked. // The root node is used for the last trace index and can only be attacked.
lastLeafTraceIndex := uint64(1<<maxDepth - 2) lastLeafTraceIndex := big.NewInt(1<<maxDepth - 2)
lastLeafTraceIndexPlusOne := big.NewInt(1<<maxDepth - 1)
ctx := context.Background() ctx := context.Background()
tests := []struct { tests := []struct {
...@@ -32,9 +34,9 @@ func TestAttemptStep(t *testing.T) { ...@@ -32,9 +34,9 @@ func TestAttemptStep(t *testing.T) {
{ {
name: "AttackFirstTraceIndex", name: "AttackFirstTraceIndex",
expectAttack: true, expectAttack: true,
expectPreState: claimBuilder.CorrectPreState(0), expectPreState: claimBuilder.CorrectPreState(common.Big0),
expectProofData: claimBuilder.CorrectProofData(0), expectProofData: claimBuilder.CorrectProofData(common.Big0),
expectedOracleData: claimBuilder.CorrectOracleData(0), expectedOracleData: claimBuilder.CorrectOracleData(common.Big0),
setupGame: func(builder *faulttest.GameBuilder) { setupGame: func(builder *faulttest.GameBuilder) {
builder.Seq(). builder.Seq().
Attack(common.Hash{0xaa}). Attack(common.Hash{0xaa}).
...@@ -45,9 +47,9 @@ func TestAttemptStep(t *testing.T) { ...@@ -45,9 +47,9 @@ func TestAttemptStep(t *testing.T) {
{ {
name: "DefendFirstTraceIndex", name: "DefendFirstTraceIndex",
expectAttack: false, expectAttack: false,
expectPreState: claimBuilder.CorrectPreState(1), expectPreState: claimBuilder.CorrectPreState(big.NewInt(1)),
expectProofData: claimBuilder.CorrectProofData(1), expectProofData: claimBuilder.CorrectProofData(big.NewInt(1)),
expectedOracleData: claimBuilder.CorrectOracleData(1), expectedOracleData: claimBuilder.CorrectOracleData(big.NewInt(1)),
setupGame: func(builder *faulttest.GameBuilder) { setupGame: func(builder *faulttest.GameBuilder) {
builder.Seq(). builder.Seq().
Attack(common.Hash{0xaa}). Attack(common.Hash{0xaa}).
...@@ -58,9 +60,9 @@ func TestAttemptStep(t *testing.T) { ...@@ -58,9 +60,9 @@ func TestAttemptStep(t *testing.T) {
{ {
name: "AttackMiddleTraceIndex", name: "AttackMiddleTraceIndex",
expectAttack: true, expectAttack: true,
expectPreState: claimBuilder.CorrectPreState(4), expectPreState: claimBuilder.CorrectPreState(big.NewInt(4)),
expectProofData: claimBuilder.CorrectProofData(4), expectProofData: claimBuilder.CorrectProofData(big.NewInt(4)),
expectedOracleData: claimBuilder.CorrectOracleData(4), expectedOracleData: claimBuilder.CorrectOracleData(big.NewInt(4)),
setupGame: func(builder *faulttest.GameBuilder) { setupGame: func(builder *faulttest.GameBuilder) {
builder.Seq(). builder.Seq().
AttackCorrect(). AttackCorrect().
...@@ -71,9 +73,9 @@ func TestAttemptStep(t *testing.T) { ...@@ -71,9 +73,9 @@ func TestAttemptStep(t *testing.T) {
{ {
name: "DefendMiddleTraceIndex", name: "DefendMiddleTraceIndex",
expectAttack: false, expectAttack: false,
expectPreState: claimBuilder.CorrectPreState(5), expectPreState: claimBuilder.CorrectPreState(big.NewInt(5)),
expectProofData: claimBuilder.CorrectProofData(5), expectProofData: claimBuilder.CorrectProofData(big.NewInt(5)),
expectedOracleData: claimBuilder.CorrectOracleData(5), expectedOracleData: claimBuilder.CorrectOracleData(big.NewInt(5)),
setupGame: func(builder *faulttest.GameBuilder) { setupGame: func(builder *faulttest.GameBuilder) {
builder.Seq(). builder.Seq().
AttackCorrect(). AttackCorrect().
...@@ -97,9 +99,9 @@ func TestAttemptStep(t *testing.T) { ...@@ -97,9 +99,9 @@ func TestAttemptStep(t *testing.T) {
{ {
name: "DefendLastTraceIndex", name: "DefendLastTraceIndex",
expectAttack: false, expectAttack: false,
expectPreState: claimBuilder.CorrectPreState(lastLeafTraceIndex + 1), expectPreState: claimBuilder.CorrectPreState(lastLeafTraceIndexPlusOne),
expectProofData: claimBuilder.CorrectProofData(lastLeafTraceIndex + 1), expectProofData: claimBuilder.CorrectProofData(lastLeafTraceIndexPlusOne),
expectedOracleData: claimBuilder.CorrectOracleData(lastLeafTraceIndex + 1), expectedOracleData: claimBuilder.CorrectOracleData(lastLeafTraceIndexPlusOne),
setupGame: func(builder *faulttest.GameBuilder) { setupGame: func(builder *faulttest.GameBuilder) {
builder.Seq(). builder.Seq().
AttackCorrect(). AttackCorrect().
...@@ -140,9 +142,9 @@ func TestAttemptStep(t *testing.T) { ...@@ -140,9 +142,9 @@ func TestAttemptStep(t *testing.T) {
{ {
name: "CannotStepNearlyValidPath", name: "CannotStepNearlyValidPath",
expectAttack: true, expectAttack: true,
expectPreState: claimBuilder.CorrectPreState(4), expectPreState: claimBuilder.CorrectPreState(big.NewInt(4)),
expectProofData: claimBuilder.CorrectProofData(4), expectProofData: claimBuilder.CorrectProofData(big.NewInt(4)),
expectedOracleData: claimBuilder.CorrectOracleData(4), expectedOracleData: claimBuilder.CorrectOracleData(big.NewInt(4)),
setupGame: func(builder *faulttest.GameBuilder) { setupGame: func(builder *faulttest.GameBuilder) {
builder.Seq(). builder.Seq().
AttackCorrect(). AttackCorrect().
......
...@@ -32,7 +32,7 @@ func (a *alphabetWithProofProvider) GetStepData(ctx context.Context, i types.Pos ...@@ -32,7 +32,7 @@ func (a *alphabetWithProofProvider) GetStepData(ctx context.Context, i types.Pos
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
traceIndex := i.TraceIndex(int(a.depth)) traceIndex := i.TraceIndex(int(a.depth)).Uint64()
data := types.NewPreimageOracleData([]byte{byte(traceIndex)}, []byte{byte(traceIndex - 1)}, uint32(traceIndex-1)) data := types.NewPreimageOracleData([]byte{byte(traceIndex)}, []byte{byte(traceIndex - 1)}, uint32(traceIndex-1))
return preimage, []byte{byte(traceIndex - 1)}, data, nil return preimage, []byte{byte(traceIndex - 1)}, data, nil
} }
...@@ -39,30 +39,30 @@ func (c *ClaimBuilder) CorrectClaimAtPosition(pos types.Position) common.Hash { ...@@ -39,30 +39,30 @@ func (c *ClaimBuilder) CorrectClaimAtPosition(pos types.Position) common.Hash {
} }
// CorrectPreState returns the pre-state (not hashed) required to execute the valid step at the specified trace index // CorrectPreState returns the pre-state (not hashed) required to execute the valid step at the specified trace index
func (c *ClaimBuilder) CorrectPreState(idx uint64) []byte { func (c *ClaimBuilder) CorrectPreState(idx *big.Int) []byte {
pos := types.NewPosition(c.maxDepth, int(idx)) pos := types.NewPosition(c.maxDepth, idx)
preimage, _, _, err := c.correct.GetStepData(context.Background(), pos) preimage, _, _, err := c.correct.GetStepData(context.Background(), pos)
c.require.NoError(err) c.require.NoError(err)
return preimage return preimage
} }
// CorrectProofData returns the proof-data required to execute the valid step at the specified trace index // CorrectProofData returns the proof-data required to execute the valid step at the specified trace index
func (c *ClaimBuilder) CorrectProofData(idx uint64) []byte { func (c *ClaimBuilder) CorrectProofData(idx *big.Int) []byte {
pos := types.NewPosition(c.maxDepth, int(idx)) pos := types.NewPosition(c.maxDepth, idx)
_, proof, _, err := c.correct.GetStepData(context.Background(), pos) _, proof, _, err := c.correct.GetStepData(context.Background(), pos)
c.require.NoError(err) c.require.NoError(err)
return proof return proof
} }
func (c *ClaimBuilder) CorrectOracleData(idx uint64) *types.PreimageOracleData { func (c *ClaimBuilder) CorrectOracleData(idx *big.Int) *types.PreimageOracleData {
pos := types.NewPosition(c.maxDepth, int(idx)) pos := types.NewPosition(c.maxDepth, idx)
_, _, data, err := c.correct.GetStepData(context.Background(), pos) _, _, data, err := c.correct.GetStepData(context.Background(), pos)
c.require.NoError(err) c.require.NoError(err)
return data return data
} }
func (c *ClaimBuilder) incorrectClaim(pos types.Position) common.Hash { func (c *ClaimBuilder) incorrectClaim(pos types.Position) common.Hash {
return common.BigToHash(new(big.Int).SetUint64(pos.TraceIndex(c.maxDepth))) return common.BigToHash(pos.TraceIndex(c.maxDepth))
} }
func (c *ClaimBuilder) claim(pos types.Position, correct bool) common.Hash { func (c *ClaimBuilder) claim(pos types.Position, correct bool) common.Hash {
...@@ -78,15 +78,15 @@ func (c *ClaimBuilder) CreateRootClaim(correct bool) types.Claim { ...@@ -78,15 +78,15 @@ func (c *ClaimBuilder) CreateRootClaim(correct bool) types.Claim {
claim := types.Claim{ claim := types.Claim{
ClaimData: types.ClaimData{ ClaimData: types.ClaimData{
Value: value, Value: value,
Position: types.NewPosition(0, 0), Position: types.NewPosition(0, common.Big0),
}, },
} }
return claim return claim
} }
func (c *ClaimBuilder) CreateLeafClaim(traceIndex uint64, correct bool) types.Claim { func (c *ClaimBuilder) CreateLeafClaim(traceIndex *big.Int, correct bool) types.Claim {
parentPos := types.NewPosition(c.maxDepth-1, 0) parentPos := types.NewPosition(c.maxDepth-1, common.Big0)
pos := types.NewPosition(c.maxDepth, int(traceIndex)) pos := types.NewPosition(c.maxDepth, traceIndex)
return types.Claim{ return types.Claim{
ClaimData: types.ClaimData{ ClaimData: types.ClaimData{
Value: c.claim(pos, correct), Value: c.claim(pos, correct),
......
package test package test
import ( import (
"math/big"
"github.com/ethereum-optimism/optimism/op-challenger/game/fault/types" "github.com/ethereum-optimism/optimism/op-challenger/game/fault/types"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
...@@ -120,7 +122,7 @@ func (s *GameBuilderSeq) ExpectStepAttack() *GameBuilderSeq { ...@@ -120,7 +122,7 @@ func (s *GameBuilderSeq) ExpectStepAttack() *GameBuilderSeq {
} }
func (s *GameBuilderSeq) ExpectStepDefend() *GameBuilderSeq { func (s *GameBuilderSeq) ExpectStepDefend() *GameBuilderSeq {
traceIdx := s.lastClaim.TraceIndex(s.builder.maxDepth) + 1 traceIdx := new(big.Int).Add(s.lastClaim.TraceIndex(s.builder.maxDepth), big.NewInt(1))
s.gameBuilder.ExpectedActions = append(s.gameBuilder.ExpectedActions, types.Action{ s.gameBuilder.ExpectedActions = append(s.gameBuilder.ExpectedActions, types.Action{
Type: types.ActionTypeStep, Type: types.ActionTypeStep,
ParentIdx: s.lastClaim.ContractIndex, ParentIdx: s.lastClaim.ContractIndex,
......
...@@ -35,7 +35,7 @@ func NewTraceProvider(state string, depth uint64) *AlphabetTraceProvider { ...@@ -35,7 +35,7 @@ func NewTraceProvider(state string, depth uint64) *AlphabetTraceProvider {
func (ap *AlphabetTraceProvider) GetStepData(ctx context.Context, i types.Position) ([]byte, []byte, *types.PreimageOracleData, error) { func (ap *AlphabetTraceProvider) GetStepData(ctx context.Context, i types.Position) ([]byte, []byte, *types.PreimageOracleData, error) {
traceIndex := i.TraceIndex(int(ap.depth)) traceIndex := i.TraceIndex(int(ap.depth))
if traceIndex == 0 { if traceIndex.Cmp(common.Big0) == 0 {
prestate, err := ap.AbsolutePreState(ctx) prestate, err := ap.AbsolutePreState(ctx)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
...@@ -43,22 +43,23 @@ func (ap *AlphabetTraceProvider) GetStepData(ctx context.Context, i types.Positi ...@@ -43,22 +43,23 @@ func (ap *AlphabetTraceProvider) GetStepData(ctx context.Context, i types.Positi
return prestate, []byte{}, nil, nil return prestate, []byte{}, nil, nil
} }
// We want the pre-state which is the value prior to the one requested // We want the pre-state which is the value prior to the one requested
traceIndex-- traceIndex = traceIndex.Sub(traceIndex, big.NewInt(1))
// The index cannot be larger than the maximum index as computed by the depth. // The index cannot be larger than the maximum index as computed by the depth.
if traceIndex >= ap.maxLen { if traceIndex.Cmp(big.NewInt(int64(ap.maxLen))) >= 0 {
return nil, nil, nil, ErrIndexTooLarge return nil, nil, nil, ErrIndexTooLarge
} }
// We extend the deepest hash to the maximum depth if the trace is not expansive. // We extend the deepest hash to the maximum depth if the trace is not expansive.
if traceIndex >= uint64(len(ap.state)) { if traceIndex.Cmp(big.NewInt(int64(len(ap.state)))) >= 0 {
return ap.GetStepData(ctx, types.NewPosition(int(ap.depth), len(ap.state))) return ap.GetStepData(ctx, types.NewPosition(int(ap.depth), big.NewInt(int64(len(ap.state)))))
} }
return BuildAlphabetPreimage(traceIndex, ap.state[traceIndex]), []byte{}, nil, nil return BuildAlphabetPreimage(traceIndex, ap.state[traceIndex.Uint64()]), []byte{}, nil, nil
} }
// Get returns the claim value at the given index in the trace. // Get returns the claim value at the given index in the trace.
func (ap *AlphabetTraceProvider) Get(ctx context.Context, i types.Position) (common.Hash, error) { func (ap *AlphabetTraceProvider) Get(ctx context.Context, i types.Position) (common.Hash, error) {
// Step data returns the pre-state, so add 1 to get the state for index i // Step data returns the pre-state, so add 1 to get the state for index i
postPosition := types.NewPosition(int(ap.depth), int(i.TraceIndex(int(ap.depth)))+1) ti := i.TraceIndex(int(ap.depth))
postPosition := types.NewPosition(int(ap.depth), new(big.Int).Add(ti, big.NewInt(1)))
claimBytes, _, _, err := ap.GetStepData(ctx, postPosition) claimBytes, _, _, err := ap.GetStepData(ctx, postPosition)
if err != nil { if err != nil {
return common.Hash{}, err return common.Hash{}, err
...@@ -82,8 +83,8 @@ func (ap *AlphabetTraceProvider) AbsolutePreStateCommitment(ctx context.Context) ...@@ -82,8 +83,8 @@ func (ap *AlphabetTraceProvider) AbsolutePreStateCommitment(ctx context.Context)
} }
// BuildAlphabetPreimage constructs the claim bytes for the index and state item. // BuildAlphabetPreimage constructs the claim bytes for the index and state item.
func BuildAlphabetPreimage(i uint64, letter string) []byte { func BuildAlphabetPreimage(i *big.Int, letter string) []byte {
return append(IndexToBytes(i), LetterToBytes(letter)...) return append(i.FillBytes(make([]byte, 32)), LetterToBytes(letter)...)
} }
func alphabetStateHash(state []byte) common.Hash { func alphabetStateHash(state []byte) common.Hash {
...@@ -92,14 +93,6 @@ func alphabetStateHash(state []byte) common.Hash { ...@@ -92,14 +93,6 @@ func alphabetStateHash(state []byte) common.Hash {
return h return h
} }
// IndexToBytes converts an index to a byte slice big endian
func IndexToBytes(i uint64) []byte {
big := new(big.Int)
big.SetUint64(i)
out := make([]byte, 32)
return big.FillBytes(out)
}
// LetterToBytes converts a letter to a 32 byte array // LetterToBytes converts a letter to a 32 byte array
func LetterToBytes(letter string) []byte { func LetterToBytes(letter string) []byte {
out := make([]byte, 32) out := make([]byte, 32)
......
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func alphabetClaim(index uint64, letter string) common.Hash { func alphabetClaim(index *big.Int, letter string) common.Hash {
return alphabetStateHash(BuildAlphabetPreimage(index, letter)) return alphabetStateHash(BuildAlphabetPreimage(index, letter))
} }
...@@ -26,16 +26,16 @@ func TestAlphabetProvider_Get_ClaimsByTraceIndex(t *testing.T) { ...@@ -26,16 +26,16 @@ func TestAlphabetProvider_Get_ClaimsByTraceIndex(t *testing.T) {
expectedHash common.Hash expectedHash common.Hash
}{ }{
{ {
types.NewPosition(depth, 7), types.NewPosition(depth, big.NewInt(7)),
alphabetClaim(7, "h"), alphabetClaim(big.NewInt(7), "h"),
}, },
{ {
types.NewPosition(depth, 3), types.NewPosition(depth, big.NewInt(3)),
alphabetClaim(3, "d"), alphabetClaim(big.NewInt(3), "d"),
}, },
{ {
types.NewPosition(depth, 5), types.NewPosition(depth, big.NewInt(5)),
alphabetClaim(5, "f"), alphabetClaim(big.NewInt(5), "f"),
}, },
} }
...@@ -47,23 +47,13 @@ func TestAlphabetProvider_Get_ClaimsByTraceIndex(t *testing.T) { ...@@ -47,23 +47,13 @@ func TestAlphabetProvider_Get_ClaimsByTraceIndex(t *testing.T) {
} }
} }
// FuzzIndexToBytes tests the IndexToBytes function.
func FuzzIndexToBytes(f *testing.F) {
f.Fuzz(func(t *testing.T, index uint64) {
translated := IndexToBytes(index)
original := new(big.Int)
original.SetBytes(translated)
require.Equal(t, original.Uint64(), index)
})
}
// TestGetPreimage_Succeeds tests the GetPreimage function // TestGetPreimage_Succeeds tests the GetPreimage function
// returns the correct pre-image for a index. // returns the correct pre-image for a index.
func TestGetStepData_Succeeds(t *testing.T) { func TestGetStepData_Succeeds(t *testing.T) {
depth := 2 depth := 2
ap := NewTraceProvider("abc", uint64(depth)) ap := NewTraceProvider("abc", uint64(depth))
expected := BuildAlphabetPreimage(0, "a") expected := BuildAlphabetPreimage(big.NewInt(0), "a")
pos := types.NewPosition(depth, 1) pos := types.NewPosition(depth, big.NewInt(1))
retrieved, proof, data, err := ap.GetStepData(context.Background(), pos) retrieved, proof, data, err := ap.GetStepData(context.Background(), pos)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expected, retrieved) require.Equal(t, expected, retrieved)
...@@ -76,7 +66,7 @@ func TestGetStepData_Succeeds(t *testing.T) { ...@@ -76,7 +66,7 @@ func TestGetStepData_Succeeds(t *testing.T) {
func TestGetStepData_TooLargeIndex_Fails(t *testing.T) { func TestGetStepData_TooLargeIndex_Fails(t *testing.T) {
depth := 2 depth := 2
ap := NewTraceProvider("abc", uint64(depth)) ap := NewTraceProvider("abc", uint64(depth))
pos := types.NewPosition(depth, 5) pos := types.NewPosition(depth, big.NewInt(5))
_, _, _, err := ap.GetStepData(context.Background(), pos) _, _, _, err := ap.GetStepData(context.Background(), pos)
require.ErrorIs(t, err, ErrIndexTooLarge) require.ErrorIs(t, err, ErrIndexTooLarge)
} }
...@@ -85,10 +75,10 @@ func TestGetStepData_TooLargeIndex_Fails(t *testing.T) { ...@@ -85,10 +75,10 @@ func TestGetStepData_TooLargeIndex_Fails(t *testing.T) {
func TestGet_Succeeds(t *testing.T) { func TestGet_Succeeds(t *testing.T) {
depth := 2 depth := 2
ap := NewTraceProvider("abc", uint64(depth)) ap := NewTraceProvider("abc", uint64(depth))
pos := types.NewPosition(depth, 0) pos := types.NewPosition(depth, big.NewInt(0))
claim, err := ap.Get(context.Background(), pos) claim, err := ap.Get(context.Background(), pos)
require.NoError(t, err) require.NoError(t, err)
expected := alphabetClaim(0, "a") expected := alphabetClaim(big.NewInt(0), "a")
require.Equal(t, expected, claim) require.Equal(t, expected, claim)
} }
...@@ -97,7 +87,7 @@ func TestGet_Succeeds(t *testing.T) { ...@@ -97,7 +87,7 @@ func TestGet_Succeeds(t *testing.T) {
func TestGet_IndexTooLarge(t *testing.T) { func TestGet_IndexTooLarge(t *testing.T) {
depth := 2 depth := 2
ap := NewTraceProvider("abc", uint64(depth)) ap := NewTraceProvider("abc", uint64(depth))
pos := types.NewPosition(depth, 4) pos := types.NewPosition(depth, big.NewInt(4))
_, err := ap.Get(context.Background(), pos) _, err := ap.Get(context.Background(), pos)
require.ErrorIs(t, err, ErrIndexTooLarge) require.ErrorIs(t, err, ErrIndexTooLarge)
} }
...@@ -107,9 +97,9 @@ func TestGet_IndexTooLarge(t *testing.T) { ...@@ -107,9 +97,9 @@ func TestGet_IndexTooLarge(t *testing.T) {
func TestGet_Extends(t *testing.T) { func TestGet_Extends(t *testing.T) {
depth := 2 depth := 2
ap := NewTraceProvider("abc", uint64(depth)) ap := NewTraceProvider("abc", uint64(depth))
pos := types.NewPosition(depth, 3) pos := types.NewPosition(depth, big.NewInt(3))
claim, err := ap.Get(context.Background(), pos) claim, err := ap.Get(context.Background(), pos)
require.NoError(t, err) require.NoError(t, err)
expected := alphabetClaim(2, "c") expected := alphabetClaim(big.NewInt(2), "c")
require.Equal(t, expected, claim) require.Equal(t, expected, claim)
} }
...@@ -88,7 +88,7 @@ func (p *CannonTraceProvider) SetMaxDepth(gameDepth uint64) { ...@@ -88,7 +88,7 @@ func (p *CannonTraceProvider) SetMaxDepth(gameDepth uint64) {
} }
func (p *CannonTraceProvider) Get(ctx context.Context, pos types.Position) (common.Hash, error) { func (p *CannonTraceProvider) Get(ctx context.Context, pos types.Position) (common.Hash, error) {
proof, err := p.loadProof(ctx, pos.TraceIndex(int(p.gameDepth))) proof, err := p.loadProof(ctx, pos.UnsafeTraceIndex(int(p.gameDepth)))
if err != nil { if err != nil {
return common.Hash{}, err return common.Hash{}, err
} }
...@@ -101,7 +101,7 @@ func (p *CannonTraceProvider) Get(ctx context.Context, pos types.Position) (comm ...@@ -101,7 +101,7 @@ func (p *CannonTraceProvider) Get(ctx context.Context, pos types.Position) (comm
} }
func (p *CannonTraceProvider) GetStepData(ctx context.Context, pos types.Position) ([]byte, []byte, *types.PreimageOracleData, error) { func (p *CannonTraceProvider) GetStepData(ctx context.Context, pos types.Position) ([]byte, []byte, *types.PreimageOracleData, error) {
proof, err := p.loadProof(ctx, pos.TraceIndex(int(p.gameDepth))) proof, err := p.loadProof(ctx, pos.UnsafeTraceIndex(int(p.gameDepth)))
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
_ "embed" _ "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/big"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
...@@ -22,7 +23,7 @@ import ( ...@@ -22,7 +23,7 @@ import (
//go:embed test_data //go:embed test_data
var testData embed.FS var testData embed.FS
func PositionFromTraceIndex(provider *CannonTraceProvider, idx int) types.Position { func PositionFromTraceIndex(provider *CannonTraceProvider, idx *big.Int) types.Position {
return types.NewPosition(int(provider.gameDepth), idx) return types.NewPosition(int(provider.gameDepth), idx)
} }
...@@ -30,7 +31,7 @@ func TestGet(t *testing.T) { ...@@ -30,7 +31,7 @@ func TestGet(t *testing.T) {
dataDir, prestate := setupTestData(t) dataDir, prestate := setupTestData(t)
t.Run("ExistingProof", func(t *testing.T) { t.Run("ExistingProof", func(t *testing.T) {
provider, generator := setupWithTestData(t, dataDir, prestate) provider, generator := setupWithTestData(t, dataDir, prestate)
value, err := provider.Get(context.Background(), PositionFromTraceIndex(provider, 0)) value, err := provider.Get(context.Background(), PositionFromTraceIndex(provider, common.Big0))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, common.HexToHash("0x45fd9aa59768331c726e719e76aa343e73123af888804604785ae19506e65e87"), value) require.Equal(t, common.HexToHash("0x45fd9aa59768331c726e719e76aa343e73123af888804604785ae19506e65e87"), value)
require.Empty(t, generator.generated) require.Empty(t, generator.generated)
...@@ -43,7 +44,7 @@ func TestGet(t *testing.T) { ...@@ -43,7 +44,7 @@ func TestGet(t *testing.T) {
Step: 10, Step: 10,
Exited: true, Exited: true,
} }
value, err := provider.Get(context.Background(), PositionFromTraceIndex(provider, 7000)) value, err := provider.Get(context.Background(), PositionFromTraceIndex(provider, big.NewInt(7000)))
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, generator.generated, 7000, "should have tried to generate the proof") require.Contains(t, generator.generated, 7000, "should have tried to generate the proof")
stateHash, err := generator.finalState.EncodeWitness().StateHash() stateHash, err := generator.finalState.EncodeWitness().StateHash()
...@@ -53,14 +54,14 @@ func TestGet(t *testing.T) { ...@@ -53,14 +54,14 @@ func TestGet(t *testing.T) {
t.Run("MissingPostHash", func(t *testing.T) { t.Run("MissingPostHash", func(t *testing.T) {
provider, generator := setupWithTestData(t, dataDir, prestate) provider, generator := setupWithTestData(t, dataDir, prestate)
_, err := provider.Get(context.Background(), PositionFromTraceIndex(provider, 1)) _, err := provider.Get(context.Background(), PositionFromTraceIndex(provider, big.NewInt(1)))
require.ErrorContains(t, err, "missing post hash") require.ErrorContains(t, err, "missing post hash")
require.Empty(t, generator.generated) require.Empty(t, generator.generated)
}) })
t.Run("IgnoreUnknownFields", func(t *testing.T) { t.Run("IgnoreUnknownFields", func(t *testing.T) {
provider, generator := setupWithTestData(t, dataDir, prestate) provider, generator := setupWithTestData(t, dataDir, prestate)
value, err := provider.Get(context.Background(), PositionFromTraceIndex(provider, 2)) value, err := provider.Get(context.Background(), PositionFromTraceIndex(provider, big.NewInt(2)))
require.NoError(t, err) require.NoError(t, err)
expected := common.HexToHash("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") expected := common.HexToHash("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
require.Equal(t, expected, value) require.Equal(t, expected, value)
...@@ -72,7 +73,7 @@ func TestGetStepData(t *testing.T) { ...@@ -72,7 +73,7 @@ func TestGetStepData(t *testing.T) {
t.Run("ExistingProof", func(t *testing.T) { t.Run("ExistingProof", func(t *testing.T) {
dataDir, prestate := setupTestData(t) dataDir, prestate := setupTestData(t)
provider, generator := setupWithTestData(t, dataDir, prestate) provider, generator := setupWithTestData(t, dataDir, prestate)
value, proof, data, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, 0)) value, proof, data, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, new(big.Int)))
require.NoError(t, err) require.NoError(t, err)
expected := common.Hex2Bytes("b8f068de604c85ea0e2acd437cdb47add074a2d70b81d018390c504b71fe26f400000000000000000000000000000000000000000000000000000000000000000000000000") expected := common.Hex2Bytes("b8f068de604c85ea0e2acd437cdb47add074a2d70b81d018390c504b71fe26f400000000000000000000000000000000000000000000000000000000000000000000000000")
require.Equal(t, expected, value) require.Equal(t, expected, value)
...@@ -99,7 +100,7 @@ func TestGetStepData(t *testing.T) { ...@@ -99,7 +100,7 @@ func TestGetStepData(t *testing.T) {
OracleValue: []byte{0xdd}, OracleValue: []byte{0xdd},
OracleOffset: 10, OracleOffset: 10,
} }
preimage, proof, data, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, 4)) preimage, proof, data, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, big.NewInt(4)))
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, generator.generated, 4, "should have tried to generate the proof") require.Contains(t, generator.generated, 4, "should have tried to generate the proof")
...@@ -125,7 +126,7 @@ func TestGetStepData(t *testing.T) { ...@@ -125,7 +126,7 @@ func TestGetStepData(t *testing.T) {
OracleValue: []byte{0xdd}, OracleValue: []byte{0xdd},
OracleOffset: 10, OracleOffset: 10,
} }
preimage, proof, data, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, 7000)) preimage, proof, data, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, big.NewInt(7000)))
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, generator.generated, 7000, "should have tried to generate the proof") require.Contains(t, generator.generated, 7000, "should have tried to generate the proof")
...@@ -151,7 +152,7 @@ func TestGetStepData(t *testing.T) { ...@@ -151,7 +152,7 @@ func TestGetStepData(t *testing.T) {
OracleValue: []byte{0xdd}, OracleValue: []byte{0xdd},
OracleOffset: 10, OracleOffset: 10,
} }
_, _, _, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, 7000)) _, _, _, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, big.NewInt(7000)))
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, initGenerator.generated, 7000, "should have tried to generate the proof") require.Contains(t, initGenerator.generated, 7000, "should have tried to generate the proof")
...@@ -166,7 +167,7 @@ func TestGetStepData(t *testing.T) { ...@@ -166,7 +167,7 @@ func TestGetStepData(t *testing.T) {
StateData: []byte{0xbb}, StateData: []byte{0xbb},
ProofData: []byte{0xcc}, ProofData: []byte{0xcc},
} }
preimage, proof, data, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, 7000)) preimage, proof, data, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, big.NewInt(7000)))
require.NoError(t, err) require.NoError(t, err)
require.Empty(t, generator.generated, "should not have to generate the proof again") require.Empty(t, generator.generated, "should not have to generate the proof again")
...@@ -178,7 +179,7 @@ func TestGetStepData(t *testing.T) { ...@@ -178,7 +179,7 @@ func TestGetStepData(t *testing.T) {
t.Run("MissingStateData", func(t *testing.T) { t.Run("MissingStateData", func(t *testing.T) {
dataDir, prestate := setupTestData(t) dataDir, prestate := setupTestData(t)
provider, generator := setupWithTestData(t, dataDir, prestate) provider, generator := setupWithTestData(t, dataDir, prestate)
_, _, _, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, 1)) _, _, _, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, big.NewInt(1)))
require.ErrorContains(t, err, "missing state data") require.ErrorContains(t, err, "missing state data")
require.Empty(t, generator.generated) require.Empty(t, generator.generated)
}) })
...@@ -186,7 +187,7 @@ func TestGetStepData(t *testing.T) { ...@@ -186,7 +187,7 @@ func TestGetStepData(t *testing.T) {
t.Run("IgnoreUnknownFields", func(t *testing.T) { t.Run("IgnoreUnknownFields", func(t *testing.T) {
dataDir, prestate := setupTestData(t) dataDir, prestate := setupTestData(t)
provider, generator := setupWithTestData(t, dataDir, prestate) provider, generator := setupWithTestData(t, dataDir, prestate)
value, proof, data, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, 2)) value, proof, data, err := provider.GetStepData(context.Background(), PositionFromTraceIndex(provider, big.NewInt(2)))
require.NoError(t, err) require.NoError(t, err)
expected := common.Hex2Bytes("cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc") expected := common.Hex2Bytes("cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc")
require.Equal(t, expected, value) require.Equal(t, expected, value)
......
...@@ -3,6 +3,7 @@ package outputs ...@@ -3,6 +3,7 @@ package outputs
import ( import (
"context" "context"
"fmt" "fmt"
"math"
"github.com/ethereum-optimism/optimism/op-challenger/game/fault/types" "github.com/ethereum-optimism/optimism/op-challenger/game/fault/types"
"github.com/ethereum-optimism/optimism/op-service/client" "github.com/ethereum-optimism/optimism/op-service/client"
...@@ -52,7 +53,11 @@ func NewTraceProviderFromInputs(logger log.Logger, rollupClient OutputRollupClie ...@@ -52,7 +53,11 @@ func NewTraceProviderFromInputs(logger log.Logger, rollupClient OutputRollupClie
} }
func (o *OutputTraceProvider) Get(ctx context.Context, pos types.Position) (common.Hash, error) { func (o *OutputTraceProvider) Get(ctx context.Context, pos types.Position) (common.Hash, error) {
outputBlock := pos.TraceIndex(int(o.gameDepth)) + o.prestateBlock + 1 traceIndex := pos.TraceIndex(int(o.gameDepth))
if traceIndex.Cmp(common.Big0.SetUint64(math.MaxUint64)) > 0 {
return common.Hash{}, fmt.Errorf("trace index %v is greater than max uint64", traceIndex)
}
outputBlock := traceIndex.Uint64() + o.prestateBlock + 1
if outputBlock > o.poststateBlock { if outputBlock > o.poststateBlock {
outputBlock = o.poststateBlock outputBlock = o.poststateBlock
} }
......
...@@ -3,6 +3,7 @@ package outputs ...@@ -3,6 +3,7 @@ package outputs
import ( import (
"context" "context"
"fmt" "fmt"
"math/big"
"testing" "testing"
"github.com/ethereum-optimism/optimism/op-challenger/game/fault/types" "github.com/ethereum-optimism/optimism/op-challenger/game/fault/types"
...@@ -26,13 +27,21 @@ var ( ...@@ -26,13 +27,21 @@ var (
func TestGet(t *testing.T) { func TestGet(t *testing.T) {
t.Run("PrePrestateErrors", func(t *testing.T) { t.Run("PrePrestateErrors", func(t *testing.T) {
provider, _ := setupWithTestData(t, 0, poststateBlock) provider, _ := setupWithTestData(t, 0, poststateBlock)
_, err := provider.Get(context.Background(), types.NewPosition(1, 0)) _, err := provider.Get(context.Background(), types.NewPosition(1, common.Big0))
require.ErrorAs(t, fmt.Errorf("no output at block %d", 1), &err) require.ErrorAs(t, fmt.Errorf("no output at block %d", 1), &err)
}) })
t.Run("ErrorsTraceIndexOutOfBounds", func(t *testing.T) {
deepGame := uint64(64)
provider, _ := setupWithTestData(t, prestateBlock, poststateBlock, deepGame)
pos := types.NewPosition(0, big.NewInt(0))
_, err := provider.Get(context.Background(), pos)
require.ErrorAs(t, fmt.Errorf("trace index %v is greater than max uint64", pos.TraceIndex(int(deepGame))), &err)
})
t.Run("MisconfiguredPoststateErrors", func(t *testing.T) { t.Run("MisconfiguredPoststateErrors", func(t *testing.T) {
provider, _ := setupWithTestData(t, 0, 0) provider, _ := setupWithTestData(t, 0, 0)
_, err := provider.Get(context.Background(), types.NewPosition(1, 0)) _, err := provider.Get(context.Background(), types.NewPosition(1, common.Big0))
require.ErrorAs(t, fmt.Errorf("no output at block %d", 0), &err) require.ErrorAs(t, fmt.Errorf("no output at block %d", 0), &err)
}) })
...@@ -82,7 +91,7 @@ func TestAbsolutePreStateCommitment(t *testing.T) { ...@@ -82,7 +91,7 @@ func TestAbsolutePreStateCommitment(t *testing.T) {
func TestGetStepData(t *testing.T) { func TestGetStepData(t *testing.T) {
provider, _ := setupWithTestData(t, prestateBlock, poststateBlock) provider, _ := setupWithTestData(t, prestateBlock, poststateBlock)
_, _, _, err := provider.GetStepData(context.Background(), types.NewPosition(1, 0)) _, _, _, err := provider.GetStepData(context.Background(), types.NewPosition(1, common.Big0))
require.ErrorIs(t, err, GetStepDataErr) require.ErrorIs(t, err, GetStepDataErr)
} }
...@@ -92,7 +101,7 @@ func TestAbsolutePreState(t *testing.T) { ...@@ -92,7 +101,7 @@ func TestAbsolutePreState(t *testing.T) {
require.ErrorIs(t, err, AbsolutePreStateErr) require.ErrorIs(t, err, AbsolutePreStateErr)
} }
func setupWithTestData(t *testing.T, prestateBlock, poststateBlock uint64) (*OutputTraceProvider, *stubRollupClient) { func setupWithTestData(t *testing.T, prestateBlock, poststateBlock uint64, customGameDepth ...uint64) (*OutputTraceProvider, *stubRollupClient) {
rollupClient := stubRollupClient{ rollupClient := stubRollupClient{
outputs: map[uint64]*eth.OutputResponse{ outputs: map[uint64]*eth.OutputResponse{
prestateBlock: { prestateBlock: {
...@@ -106,12 +115,16 @@ func setupWithTestData(t *testing.T, prestateBlock, poststateBlock uint64) (*Out ...@@ -106,12 +115,16 @@ func setupWithTestData(t *testing.T, prestateBlock, poststateBlock uint64) (*Out
}, },
}, },
} }
inputGameDepth := gameDepth
if len(customGameDepth) > 0 {
inputGameDepth = customGameDepth[0]
}
return &OutputTraceProvider{ return &OutputTraceProvider{
logger: testlog.Logger(t, log.LvlInfo), logger: testlog.Logger(t, log.LvlInfo),
rollupClient: &rollupClient, rollupClient: &rollupClient,
prestateBlock: prestateBlock, prestateBlock: prestateBlock,
poststateBlock: poststateBlock, poststateBlock: poststateBlock,
gameDepth: gameDepth, gameDepth: inputGameDepth,
}, &rollupClient }, &rollupClient
} }
......
...@@ -3,6 +3,7 @@ package split ...@@ -3,6 +3,7 @@ package split
import ( import (
"context" "context"
"errors" "errors"
"math/big"
"testing" "testing"
"github.com/ethereum-optimism/optimism/op-challenger/game/fault/types" "github.com/ethereum-optimism/optimism/op-challenger/game/fault/types"
...@@ -23,26 +24,26 @@ func TestGet(t *testing.T) { ...@@ -23,26 +24,26 @@ func TestGet(t *testing.T) {
t.Run("ErrorBubblesUp", func(t *testing.T) { t.Run("ErrorBubblesUp", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{getError: mockGetError} mockOutputProvider := mockTraceProvider{getError: mockGetError}
splitProvider := newSplitTraceProvider(t, &mockOutputProvider, nil, 40) splitProvider := newSplitTraceProvider(t, &mockOutputProvider, nil, 40)
_, err := splitProvider.Get(context.Background(), types.NewPosition(1, 0)) _, err := splitProvider.Get(context.Background(), types.NewPosition(1, common.Big0))
require.ErrorIs(t, err, mockGetError) require.ErrorIs(t, err, mockGetError)
}) })
t.Run("ReturnsCorrectOutputFromTopProvider", func(t *testing.T) { t.Run("ReturnsCorrectOutputFromTopProvider", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{getOutput: mockOutput} mockOutputProvider := mockTraceProvider{getOutput: mockOutput}
splitProvider := newSplitTraceProvider(t, &mockOutputProvider, &mockTraceProvider{}, 40) splitProvider := newSplitTraceProvider(t, &mockOutputProvider, &mockTraceProvider{}, 40)
output, err := splitProvider.Get(context.Background(), types.NewPosition(6, 3)) output, err := splitProvider.Get(context.Background(), types.NewPosition(6, big.NewInt(3)))
require.NoError(t, err) require.NoError(t, err)
expectedGIndex := types.NewPosition(6, 3).ToGIndex() expectedGIndex := types.NewPosition(6, big.NewInt(3)).ToGIndex()
require.Equal(t, common.BytesToHash([]byte{byte(expectedGIndex)}), output) require.Equal(t, common.BigToHash(expectedGIndex), output)
}) })
t.Run("ReturnsCorrectOutputWithMultipleProviders", func(t *testing.T) { t.Run("ReturnsCorrectOutputWithMultipleProviders", func(t *testing.T) {
bottomProvider := mockTraceProvider{getOutput: mockOutput} bottomProvider := mockTraceProvider{getOutput: mockOutput}
splitProvider := newSplitTraceProvider(t, &mockTraceProvider{}, &bottomProvider, 40) splitProvider := newSplitTraceProvider(t, &mockTraceProvider{}, &bottomProvider, 40)
output, err := splitProvider.Get(context.Background(), types.NewPosition(42, 17)) output, err := splitProvider.Get(context.Background(), types.NewPosition(42, big.NewInt(17)))
require.NoError(t, err) require.NoError(t, err)
expectedGIndex := types.NewPosition(2, 1).ToGIndex() expectedGIndex := types.NewPosition(2, big.NewInt(1)).ToGIndex()
require.Equal(t, common.BytesToHash([]byte{byte(expectedGIndex)}), output) require.Equal(t, common.BigToHash(expectedGIndex), output)
}) })
} }
...@@ -85,7 +86,7 @@ func TestGetStepData(t *testing.T) { ...@@ -85,7 +86,7 @@ func TestGetStepData(t *testing.T) {
t.Run("ErrorBubblesUp", func(t *testing.T) { t.Run("ErrorBubblesUp", func(t *testing.T) {
mockOutputProvider := mockTraceProvider{getStepDataError: mockGetError} mockOutputProvider := mockTraceProvider{getStepDataError: mockGetError}
splitProvider := newSplitTraceProvider(t, &mockOutputProvider, nil, 40) splitProvider := newSplitTraceProvider(t, &mockOutputProvider, nil, 40)
_, _, _, err := splitProvider.GetStepData(context.Background(), types.NewPosition(0, 0)) _, _, _, err := splitProvider.GetStepData(context.Background(), types.NewPosition(0, common.Big0))
require.ErrorIs(t, err, mockGetError) require.ErrorIs(t, err, mockGetError)
}) })
...@@ -93,7 +94,7 @@ func TestGetStepData(t *testing.T) { ...@@ -93,7 +94,7 @@ func TestGetStepData(t *testing.T) {
expectedStepData := []byte{1, 2, 3, 4} expectedStepData := []byte{1, 2, 3, 4}
mockOutputProvider := mockTraceProvider{stepPrestateData: expectedStepData} mockOutputProvider := mockTraceProvider{stepPrestateData: expectedStepData}
splitProvider := newSplitTraceProvider(t, nil, &mockOutputProvider, 40) splitProvider := newSplitTraceProvider(t, nil, &mockOutputProvider, 40)
output, _, _, err := splitProvider.GetStepData(context.Background(), types.NewPosition(41, 0)) output, _, _, err := splitProvider.GetStepData(context.Background(), types.NewPosition(41, common.Big0))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expectedStepData, output) require.Equal(t, expectedStepData, output)
}) })
...@@ -123,7 +124,7 @@ func (m *mockTraceProvider) Get(ctx context.Context, pos types.Position) (common ...@@ -123,7 +124,7 @@ func (m *mockTraceProvider) Get(ctx context.Context, pos types.Position) (common
if m.getError != nil { if m.getError != nil {
return common.Hash{}, m.getError return common.Hash{}, m.getError
} }
return common.BytesToHash([]byte{byte(pos.ToGIndex())}), nil return common.BigToHash(pos.ToGIndex()), nil
} }
func (m *mockTraceProvider) AbsolutePreStateCommitment(ctx context.Context) (hash common.Hash, err error) { func (m *mockTraceProvider) AbsolutePreStateCommitment(ctx context.Context) (hash common.Hash, err error) {
......
...@@ -38,7 +38,7 @@ type claimID common.Hash ...@@ -38,7 +38,7 @@ type claimID common.Hash
func computeClaimID(claim Claim) claimID { func computeClaimID(claim Claim) claimID {
return claimID(crypto.Keccak256Hash( return claimID(crypto.Keccak256Hash(
new(big.Int).SetUint64(claim.Position.ToGIndex()).Bytes(), claim.Position.ToGIndex().Bytes(),
claim.Value.Bytes(), claim.Value.Bytes(),
big.NewInt(int64(claim.ParentContractIndex)).Bytes(), big.NewInt(int64(claim.ParentContractIndex)).Bytes(),
)) ))
......
package types package types
import ( import (
"math/big"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
...@@ -15,14 +16,14 @@ func createTestClaims() (Claim, Claim, Claim, Claim) { ...@@ -15,14 +16,14 @@ func createTestClaims() (Claim, Claim, Claim, Claim) {
root := Claim{ root := Claim{
ClaimData: ClaimData{ ClaimData: ClaimData{
Value: common.HexToHash("0x000000000000000000000000000000000000000000000000000000000000077a"), Value: common.HexToHash("0x000000000000000000000000000000000000000000000000000000000000077a"),
Position: NewPosition(0, 0), Position: NewPosition(0, common.Big0),
}, },
// Root claim has no parent // Root claim has no parent
} }
top := Claim{ top := Claim{
ClaimData: ClaimData{ ClaimData: ClaimData{
Value: common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000364"), Value: common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000364"),
Position: NewPosition(1, 0), Position: NewPosition(1, common.Big0),
}, },
Parent: root.ClaimData, Parent: root.ClaimData,
ContractIndex: 1, ContractIndex: 1,
...@@ -31,7 +32,7 @@ func createTestClaims() (Claim, Claim, Claim, Claim) { ...@@ -31,7 +32,7 @@ func createTestClaims() (Claim, Claim, Claim, Claim) {
middle := Claim{ middle := Claim{
ClaimData: ClaimData{ ClaimData: ClaimData{
Value: common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000578"), Value: common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000578"),
Position: NewPosition(2, 2), Position: NewPosition(2, big.NewInt(2)),
}, },
Parent: top.ClaimData, Parent: top.ClaimData,
ContractIndex: 2, ContractIndex: 2,
...@@ -41,7 +42,7 @@ func createTestClaims() (Claim, Claim, Claim, Claim) { ...@@ -41,7 +42,7 @@ func createTestClaims() (Claim, Claim, Claim, Claim) {
bottom := Claim{ bottom := Claim{
ClaimData: ClaimData{ ClaimData: ClaimData{
Value: common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000465"), Value: common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000465"),
Position: NewPosition(3, 4), Position: NewPosition(3, big.NewInt(4)),
}, },
Parent: middle.ClaimData, Parent: middle.ClaimData,
ContractIndex: 3, ContractIndex: 3,
...@@ -52,7 +53,6 @@ func createTestClaims() (Claim, Claim, Claim, Claim) { ...@@ -52,7 +53,6 @@ func createTestClaims() (Claim, Claim, Claim, Claim) {
} }
func TestIsDuplicate(t *testing.T) { func TestIsDuplicate(t *testing.T) {
// Setup the game state.
root, top, middle, bottom := createTestClaims() root, top, middle, bottom := createTestClaims()
g := NewGameState(false, []Claim{root, top}, testMaxDepth) g := NewGameState(false, []Claim{root, top}, testMaxDepth)
......
...@@ -3,6 +3,9 @@ package types ...@@ -3,6 +3,9 @@ package types
import ( import (
"errors" "errors"
"fmt" "fmt"
"math/big"
"github.com/ethereum/go-ethereum/common"
) )
var ( var (
...@@ -12,23 +15,33 @@ var ( ...@@ -12,23 +15,33 @@ var (
// Position is a golang wrapper around the dispute game Position type. // Position is a golang wrapper around the dispute game Position type.
type Position struct { type Position struct {
depth int depth int
indexAtDepth int indexAtDepth *big.Int
}
func NewPosition(depth int, indexAtDepth *big.Int) Position {
return Position{
depth: depth,
indexAtDepth: indexAtDepth,
}
} }
func NewPosition(depth, indexAtDepth int) Position { func NewLargePositionFromGIndex(x *big.Int) Position {
return Position{depth, indexAtDepth} depth := bigMSB(x)
indexAtDepth := new(big.Int).Sub(x, new(big.Int).Lsh(big.NewInt(1), uint(depth)))
return NewPosition(depth, indexAtDepth)
} }
// todo(client-pod#80): remove this to use the NewLargePositionFromGIndex.
func NewPositionFromGIndex(x uint64) Position { func NewPositionFromGIndex(x uint64) Position {
depth := MSBIndex(x) depth := MSBIndex(x)
indexAtDepth := ^(1 << depth) & x indexAtDepth := ^(1 << depth) & x
return NewPosition(depth, int(indexAtDepth)) return NewPosition(depth, big.NewInt(int64(indexAtDepth)))
} }
func (p Position) MoveRight() Position { func (p Position) MoveRight() Position {
return Position{ return Position{
depth: p.depth, depth: p.depth,
indexAtDepth: int(p.indexAtDepth + 1), indexAtDepth: new(big.Int).Add(p.indexAtDepth, big.NewInt(1)),
} }
} }
...@@ -40,7 +53,7 @@ func (p Position) RelativeToAncestorAtDepth(ancestor uint64) (Position, error) { ...@@ -40,7 +53,7 @@ func (p Position) RelativeToAncestorAtDepth(ancestor uint64) (Position, error) {
} }
newPosDepth := uint64(p.depth) - ancestor newPosDepth := uint64(p.depth) - ancestor
nodesAtDepth := 1 << newPosDepth nodesAtDepth := 1 << newPosDepth
newIndexAtDepth := p.indexAtDepth % nodesAtDepth newIndexAtDepth := new(big.Int).Mod(p.indexAtDepth, big.NewInt(int64(nodesAtDepth)))
return NewPosition(int(newPosDepth), newIndexAtDepth), nil return NewPosition(int(newPosDepth), newIndexAtDepth), nil
} }
...@@ -48,28 +61,44 @@ func (p Position) Depth() int { ...@@ -48,28 +61,44 @@ func (p Position) Depth() int {
return p.depth return p.depth
} }
func (p Position) IndexAtDepth() int { func (p Position) IndexAtDepth() *big.Int {
if p.indexAtDepth == nil {
return common.Big0
}
return p.indexAtDepth return p.indexAtDepth
} }
func (p Position) IsRootPosition() bool { func (p Position) IsRootPosition() bool {
return p.depth == 0 && p.indexAtDepth == 0 return p.depth == 0 && common.Big0.Cmp(p.indexAtDepth) == 0
}
func (p Position) lshIndex(amount int) *big.Int {
return new(big.Int).Lsh(p.IndexAtDepth(), uint(amount))
} }
// TraceIndex calculates the what the index of the claim value would be inside the trace. // TraceIndex calculates the what the index of the claim value would be inside the trace.
// It is equivalent to going right until the final depth has been reached. // It is equivalent to going right until the final depth has been reached.
func (p Position) TraceIndex(maxDepth int) uint64 { func (p Position) TraceIndex(maxDepth int) *big.Int {
// When we go right, we do a shift left and set the bottom bit to be 1. // When we go right, we do a shift left and set the bottom bit to be 1.
// To do this in a single step, do all the shifts at once & or in all 1s for the bottom bits. // To do this in a single step, do all the shifts at once & or in all 1s for the bottom bits.
rd := maxDepth - p.depth rd := maxDepth - p.depth
return uint64(p.indexAtDepth<<rd | ((1 << rd) - 1)) rhs := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(rd)), big.NewInt(1))
ti := new(big.Int).Or(p.lshIndex(rd), rhs)
return ti
}
// UnsafeTraceIndex returns a uint64 representation of the trace index.
// todo(refcell): This should be removed in a follow-on pr and any invocations
// should be updated to use TraceIndex.
func (p Position) UnsafeTraceIndex(maxDepth int) uint64 {
return p.TraceIndex(maxDepth).Uint64()
} }
// move returns a new position at the left or right child. // move returns a new position at the left or right child.
func (p Position) move(right bool) Position { func (p Position) move(right bool) Position {
return Position{ return Position{
depth: p.depth + 1, depth: p.depth + 1,
indexAtDepth: (p.indexAtDepth << 1) | boolToInt(right), indexAtDepth: new(big.Int).Or(p.lshIndex(1), big.NewInt(int64(boolToInt(right)))),
} }
} }
...@@ -81,11 +110,19 @@ func boolToInt(b bool) int { ...@@ -81,11 +110,19 @@ func boolToInt(b bool) int {
} }
} }
func (p Position) parentIndexAtDepth() *big.Int {
return new(big.Int).Div(p.IndexAtDepth(), big.NewInt(2))
}
func (p Position) RightOf(parent Position) bool {
return p.parentIndexAtDepth().Cmp(parent.IndexAtDepth()) != 0
}
// parent return a new position that is the parent of this Position. // parent return a new position that is the parent of this Position.
func (p Position) parent() Position { func (p Position) parent() Position {
return Position{ return Position{
depth: p.depth - 1, depth: p.depth - 1,
indexAtDepth: p.indexAtDepth >> 1, indexAtDepth: p.parentIndexAtDepth(),
} }
} }
...@@ -103,8 +140,20 @@ func (p Position) Print(maxDepth int) { ...@@ -103,8 +140,20 @@ func (p Position) Print(maxDepth int) {
fmt.Printf("GIN: %4b\tTrace Position is %4b\tTrace Depth is: %d\tTrace Index is: %d\n", p.ToGIndex(), p.indexAtDepth, p.depth, p.TraceIndex(maxDepth)) fmt.Printf("GIN: %4b\tTrace Position is %4b\tTrace Depth is: %d\tTrace Index is: %d\n", p.ToGIndex(), p.indexAtDepth, p.depth, p.TraceIndex(maxDepth))
} }
func (p Position) ToGIndex() uint64 { func (p Position) ToGIndex() *big.Int {
return uint64(1<<p.depth | p.indexAtDepth) return new(big.Int).Or(new(big.Int).Lsh(big.NewInt(1), uint(p.depth)), p.IndexAtDepth())
}
// bigMSB returns the index of the most significant bit
func bigMSB(x *big.Int) int {
if x.Cmp(common.Big0) == 0 {
return 0
}
out := 0
for ; x.Cmp(common.Big0) != 0; out++ {
x = new(big.Int).Rsh(x, 1)
}
return out - 1
} }
// MSBIndex returns the index of the most significant bit // MSBIndex returns the index of the most significant bit
......
package types package types
import ( import (
"math"
"math/big"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -28,79 +30,93 @@ func TestMSBIndex(t *testing.T) { ...@@ -28,79 +30,93 @@ func TestMSBIndex(t *testing.T) {
t.Errorf("MSBIndex(%d) expected %d, but got %d", test.input, test.expected, result) t.Errorf("MSBIndex(%d) expected %d, but got %d", test.input, test.expected, result)
} }
} }
}
func bi(i int) *big.Int {
return big.NewInt(int64(i))
} }
type testNodeInfo struct { type testNodeInfo struct {
GIndex uint64 GIndex *big.Int
Depth int Depth int
IndexAtDepth int MaxDepth int
TraceIndex uint64 IndexAtDepth *big.Int
AttackGIndex uint64 // 0 indicates attack is not possible from this node TraceIndex *big.Int
DefendGIndex uint64 // 0 indicates defend is not possible from this node AttackGIndex *big.Int // 0 indicates attack is not possible from this node
DefendGIndex *big.Int // 0 indicates defend is not possible from this node
} }
var treeNodesMaxDepth4 = []testNodeInfo{ var treeNodes = []testNodeInfo{
{GIndex: 1, Depth: 0, IndexAtDepth: 0, TraceIndex: 15, AttackGIndex: 2}, {GIndex: bi(1), Depth: 0, MaxDepth: 4, IndexAtDepth: bi(0), TraceIndex: bi(15), AttackGIndex: bi(2)},
{GIndex: 2, Depth: 1, IndexAtDepth: 0, TraceIndex: 7, AttackGIndex: 4, DefendGIndex: 6}, {GIndex: bi(2), Depth: 1, MaxDepth: 4, IndexAtDepth: bi(0), TraceIndex: bi(7), AttackGIndex: bi(4), DefendGIndex: bi(6)},
{GIndex: 3, Depth: 1, IndexAtDepth: 1, TraceIndex: 15, AttackGIndex: 6}, {GIndex: bi(3), Depth: 1, MaxDepth: 4, IndexAtDepth: bi(1), TraceIndex: bi(15), AttackGIndex: bi(6)},
{GIndex: 4, Depth: 2, IndexAtDepth: 0, TraceIndex: 3, AttackGIndex: 8, DefendGIndex: 10}, {GIndex: bi(4), Depth: 2, MaxDepth: 4, IndexAtDepth: bi(0), TraceIndex: bi(3), AttackGIndex: bi(8), DefendGIndex: bi(10)},
{GIndex: 5, Depth: 2, IndexAtDepth: 1, TraceIndex: 7, AttackGIndex: 10}, {GIndex: bi(5), Depth: 2, MaxDepth: 4, IndexAtDepth: bi(1), TraceIndex: bi(7), AttackGIndex: bi(10)},
{GIndex: 6, Depth: 2, IndexAtDepth: 2, TraceIndex: 11, AttackGIndex: 12, DefendGIndex: 14}, {GIndex: bi(6), Depth: 2, MaxDepth: 4, IndexAtDepth: bi(2), TraceIndex: bi(11), AttackGIndex: bi(12), DefendGIndex: bi(14)},
{GIndex: 7, Depth: 2, IndexAtDepth: 3, TraceIndex: 15, AttackGIndex: 14}, {GIndex: bi(7), Depth: 2, MaxDepth: 4, IndexAtDepth: bi(3), TraceIndex: bi(15), AttackGIndex: bi(14)},
{GIndex: 8, Depth: 3, IndexAtDepth: 0, TraceIndex: 1, AttackGIndex: 16, DefendGIndex: 18}, {GIndex: bi(8), Depth: 3, MaxDepth: 4, IndexAtDepth: bi(0), TraceIndex: bi(1), AttackGIndex: bi(16), DefendGIndex: bi(18)},
{GIndex: 9, Depth: 3, IndexAtDepth: 1, TraceIndex: 3, AttackGIndex: 18}, {GIndex: bi(9), Depth: 3, MaxDepth: 4, IndexAtDepth: bi(1), TraceIndex: bi(3), AttackGIndex: bi(18)},
{GIndex: 10, Depth: 3, IndexAtDepth: 2, TraceIndex: 5, AttackGIndex: 20, DefendGIndex: 22}, {GIndex: bi(10), Depth: 3, MaxDepth: 4, IndexAtDepth: bi(2), TraceIndex: bi(5), AttackGIndex: bi(20), DefendGIndex: bi(22)},
{GIndex: 11, Depth: 3, IndexAtDepth: 3, TraceIndex: 7, AttackGIndex: 22}, {GIndex: bi(11), Depth: 3, MaxDepth: 4, IndexAtDepth: bi(3), TraceIndex: bi(7), AttackGIndex: bi(22)},
{GIndex: 12, Depth: 3, IndexAtDepth: 4, TraceIndex: 9, AttackGIndex: 24, DefendGIndex: 26}, {GIndex: bi(12), Depth: 3, MaxDepth: 4, IndexAtDepth: bi(4), TraceIndex: bi(9), AttackGIndex: bi(24), DefendGIndex: bi(26)},
{GIndex: 13, Depth: 3, IndexAtDepth: 5, TraceIndex: 11, AttackGIndex: 26}, {GIndex: bi(13), Depth: 3, MaxDepth: 4, IndexAtDepth: bi(5), TraceIndex: bi(11), AttackGIndex: bi(26)},
{GIndex: 14, Depth: 3, IndexAtDepth: 6, TraceIndex: 13, AttackGIndex: 28, DefendGIndex: 30}, {GIndex: bi(14), Depth: 3, MaxDepth: 4, IndexAtDepth: bi(6), TraceIndex: bi(13), AttackGIndex: bi(28), DefendGIndex: bi(30)},
{GIndex: 15, Depth: 3, IndexAtDepth: 7, TraceIndex: 15, AttackGIndex: 30}, {GIndex: bi(15), Depth: 3, MaxDepth: 4, IndexAtDepth: bi(7), TraceIndex: bi(15), AttackGIndex: bi(30)},
{GIndex: 16, Depth: 4, IndexAtDepth: 0, TraceIndex: 0}, {GIndex: bi(16), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(0), TraceIndex: bi(0)},
{GIndex: 17, Depth: 4, IndexAtDepth: 1, TraceIndex: 1}, {GIndex: bi(17), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(1), TraceIndex: bi(1)},
{GIndex: 18, Depth: 4, IndexAtDepth: 2, TraceIndex: 2}, {GIndex: bi(18), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(2), TraceIndex: bi(2)},
{GIndex: 19, Depth: 4, IndexAtDepth: 3, TraceIndex: 3}, {GIndex: bi(19), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(3), TraceIndex: bi(3)},
{GIndex: 20, Depth: 4, IndexAtDepth: 4, TraceIndex: 4}, {GIndex: bi(20), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(4), TraceIndex: bi(4)},
{GIndex: 21, Depth: 4, IndexAtDepth: 5, TraceIndex: 5}, {GIndex: bi(21), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(5), TraceIndex: bi(5)},
{GIndex: 22, Depth: 4, IndexAtDepth: 6, TraceIndex: 6}, {GIndex: bi(22), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(6), TraceIndex: bi(6)},
{GIndex: 23, Depth: 4, IndexAtDepth: 7, TraceIndex: 7}, {GIndex: bi(23), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(7), TraceIndex: bi(7)},
{GIndex: 24, Depth: 4, IndexAtDepth: 8, TraceIndex: 8}, {GIndex: bi(24), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(8), TraceIndex: bi(8)},
{GIndex: 25, Depth: 4, IndexAtDepth: 9, TraceIndex: 9}, {GIndex: bi(25), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(9), TraceIndex: bi(9)},
{GIndex: 26, Depth: 4, IndexAtDepth: 10, TraceIndex: 10}, {GIndex: bi(26), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(10), TraceIndex: bi(10)},
{GIndex: 27, Depth: 4, IndexAtDepth: 11, TraceIndex: 11}, {GIndex: bi(27), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(11), TraceIndex: bi(11)},
{GIndex: 28, Depth: 4, IndexAtDepth: 12, TraceIndex: 12}, {GIndex: bi(28), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(12), TraceIndex: bi(12)},
{GIndex: 29, Depth: 4, IndexAtDepth: 13, TraceIndex: 13}, {GIndex: bi(29), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(13), TraceIndex: bi(13)},
{GIndex: 30, Depth: 4, IndexAtDepth: 14, TraceIndex: 14}, {GIndex: bi(30), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(14), TraceIndex: bi(14)},
{GIndex: 31, Depth: 4, IndexAtDepth: 15, TraceIndex: 15}, {GIndex: bi(31), Depth: 4, MaxDepth: 4, IndexAtDepth: bi(15), TraceIndex: bi(15)},
{GIndex: bi(0).Mul(bi(math.MaxInt64), bi(2)), Depth: 63, MaxDepth: 64, IndexAtDepth: bi(9223372036854775806), TraceIndex: bi(0).Sub(bi(0).Mul(bi(math.MaxInt64), bi(2)), bi(1))},
} }
// TestGINConversions does To & From the generalized index on the treeNodesMaxDepth4 data // TestGINConversions does To & From the generalized index on the treeNodesMaxDepth4 data
func TestGINConversions(t *testing.T) { func TestGINConversions(t *testing.T) {
for _, test := range treeNodesMaxDepth4 { for _, test := range treeNodes {
from := NewPositionFromGIndex(test.GIndex) from := NewLargePositionFromGIndex(test.GIndex)
pos := NewPosition(test.Depth, test.IndexAtDepth) pos := NewPosition(test.Depth, test.IndexAtDepth)
require.Equal(t, pos, from) require.EqualValuesf(t, pos.Depth(), from.Depth(), "From GIndex %v vs pos %v", from.Depth(), pos.Depth())
require.Zerof(t, pos.IndexAtDepth().Cmp(from.IndexAtDepth()), "From GIndex %v vs pos %v", from.IndexAtDepth(), pos.IndexAtDepth())
to := pos.ToGIndex() to := pos.ToGIndex()
require.Equal(t, test.GIndex, to) require.Equal(t, test.GIndex, to)
} }
} }
func TestTraceIndexOfRootWithLargeDepth(t *testing.T) {
traceIdx := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 100), big.NewInt(1))
pos := NewLargePositionFromGIndex(big.NewInt(1))
actual := pos.TraceIndex(100)
require.Equal(t, traceIdx, actual)
}
// TestTraceIndex creates the position & then tests the trace index function on the treeNodesMaxDepth4 data // TestTraceIndex creates the position & then tests the trace index function on the treeNodesMaxDepth4 data
func TestTraceIndex(t *testing.T) { func TestTraceIndex(t *testing.T) {
for _, test := range treeNodesMaxDepth4 { for _, test := range treeNodes {
pos := NewPosition(test.Depth, test.IndexAtDepth) pos := NewPosition(test.Depth, test.IndexAtDepth)
result := pos.TraceIndex(4) result := pos.TraceIndex(test.MaxDepth)
require.Equal(t, test.TraceIndex, result) require.Equal(t, test.TraceIndex, result)
} }
} }
func TestAttack(t *testing.T) { func TestAttack(t *testing.T) {
for _, test := range treeNodesMaxDepth4 { for _, test := range treeNodes {
if test.AttackGIndex == 0 { if test.AttackGIndex == nil || test.AttackGIndex.Cmp(big.NewInt(0)) == 0 {
continue continue
} }
pos := NewPosition(test.Depth, test.IndexAtDepth) pos := NewPosition(test.Depth, test.IndexAtDepth)
...@@ -110,8 +126,8 @@ func TestAttack(t *testing.T) { ...@@ -110,8 +126,8 @@ func TestAttack(t *testing.T) {
} }
func TestDefend(t *testing.T) { func TestDefend(t *testing.T) {
for _, test := range treeNodesMaxDepth4 { for _, test := range treeNodes {
if test.DefendGIndex == 0 { if test.DefendGIndex == nil || test.DefendGIndex.Cmp(big.NewInt(0)) == 0 {
continue continue
} }
pos := NewPosition(test.Depth, test.IndexAtDepth) pos := NewPosition(test.Depth, test.IndexAtDepth)
...@@ -122,14 +138,14 @@ func TestDefend(t *testing.T) { ...@@ -122,14 +138,14 @@ func TestDefend(t *testing.T) {
func TestRelativeToAncestorAtDepth(t *testing.T) { func TestRelativeToAncestorAtDepth(t *testing.T) {
t.Run("ErrorsForDeepAncestor", func(t *testing.T) { t.Run("ErrorsForDeepAncestor", func(t *testing.T) {
pos := NewPosition(1, 1) pos := NewPosition(1, big.NewInt(1))
_, err := pos.RelativeToAncestorAtDepth(2) _, err := pos.RelativeToAncestorAtDepth(2)
require.ErrorIs(t, err, ErrPositionDepthTooSmall) require.ErrorIs(t, err, ErrPositionDepthTooSmall)
}) })
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
pos := NewPosition(2, 1) pos := NewPosition(2, big.NewInt(1))
expectedRelativePosition := NewPosition(1, 1) expectedRelativePosition := NewPosition(1, big.NewInt(1))
relativePosition, err := pos.RelativeToAncestorAtDepth(1) relativePosition, err := pos.RelativeToAncestorAtDepth(1)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expectedRelativePosition, relativePosition) require.Equal(t, expectedRelativePosition, relativePosition)
......
...@@ -119,5 +119,5 @@ func (c *Claim) IsRoot() bool { ...@@ -119,5 +119,5 @@ func (c *Claim) IsRoot() bool {
// DefendsParent returns true if the the claim is a defense (i.e. goes right) of the // DefendsParent returns true if the the claim is a defense (i.e. goes right) of the
// parent. It returns false if the claim is an attack (i.e. goes left) of the parent. // parent. It returns false if the claim is an attack (i.e. goes left) of the parent.
func (c *Claim) DefendsParent() bool { func (c *Claim) DefendsParent() bool {
return (c.IndexAtDepth() >> 1) != c.Parent.IndexAtDepth() return c.RightOf(c.Parent.Position)
} }
...@@ -23,3 +23,89 @@ func TestNewPreimageOracleData(t *testing.T) { ...@@ -23,3 +23,89 @@ func TestNewPreimageOracleData(t *testing.T) {
require.Equal(t, uint32(7), data.OracleOffset) require.Equal(t, uint32(7), data.OracleOffset)
}) })
} }
func TestIsRootPosition(t *testing.T) {
tests := []struct {
name string
position Position
expected bool
}{
{
name: "ZeroRoot",
position: NewPositionFromGIndex(0),
expected: true,
},
{
name: "ValidRoot",
position: NewPositionFromGIndex(1),
expected: true,
},
{
name: "NotRoot",
position: NewPositionFromGIndex(2),
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require.Equal(t, test.expected, test.position.IsRootPosition())
})
}
}
func buildClaim(gindex uint64, parentGIndex uint64) Claim {
return Claim{
ClaimData: ClaimData{
Position: NewPositionFromGIndex(gindex),
},
Parent: ClaimData{
Position: NewPositionFromGIndex(parentGIndex),
},
}
}
func TestDefendsParent(t *testing.T) {
tests := []struct {
name string
claim Claim
expected bool
}{
{
name: "LeftChildAttacks",
claim: buildClaim(2, 1),
expected: false,
},
{
name: "RightChildDoesntDefend",
claim: buildClaim(3, 1),
expected: false,
},
{
name: "SubChildDoesntDefend",
claim: buildClaim(4, 1),
expected: false,
},
{
name: "SubSecondChildDoesntDefend",
claim: buildClaim(5, 1),
expected: false,
},
{
name: "RightLeftChildDefendsParent",
claim: buildClaim(6, 1),
expected: true,
},
{
name: "SubThirdChildDefends",
claim: buildClaim(7, 1),
expected: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require.Equal(t, test.expected, test.claim.DefendsParent())
})
}
}
...@@ -34,7 +34,8 @@ import ( ...@@ -34,7 +34,8 @@ import (
const alphabetGameType uint8 = 255 const alphabetGameType uint8 = 255
const cannonGameType uint8 = 0 const cannonGameType uint8 = 0
const alphabetGameDepth = 4 const alphabetGameDepth = 4
const lastAlphabetTraceIndex = 1<<alphabetGameDepth - 1
var lastAlphabetTraceIndex = big.NewInt(1<<alphabetGameDepth - 1)
// rootPosition is the position of the root claim. // rootPosition is the position of the root claim.
var rootPosition = faultTypes.NewPositionFromGIndex(1) var rootPosition = faultTypes.NewPositionFromGIndex(1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment