Commit ab90a101 authored by Andreas Bigger's avatar Andreas Bigger

fix: tests and nits :test_tube:

parent b256b358
...@@ -127,11 +127,13 @@ func (n *OpNode) initL1(ctx context.Context, cfg *Config) error { ...@@ -127,11 +127,13 @@ func (n *OpNode) initL1(ctx context.Context, cfg *Config) error {
// Validate the L1 Client Chain ID // Validate the L1 Client Chain ID
if err := cfg.Rollup.CheckL1ChainID(ctx, n.l1Source); err != nil { if err := cfg.Rollup.CheckL1ChainID(ctx, n.l1Source); err != nil {
n.log.Error("failed to verify L1 RPC chain id", "err", err) n.log.Error("failed to verify L1 RPC chain id", "err", err)
return err
} }
// Validate the Rollup L1 Genesis Blockhash // Validate the Rollup L1 Genesis Blockhash
if err := cfg.Rollup.CheckL1GenesisBlockHash(ctx, n.l1Source); err != nil { if err := cfg.Rollup.CheckL1GenesisBlockHash(ctx, n.l1Source); err != nil {
n.log.Error("failed to verify L1 genesis block hash", "err", err) n.log.Error("failed to verify L1 genesis block hash", "err", err)
return err
} }
// Keep subscribed to the L1 heads, which keeps the L1 maintainer pointing to the best headers to sync // Keep subscribed to the L1 heads, which keeps the L1 maintainer pointing to the best headers to sync
...@@ -202,11 +204,13 @@ func (n *OpNode) initL2(ctx context.Context, cfg *Config, snapshotLog log.Logger ...@@ -202,11 +204,13 @@ func (n *OpNode) initL2(ctx context.Context, cfg *Config, snapshotLog log.Logger
// Validate the L2 Client Chain ID // Validate the L2 Client Chain ID
if err := cfg.Rollup.CheckL2ChainID(ctx, n.l2Source); err != nil { if err := cfg.Rollup.CheckL2ChainID(ctx, n.l2Source); err != nil {
n.log.Error("failed to verify L2 RPC chain id", "err", err) n.log.Error("failed to verify L2 RPC chain id", "err", err)
return err
} }
// Validate the Rollup L2 Genesis Blockhash // Validate the Rollup L2 Genesis Blockhash
if err := cfg.Rollup.CheckL2GenesisBlockHash(ctx, n.l2Source); err != nil { if err := cfg.Rollup.CheckL2GenesisBlockHash(ctx, n.l2Source); err != nil {
n.log.Error("failed to verify L2 genesis block hash", "err", err) n.log.Error("failed to verify L2 genesis block hash", "err", err)
return err
} }
n.l2Driver = driver.NewDriver(&cfg.Driver, &cfg.Rollup, n.l2Source, n.l1Source, n, n.log, snapshotLog, n.metrics) n.l2Driver = driver.NewDriver(&cfg.Driver, &cfg.Rollup, n.l2Source, n.l1Source, n, n.log, snapshotLog, n.metrics)
......
...@@ -67,7 +67,7 @@ func (cfg *Config) CheckL1ChainID(ctx context.Context, client L1Client) error { ...@@ -67,7 +67,7 @@ func (cfg *Config) CheckL1ChainID(ctx context.Context, client L1Client) error {
if err != nil { if err != nil {
return err return err
} }
if cfg.L1ChainID != id { if cfg.L1ChainID.Cmp(id) != 0 {
return fmt.Errorf("incorrect L1 RPC chain id %d, expected %d", cfg.L1ChainID, id) return fmt.Errorf("incorrect L1 RPC chain id %d, expected %d", cfg.L1ChainID, id)
} }
return nil return nil
...@@ -96,7 +96,7 @@ func (cfg *Config) CheckL2ChainID(ctx context.Context, client L2Client) error { ...@@ -96,7 +96,7 @@ func (cfg *Config) CheckL2ChainID(ctx context.Context, client L2Client) error {
if err != nil { if err != nil {
return err return err
} }
if cfg.L2ChainID != id { if cfg.L2ChainID.Cmp(id) != 0 {
return fmt.Errorf("incorrect L2 RPC chain id %d, expected %d", cfg.L2ChainID, id) return fmt.Errorf("incorrect L2 RPC chain id %d, expected %d", cfg.L2ChainID, id)
} }
return nil return nil
......
package rollup package rollup
import ( import (
"context"
"encoding/json" "encoding/json"
"math/big" "math/big"
"math/rand" "math/rand"
...@@ -55,3 +56,87 @@ func TestConfigJSON(t *testing.T) { ...@@ -55,3 +56,87 @@ func TestConfigJSON(t *testing.T) {
assert.NoError(t, json.Unmarshal(data, &roundTripped)) assert.NoError(t, json.Unmarshal(data, &roundTripped))
assert.Equal(t, &roundTripped, config) assert.Equal(t, &roundTripped, config)
} }
type mockL1Client struct {
chainID *big.Int
Hash common.Hash
}
func (m *mockL1Client) L1ChainID(context.Context) (*big.Int, error) {
return m.chainID, nil
}
func (m *mockL1Client) L1BlockRefByNumber(ctx context.Context, number uint64) (eth.L1BlockRef, error) {
return eth.L1BlockRef{
Hash: m.Hash,
Number: 100,
}, nil
}
func TestCheckL1ChainID(t *testing.T) {
config := randConfig()
config.L1ChainID = big.NewInt(100)
err := config.CheckL1ChainID(context.TODO(), &mockL1Client{chainID: big.NewInt(100)})
assert.NoError(t, err)
err = config.CheckL1ChainID(context.TODO(), &mockL1Client{chainID: big.NewInt(101)})
assert.Error(t, err)
err = config.CheckL1ChainID(context.TODO(), &mockL1Client{chainID: big.NewInt(99)})
assert.Error(t, err)
}
func TestCheckL1BlockRefByNumber(t *testing.T) {
config := randConfig()
config.Genesis.L1.Number = 100
config.Genesis.L1.Hash = [32]byte{0x01}
mockClient := mockL1Client{chainID: big.NewInt(100), Hash: common.Hash{0x01}}
err := config.CheckL1GenesisBlockHash(context.TODO(), &mockClient)
assert.NoError(t, err)
mockClient.Hash = common.Hash{0x02}
err = config.CheckL1GenesisBlockHash(context.TODO(), &mockClient)
assert.Error(t, err)
mockClient.Hash = common.Hash{0x00}
err = config.CheckL1GenesisBlockHash(context.TODO(), &mockClient)
assert.Error(t, err)
}
type mockL2Client struct {
chainID *big.Int
Hash common.Hash
}
func (m *mockL2Client) L2ChainID(context.Context) (*big.Int, error) {
return m.chainID, nil
}
func (m *mockL2Client) L2BlockRefByNumber(ctx context.Context, number uint64) (eth.L2BlockRef, error) {
return eth.L2BlockRef{
Hash: m.Hash,
Number: 100,
}, nil
}
func TestCheckL2ChainID(t *testing.T) {
config := randConfig()
config.L2ChainID = big.NewInt(100)
err := config.CheckL2ChainID(context.TODO(), &mockL2Client{chainID: big.NewInt(100)})
assert.NoError(t, err)
err = config.CheckL2ChainID(context.TODO(), &mockL2Client{chainID: big.NewInt(101)})
assert.Error(t, err)
err = config.CheckL2ChainID(context.TODO(), &mockL2Client{chainID: big.NewInt(99)})
assert.Error(t, err)
}
func TestCheckL2BlockRefByNumber(t *testing.T) {
config := randConfig()
config.Genesis.L2.Number = 100
config.Genesis.L2.Hash = [32]byte{0x01}
mockClient := mockL2Client{chainID: big.NewInt(100), Hash: common.Hash{0x01}}
err := config.CheckL2GenesisBlockHash(context.TODO(), &mockClient)
assert.NoError(t, err)
mockClient.Hash = common.Hash{0x02}
err = config.CheckL2GenesisBlockHash(context.TODO(), &mockClient)
assert.Error(t, err)
mockClient.Hash = common.Hash{0x00}
err = config.CheckL2GenesisBlockHash(context.TODO(), &mockClient)
assert.Error(t, err)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment