Commit 672adac1 authored by Adrian Sutton's avatar Adrian Sutton

op-challenger: Allow multiple trace types to be specified on the CLI.

parent bbd8b86f
...@@ -91,6 +91,41 @@ func TestTraceType(t *testing.T) { ...@@ -91,6 +91,41 @@ func TestTraceType(t *testing.T) {
}) })
} }
func TestMultipleTraceTypes(t *testing.T) {
t.Run("WithAllOptions", func(t *testing.T) {
argsMap := requiredArgs(config.TraceTypeCannon)
addRequiredOutputCannonArgs(argsMap)
addRequiredAlphabetArgs(argsMap)
args := toArgList(argsMap)
// Add extra trace types (cannon is already specified)
args = append(args,
"--trace-type", config.TraceTypeOutputCannon.String(),
"--trace-type", config.TraceTypeAlphabet.String())
cfg := configForArgs(t, args)
require.Equal(t, []config.TraceType{config.TraceTypeCannon, config.TraceTypeOutputCannon, config.TraceTypeAlphabet}, cfg.TraceTypes)
})
t.Run("WithSomeOptions", func(t *testing.T) {
argsMap := requiredArgs(config.TraceTypeCannon)
addRequiredAlphabetArgs(argsMap)
args := toArgList(argsMap)
// Add extra trace types (cannon is already specified)
args = append(args,
"--trace-type", config.TraceTypeAlphabet.String())
cfg := configForArgs(t, args)
require.Equal(t, []config.TraceType{config.TraceTypeCannon, config.TraceTypeAlphabet}, cfg.TraceTypes)
})
t.Run("SpecifySameOptionMultipleTimes", func(t *testing.T) {
argsMap := requiredArgs(config.TraceTypeCannon)
args := toArgList(argsMap)
// Add cannon trace type again
args = append(args, "--trace-type", config.TraceTypeCannon.String())
// We're fine with the same option being listed multiple times, just deduplicate them.
cfg := configForArgs(t, args)
require.Equal(t, []config.TraceType{config.TraceTypeCannon}, cfg.TraceTypes)
})
}
func TestGameFactoryAddress(t *testing.T) { func TestGameFactoryAddress(t *testing.T) {
t.Run("Required", func(t *testing.T) { t.Run("Required", func(t *testing.T) {
verifyArgsInvalid(t, "flag game-factory-address is required", addRequiredArgsExcept(config.TraceTypeAlphabet, "--game-factory-address")) verifyArgsInvalid(t, "flag game-factory-address is required", addRequiredArgsExcept(config.TraceTypeAlphabet, "--game-factory-address"))
...@@ -441,18 +476,30 @@ func requiredArgs(traceType config.TraceType) map[string]string { ...@@ -441,18 +476,30 @@ func requiredArgs(traceType config.TraceType) map[string]string {
} }
switch traceType { switch traceType {
case config.TraceTypeAlphabet: case config.TraceTypeAlphabet:
addRequiredAlphabetArgs(args)
case config.TraceTypeCannon:
addRequiredCannonArgs(args)
case config.TraceTypeOutputCannon:
addRequiredOutputCannonArgs(args)
}
return args
}
func addRequiredAlphabetArgs(args map[string]string) {
args["--alphabet"] = alphabetTrace args["--alphabet"] = alphabetTrace
case config.TraceTypeCannon, config.TraceTypeOutputCannon: }
func addRequiredOutputCannonArgs(args map[string]string) {
addRequiredCannonArgs(args)
args["--rollup-rpc"] = rollupRpc
}
func addRequiredCannonArgs(args map[string]string) {
args["--cannon-network"] = cannonNetwork args["--cannon-network"] = cannonNetwork
args["--cannon-bin"] = cannonBin args["--cannon-bin"] = cannonBin
args["--cannon-server"] = cannonServer args["--cannon-server"] = cannonServer
args["--cannon-prestate"] = cannonPreState args["--cannon-prestate"] = cannonPreState
args["--cannon-l2"] = cannonL2 args["--cannon-l2"] = cannonL2
}
if traceType == config.TraceTypeOutputCannon {
args["--rollup-rpc"] = rollupRpc
}
return args
} }
func toArgList(req map[string]string) []string { func toArgList(req map[string]string) []string {
......
...@@ -3,6 +3,7 @@ package flags ...@@ -3,6 +3,7 @@ package flags
import ( import (
"fmt" "fmt"
"runtime" "runtime"
"slices"
"strings" "strings"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
...@@ -44,14 +45,10 @@ var ( ...@@ -44,14 +45,10 @@ var (
"If empty, the challenger will play all games.", "If empty, the challenger will play all games.",
EnvVars: prefixEnvVars("GAME_ALLOWLIST"), EnvVars: prefixEnvVars("GAME_ALLOWLIST"),
} }
TraceTypeFlag = &cli.GenericFlag{ TraceTypeFlag = &cli.StringSliceFlag{
Name: "trace-type", Name: "trace-type",
Usage: "The trace type. Valid options: " + openum.EnumString(config.TraceTypes), Usage: "The trace types to support. Valid options: " + openum.EnumString(config.TraceTypes),
EnvVars: prefixEnvVars("TRACE_TYPE"), EnvVars: prefixEnvVars("TRACE_TYPE"),
Value: func() *config.TraceType {
out := config.TraceType("") // No default value
return &out
}(),
} }
AgreeWithProposedOutputFlag = &cli.BoolFlag{ AgreeWithProposedOutputFlag = &cli.BoolFlag{
Name: "agree-with-proposed-output", Name: "agree-with-proposed-output",
...@@ -210,14 +207,14 @@ func CheckCannonFlags(ctx *cli.Context) error { ...@@ -210,14 +207,14 @@ func CheckCannonFlags(ctx *cli.Context) error {
return nil return nil
} }
func CheckRequired(ctx *cli.Context) error { func CheckRequired(ctx *cli.Context, traceTypes []config.TraceType) error {
for _, f := range requiredFlags { for _, f := range requiredFlags {
if !ctx.IsSet(f.Names()[0]) { if !ctx.IsSet(f.Names()[0]) {
return fmt.Errorf("flag %s is required", f.Names()[0]) return fmt.Errorf("flag %s is required", f.Names()[0])
} }
} }
gameType := config.TraceType(strings.ToLower(ctx.String(TraceTypeFlag.Name))) for _, traceType := range traceTypes {
switch gameType { switch traceType {
case config.TraceTypeCannon: case config.TraceTypeCannon:
if err := CheckCannonFlags(ctx); err != nil { if err := CheckCannonFlags(ctx); err != nil {
return err return err
...@@ -236,12 +233,31 @@ func CheckRequired(ctx *cli.Context) error { ...@@ -236,12 +233,31 @@ func CheckRequired(ctx *cli.Context) error {
default: default:
return fmt.Errorf("invalid trace type. must be one of %v", config.TraceTypes) return fmt.Errorf("invalid trace type. must be one of %v", config.TraceTypes)
} }
}
return nil return nil
} }
func parseTraceTypes(ctx *cli.Context) ([]config.TraceType, error) {
var traceTypes []config.TraceType
for _, typeName := range ctx.StringSlice(TraceTypeFlag.Name) {
traceType := new(config.TraceType)
if err := traceType.Set(typeName); err != nil {
return nil, err
}
if !slices.Contains(traceTypes, *traceType) {
traceTypes = append(traceTypes, *traceType)
}
}
return traceTypes, nil
}
// NewConfigFromCLI parses the Config from the provided flags or environment variables. // NewConfigFromCLI parses the Config from the provided flags or environment variables.
func NewConfigFromCLI(ctx *cli.Context) (*config.Config, error) { func NewConfigFromCLI(ctx *cli.Context) (*config.Config, error) {
if err := CheckRequired(ctx); err != nil { traceTypes, err := parseTraceTypes(ctx)
if err != nil {
return nil, err
}
if err := CheckRequired(ctx, traceTypes); err != nil {
return nil, err return nil, err
} }
gameFactoryAddress, err := opservice.ParseAddress(ctx.String(FactoryAddressFlag.Name)) gameFactoryAddress, err := opservice.ParseAddress(ctx.String(FactoryAddressFlag.Name))
...@@ -263,8 +279,6 @@ func NewConfigFromCLI(ctx *cli.Context) (*config.Config, error) { ...@@ -263,8 +279,6 @@ func NewConfigFromCLI(ctx *cli.Context) (*config.Config, error) {
metricsConfig := opmetrics.ReadCLIConfig(ctx) metricsConfig := opmetrics.ReadCLIConfig(ctx)
pprofConfig := oppprof.ReadCLIConfig(ctx) pprofConfig := oppprof.ReadCLIConfig(ctx)
traceTypeFlag := config.TraceType(strings.ToLower(ctx.String(TraceTypeFlag.Name)))
maxConcurrency := ctx.Uint(MaxConcurrencyFlag.Name) maxConcurrency := ctx.Uint(MaxConcurrencyFlag.Name)
if maxConcurrency == 0 { if maxConcurrency == 0 {
return nil, fmt.Errorf("%v must not be 0", MaxConcurrencyFlag.Name) return nil, fmt.Errorf("%v must not be 0", MaxConcurrencyFlag.Name)
...@@ -272,7 +286,7 @@ func NewConfigFromCLI(ctx *cli.Context) (*config.Config, error) { ...@@ -272,7 +286,7 @@ func NewConfigFromCLI(ctx *cli.Context) (*config.Config, error) {
return &config.Config{ return &config.Config{
// Required Flags // Required Flags
L1EthRpc: ctx.String(L1EthRpcFlag.Name), L1EthRpc: ctx.String(L1EthRpcFlag.Name),
TraceTypes: []config.TraceType{traceTypeFlag}, TraceTypes: traceTypes,
GameFactoryAddress: gameFactoryAddress, GameFactoryAddress: gameFactoryAddress,
GameAllowlist: allowedGames, GameAllowlist: allowedGames,
GameWindow: ctx.Duration(GameWindowFlag.Name), GameWindow: ctx.Duration(GameWindowFlag.Name),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment