Commit bbd8b86f authored by Adrian Sutton's avatar Adrian Sutton

op-challenger: Support multiple trace types in the config.

parent 7203324a
...@@ -46,14 +46,14 @@ func TestLogLevel(t *testing.T) { ...@@ -46,14 +46,14 @@ func TestLogLevel(t *testing.T) {
func TestDefaultCLIOptionsMatchDefaultConfig(t *testing.T) { func TestDefaultCLIOptionsMatchDefaultConfig(t *testing.T) {
cfg := configForArgs(t, addRequiredArgs(config.TraceTypeAlphabet)) cfg := configForArgs(t, addRequiredArgs(config.TraceTypeAlphabet))
defaultCfg := config.NewConfig(common.HexToAddress(gameFactoryAddressValue), l1EthRpc, config.TraceTypeAlphabet, true, datadir) defaultCfg := config.NewConfig(common.HexToAddress(gameFactoryAddressValue), l1EthRpc, true, datadir, config.TraceTypeAlphabet)
// Add in the extra CLI options required when using alphabet trace type // Add in the extra CLI options required when using alphabet trace type
defaultCfg.AlphabetTrace = alphabetTrace defaultCfg.AlphabetTrace = alphabetTrace
require.Equal(t, defaultCfg, cfg) require.Equal(t, defaultCfg, cfg)
} }
func TestDefaultConfigIsValid(t *testing.T) { func TestDefaultConfigIsValid(t *testing.T) {
cfg := config.NewConfig(common.HexToAddress(gameFactoryAddressValue), l1EthRpc, config.TraceTypeAlphabet, true, datadir) cfg := config.NewConfig(common.HexToAddress(gameFactoryAddressValue), l1EthRpc, true, datadir, config.TraceTypeAlphabet)
// Add in options that are required based on the specific trace type // Add in options that are required based on the specific trace type
// To avoid needing to specify unused options, these aren't included in the params for NewConfig // To avoid needing to specify unused options, these aren't included in the params for NewConfig
cfg.AlphabetTrace = alphabetTrace cfg.AlphabetTrace = alphabetTrace
...@@ -82,7 +82,7 @@ func TestTraceType(t *testing.T) { ...@@ -82,7 +82,7 @@ func TestTraceType(t *testing.T) {
traceType := traceType traceType := traceType
t.Run("Valid_"+traceType.String(), func(t *testing.T) { t.Run("Valid_"+traceType.String(), func(t *testing.T) {
cfg := configForArgs(t, addRequiredArgs(traceType)) cfg := configForArgs(t, addRequiredArgs(traceType))
require.Equal(t, traceType, cfg.TraceType) require.Equal(t, []config.TraceType{traceType}, cfg.TraceTypes)
}) })
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"runtime" "runtime"
"slices"
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
...@@ -15,7 +16,7 @@ import ( ...@@ -15,7 +16,7 @@ import (
) )
var ( var (
ErrMissingTraceType = errors.New("missing trace type") ErrMissingTraceType = errors.New("no supported trace types specified")
ErrMissingDatadir = errors.New("missing datadir") ErrMissingDatadir = errors.New("missing datadir")
ErrMaxConcurrencyZero = errors.New("max concurrency must not be 0") ErrMaxConcurrencyZero = errors.New("max concurrency must not be 0")
ErrMissingCannonL2 = errors.New("missing cannon L2") ErrMissingCannonL2 = errors.New("missing cannon L2")
...@@ -108,7 +109,7 @@ type Config struct { ...@@ -108,7 +109,7 @@ type Config struct {
MaxConcurrency uint // Maximum number of threads to use when progressing games MaxConcurrency uint // Maximum number of threads to use when progressing games
PollInterval time.Duration // Polling interval for latest-block subscription when using an HTTP RPC provider PollInterval time.Duration // Polling interval for latest-block subscription when using an HTTP RPC provider
TraceType TraceType // Type of trace TraceTypes []TraceType // Type of traces supported
// Specific to the alphabet trace provider // Specific to the alphabet trace provider
AlphabetTrace string // String for the AlphabetTraceProvider AlphabetTrace string // String for the AlphabetTraceProvider
...@@ -135,9 +136,9 @@ type Config struct { ...@@ -135,9 +136,9 @@ type Config struct {
func NewConfig( func NewConfig(
gameFactoryAddress common.Address, gameFactoryAddress common.Address,
l1EthRpc string, l1EthRpc string,
traceType TraceType,
agreeWithProposedOutput bool, agreeWithProposedOutput bool,
datadir string, datadir string,
supportedTraceTypes ...TraceType,
) Config { ) Config {
return Config{ return Config{
L1EthRpc: l1EthRpc, L1EthRpc: l1EthRpc,
...@@ -147,7 +148,7 @@ func NewConfig( ...@@ -147,7 +148,7 @@ func NewConfig(
AgreeWithProposedOutput: agreeWithProposedOutput, AgreeWithProposedOutput: agreeWithProposedOutput,
TraceType: traceType, TraceTypes: supportedTraceTypes,
TxMgrConfig: txmgr.NewCLIConfig(l1EthRpc, txmgr.DefaultChallengerFlagValues), TxMgrConfig: txmgr.NewCLIConfig(l1EthRpc, txmgr.DefaultChallengerFlagValues),
MetricsConfig: opmetrics.DefaultCLIConfig(), MetricsConfig: opmetrics.DefaultCLIConfig(),
...@@ -161,6 +162,10 @@ func NewConfig( ...@@ -161,6 +162,10 @@ func NewConfig(
} }
} }
func (c Config) TraceTypeEnabled(t TraceType) bool {
return slices.Contains(c.TraceTypes, t)
}
func (c Config) Check() error { func (c Config) Check() error {
if c.L1EthRpc == "" { if c.L1EthRpc == "" {
return ErrMissingL1EthRPC return ErrMissingL1EthRPC
...@@ -168,7 +173,7 @@ func (c Config) Check() error { ...@@ -168,7 +173,7 @@ func (c Config) Check() error {
if c.GameFactoryAddress == (common.Address{}) { if c.GameFactoryAddress == (common.Address{}) {
return ErrMissingGameFactoryAddress return ErrMissingGameFactoryAddress
} }
if c.TraceType == "" { if len(c.TraceTypes) == 0 {
return ErrMissingTraceType return ErrMissingTraceType
} }
if c.Datadir == "" { if c.Datadir == "" {
...@@ -177,12 +182,12 @@ func (c Config) Check() error { ...@@ -177,12 +182,12 @@ func (c Config) Check() error {
if c.MaxConcurrency == 0 { if c.MaxConcurrency == 0 {
return ErrMaxConcurrencyZero return ErrMaxConcurrencyZero
} }
if c.TraceType == TraceTypeOutputCannon { if c.TraceTypeEnabled(TraceTypeOutputCannon) {
if c.RollupRpc == "" { if c.RollupRpc == "" {
return ErrMissingRollupRpc return ErrMissingRollupRpc
} }
} }
if c.TraceType == TraceTypeCannon || c.TraceType == TraceTypeOutputCannon { if c.TraceTypeEnabled(TraceTypeCannon) || c.TraceTypeEnabled(TraceTypeOutputCannon) {
if c.CannonBin == "" { if c.CannonBin == "" {
return ErrMissingCannonBin return ErrMissingCannonBin
} }
...@@ -220,7 +225,7 @@ func (c Config) Check() error { ...@@ -220,7 +225,7 @@ func (c Config) Check() error {
return ErrMissingCannonInfoFreq return ErrMissingCannonInfoFreq
} }
} }
if c.TraceType == TraceTypeAlphabet && c.AlphabetTrace == "" { if c.TraceTypeEnabled(TraceTypeAlphabet) && c.AlphabetTrace == "" {
return ErrMissingAlphabetTrace return ErrMissingAlphabetTrace
} }
if err := c.TxMgrConfig.Check(); err != nil { if err := c.TxMgrConfig.Check(); err != nil {
......
...@@ -25,7 +25,7 @@ var ( ...@@ -25,7 +25,7 @@ var (
) )
func validConfig(traceType TraceType) Config { func validConfig(traceType TraceType) Config {
cfg := NewConfig(validGameFactoryAddress, validL1EthRpc, traceType, agreeWithProposedOutput, validDatadir) cfg := NewConfig(validGameFactoryAddress, validL1EthRpc, agreeWithProposedOutput, validDatadir, traceType)
switch traceType { switch traceType {
case TraceTypeAlphabet: case TraceTypeAlphabet:
cfg.AlphabetTrace = validAlphabetTrace cfg.AlphabetTrace = validAlphabetTrace
...@@ -194,3 +194,26 @@ func TestNetworkMustBeValid(t *testing.T) { ...@@ -194,3 +194,26 @@ func TestNetworkMustBeValid(t *testing.T) {
cfg.CannonNetwork = "unknown" cfg.CannonNetwork = "unknown"
require.ErrorIs(t, cfg.Check(), ErrCannonNetworkUnknown) require.ErrorIs(t, cfg.Check(), ErrCannonNetworkUnknown)
} }
func TestRequireConfigForAllSupportedTraceTypes(t *testing.T) {
cfg := validConfig(TraceTypeCannon)
cfg.TraceTypes = []TraceType{TraceTypeCannon, TraceTypeOutputCannon, TraceTypeAlphabet}
// Set all required options and check its valid
cfg.RollupRpc = validRollupRpc
cfg.AlphabetTrace = validAlphabetTrace
require.NoError(t, cfg.Check())
// Require output cannon specific args
cfg.RollupRpc = ""
require.ErrorIs(t, cfg.Check(), ErrMissingRollupRpc)
cfg.RollupRpc = validRollupRpc
// Require cannon specific args
cfg.CannonL2 = ""
require.ErrorIs(t, cfg.Check(), ErrMissingCannonL2)
cfg.CannonL2 = validCannonL2
// Require alphabet specific args
cfg.AlphabetTrace = ""
require.ErrorIs(t, cfg.Check(), ErrMissingAlphabetTrace)
}
...@@ -272,7 +272,7 @@ func NewConfigFromCLI(ctx *cli.Context) (*config.Config, error) { ...@@ -272,7 +272,7 @@ func NewConfigFromCLI(ctx *cli.Context) (*config.Config, error) {
return &config.Config{ return &config.Config{
// Required Flags // Required Flags
L1EthRpc: ctx.String(L1EthRpcFlag.Name), L1EthRpc: ctx.String(L1EthRpcFlag.Name),
TraceType: traceTypeFlag, TraceTypes: []config.TraceType{traceTypeFlag},
GameFactoryAddress: gameFactoryAddress, GameFactoryAddress: gameFactoryAddress,
GameAllowlist: allowedGames, GameAllowlist: allowedGames,
GameWindow: ctx.Duration(GameWindowFlag.Name), GameWindow: ctx.Duration(GameWindowFlag.Name),
......
...@@ -35,8 +35,7 @@ func RegisterGameTypes( ...@@ -35,8 +35,7 @@ func RegisterGameTypes(
txMgr txmgr.TxManager, txMgr txmgr.TxManager,
client bind.ContractCaller, client bind.ContractCaller,
) { ) {
switch cfg.TraceType { if cfg.TraceTypeEnabled(config.TraceTypeCannon) {
case config.TraceTypeCannon:
resourceCreator := func(addr common.Address, gameDepth uint64, dir string) (faultTypes.TraceProvider, faultTypes.OracleUpdater, error) { resourceCreator := func(addr common.Address, gameDepth uint64, dir string) (faultTypes.TraceProvider, faultTypes.OracleUpdater, error) {
provider, err := cannon.NewTraceProvider(ctx, logger, m, cfg, client, dir, addr, gameDepth) provider, err := cannon.NewTraceProvider(ctx, logger, m, cfg, client, dir, addr, gameDepth)
if err != nil { if err != nil {
...@@ -52,7 +51,8 @@ func RegisterGameTypes( ...@@ -52,7 +51,8 @@ func RegisterGameTypes(
return NewGamePlayer(ctx, logger, m, cfg, dir, game.Proxy, txMgr, client, resourceCreator) return NewGamePlayer(ctx, logger, m, cfg, dir, game.Proxy, txMgr, client, resourceCreator)
} }
registry.RegisterGameType(cannonGameType, playerCreator) registry.RegisterGameType(cannonGameType, playerCreator)
case config.TraceTypeAlphabet: }
if cfg.TraceTypeEnabled(config.TraceTypeAlphabet) {
resourceCreator := func(addr common.Address, gameDepth uint64, dir string) (faultTypes.TraceProvider, faultTypes.OracleUpdater, error) { resourceCreator := func(addr common.Address, gameDepth uint64, dir string) (faultTypes.TraceProvider, faultTypes.OracleUpdater, error) {
provider := alphabet.NewTraceProvider(cfg.AlphabetTrace, gameDepth) provider := alphabet.NewTraceProvider(cfg.AlphabetTrace, gameDepth)
updater := alphabet.NewOracleUpdater(logger) updater := alphabet.NewOracleUpdater(logger)
......
...@@ -24,7 +24,7 @@ func TestGenerateProof(t *testing.T) { ...@@ -24,7 +24,7 @@ func TestGenerateProof(t *testing.T) {
input := "starting.json" input := "starting.json"
tempDir := t.TempDir() tempDir := t.TempDir()
dir := filepath.Join(tempDir, "gameDir") dir := filepath.Join(tempDir, "gameDir")
cfg := config.NewConfig(common.Address{0xbb}, "http://localhost:8888", config.TraceTypeCannon, true, tempDir) cfg := config.NewConfig(common.Address{0xbb}, "http://localhost:8888", true, tempDir, config.TraceTypeCannon)
cfg.CannonAbsolutePreState = "pre.json" cfg.CannonAbsolutePreState = "pre.json"
cfg.CannonBin = "./bin/cannon" cfg.CannonBin = "./bin/cannon"
cfg.CannonServer = "./bin/op-program" cfg.CannonServer = "./bin/op-program"
......
...@@ -126,7 +126,7 @@ func NewChallenger(t *testing.T, ctx context.Context, l1Endpoint string, name st ...@@ -126,7 +126,7 @@ func NewChallenger(t *testing.T, ctx context.Context, l1Endpoint string, name st
func NewChallengerConfig(t *testing.T, l1Endpoint string, options ...Option) *config.Config { func NewChallengerConfig(t *testing.T, l1Endpoint string, options ...Option) *config.Config {
// Use the NewConfig method to ensure we pick up any defaults that are set. // Use the NewConfig method to ensure we pick up any defaults that are set.
cfg := config.NewConfig(common.Address{}, l1Endpoint, config.TraceTypeAlphabet, true, t.TempDir()) cfg := config.NewConfig(common.Address{}, l1Endpoint, true, t.TempDir(), config.TraceTypeAlphabet)
cfg.TxMgrConfig.NumConfirmations = 1 cfg.TxMgrConfig.NumConfirmations = 1
cfg.TxMgrConfig.ReceiptQueryInterval = 1 * time.Second cfg.TxMgrConfig.ReceiptQueryInterval = 1 * time.Second
if cfg.MaxConcurrency > 4 { if cfg.MaxConcurrency > 4 {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment