Commit 67ba188e authored by Adrian Sutton's avatar Adrian Sutton Committed by GitHub

cannon: Support binary serialisation for snapshots (#11718)

* cannon: Add serialize utils for binary formats and automatic binary/json detection.

* cannon: Support reading and writing states as binary or JSON

* cannon: Generate mt prestate as gzipped binary.

Use different versions for singlethreaded and multithreaded states.

* cannon: Improve comments for serialization

* cannon: Review feedback

* cannon: Introduce reader and writer helpers to simplify code.
parent ac19f2f9
...@@ -146,8 +146,8 @@ cannon-prestate: op-program cannon ## Generates prestate using cannon and op-pro ...@@ -146,8 +146,8 @@ cannon-prestate: op-program cannon ## Generates prestate using cannon and op-pro
.PHONY: cannon-prestate .PHONY: cannon-prestate
cannon-prestate-mt: op-program cannon ## Generates prestate using cannon and op-program in the multithreaded cannon format cannon-prestate-mt: op-program cannon ## Generates prestate using cannon and op-program in the multithreaded cannon format
./cannon/bin/cannon load-elf --type mt --path op-program/bin/op-program-client.elf --out op-program/bin/prestate-mt.json --meta op-program/bin/meta-mt.json ./cannon/bin/cannon load-elf --type cannon-mt --path op-program/bin/op-program-client.elf --out op-program/bin/prestate-mt.bin.gz --meta op-program/bin/meta-mt.json
./cannon/bin/cannon run --type mt --proof-at '=0' --stop-at '=1' --input op-program/bin/prestate-mt.json --meta op-program/bin/meta-mt.json --proof-fmt 'op-program/bin/%d-mt.json' --output "" ./cannon/bin/cannon run --type cannon-mt --proof-at '=0' --stop-at '=1' --input op-program/bin/prestate-mt.bin.gz --meta op-program/bin/meta-mt.json --proof-fmt 'op-program/bin/%d-mt.json' --output ""
mv op-program/bin/0-mt.json op-program/bin/prestate-proof-mt.json mv op-program/bin/0-mt.json op-program/bin/prestate-proof-mt.json
.PHONY: cannon-prestate .PHONY: cannon-prestate
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"github.com/ethereum-optimism/optimism/cannon/mipsevm" "github.com/ethereum-optimism/optimism/cannon/mipsevm"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/multithreaded" "github.com/ethereum-optimism/optimism/cannon/mipsevm/multithreaded"
"github.com/ethereum-optimism/optimism/cannon/serialize"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/program" "github.com/ethereum-optimism/optimism/cannon/mipsevm/program"
...@@ -51,14 +52,14 @@ func LoadELF(ctx *cli.Context) error { ...@@ -51,14 +52,14 @@ func LoadELF(ctx *cli.Context) error {
return program.LoadELF(f, singlethreaded.CreateInitialState) return program.LoadELF(f, singlethreaded.CreateInitialState)
} }
writeState = func(path string, state mipsevm.FPVMState) error { writeState = func(path string, state mipsevm.FPVMState) error {
return jsonutil.WriteJSON[*singlethreaded.State](path, state.(*singlethreaded.State), OutFilePerm) return serialize.Write[*singlethreaded.State](path, state.(*singlethreaded.State), OutFilePerm)
} }
} else if vmType == mtVMType { } else if vmType == mtVMType {
createInitialState = func(f *elf.File) (mipsevm.FPVMState, error) { createInitialState = func(f *elf.File) (mipsevm.FPVMState, error) {
return program.LoadELF(f, multithreaded.CreateInitialState) return program.LoadELF(f, multithreaded.CreateInitialState)
} }
writeState = func(path string, state mipsevm.FPVMState) error { writeState = func(path string, state mipsevm.FPVMState) error {
return jsonutil.WriteJSON[*multithreaded.State](path, state.(*multithreaded.State), OutFilePerm) return serialize.Write[*multithreaded.State](path, state.(*multithreaded.State), OutFilePerm)
} }
} else { } else {
return fmt.Errorf("invalid VM type: %q", vmType) return fmt.Errorf("invalid VM type: %q", vmType)
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/multithreaded" "github.com/ethereum-optimism/optimism/cannon/mipsevm/multithreaded"
"github.com/ethereum-optimism/optimism/cannon/serialize"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
...@@ -454,7 +455,7 @@ func Run(ctx *cli.Context) error { ...@@ -454,7 +455,7 @@ func Run(ctx *cli.Context) error {
} }
if snapshotAt(state) { if snapshotAt(state) {
if err := jsonutil.WriteJSON(fmt.Sprintf(snapshotFmt, step), state, OutFilePerm); err != nil { if err := serialize.Write(fmt.Sprintf(snapshotFmt, step), state, OutFilePerm); err != nil {
return fmt.Errorf("failed to write state snapshot: %w", err) return fmt.Errorf("failed to write state snapshot: %w", err)
} }
} }
...@@ -511,7 +512,7 @@ func Run(ctx *cli.Context) error { ...@@ -511,7 +512,7 @@ func Run(ctx *cli.Context) error {
vm.Traceback() vm.Traceback()
} }
if err := jsonutil.WriteJSON(ctx.Path(RunOutputFlag.Name), state, OutFilePerm); err != nil { if err := serialize.Write(ctx.Path(RunOutputFlag.Name), state, OutFilePerm); err != nil {
return fmt.Errorf("failed to write state output: %w", err) return fmt.Errorf("failed to write state output: %w", err)
} }
if debugInfoFile := ctx.Path(RunDebugInfoFlag.Name); debugInfoFile != "" { if debugInfoFile := ctx.Path(RunDebugInfoFlag.Name); debugInfoFile != "" {
......
...@@ -6,10 +6,10 @@ import ( ...@@ -6,10 +6,10 @@ import (
"github.com/ethereum-optimism/optimism/cannon/mipsevm" "github.com/ethereum-optimism/optimism/cannon/mipsevm"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/multithreaded" "github.com/ethereum-optimism/optimism/cannon/mipsevm/multithreaded"
"github.com/ethereum-optimism/optimism/cannon/serialize"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/singlethreaded" "github.com/ethereum-optimism/optimism/cannon/mipsevm/singlethreaded"
"github.com/ethereum-optimism/optimism/op-service/jsonutil"
) )
var ( var (
...@@ -33,12 +33,12 @@ func Witness(ctx *cli.Context) error { ...@@ -33,12 +33,12 @@ func Witness(ctx *cli.Context) error {
if vmType, err := vmTypeFromString(ctx); err != nil { if vmType, err := vmTypeFromString(ctx); err != nil {
return err return err
} else if vmType == cannonVMType { } else if vmType == cannonVMType {
state, err = jsonutil.LoadJSON[singlethreaded.State](input) state, err = serialize.Load[singlethreaded.State](input)
if err != nil { if err != nil {
return fmt.Errorf("invalid input state (%v): %w", input, err) return fmt.Errorf("invalid input state (%v): %w", input, err)
} }
} else if vmType == mtVMType { } else if vmType == mtVMType {
state, err = jsonutil.LoadJSON[multithreaded.State](input) state, err = serialize.Load[multithreaded.State](input)
if err != nil { if err != nil {
return fmt.Errorf("invalid input state (%v): %w", input, err) return fmt.Errorf("invalid input state (%v): %w", input, err)
} }
......
package mipsevm package mipsevm
import ( import (
"github.com/ethereum-optimism/optimism/cannon/serialize"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
...@@ -8,6 +9,8 @@ import ( ...@@ -8,6 +9,8 @@ import (
) )
type FPVMState interface { type FPVMState interface {
serialize.Serializable
GetMemory() *memory.Memory GetMemory() *memory.Memory
// GetHeap returns the current memory address at the top of the heap // GetHeap returns the current memory address at the top of the heap
......
...@@ -286,6 +286,48 @@ func (m *Memory) SetMemoryRange(addr uint32, r io.Reader) error { ...@@ -286,6 +286,48 @@ func (m *Memory) SetMemoryRange(addr uint32, r io.Reader) error {
} }
} }
// Serialize writes the memory in a simple binary format which can be read again using Deserialize
// The format is a simple concatenation of fields, with prefixed item count for repeating items and using big endian
// encoding for numbers.
//
// len(PageCount) uint32
// For each page (order is arbitrary):
//
// page index uint32
// page Data [PageSize]byte
func (m *Memory) Serialize(out io.Writer) error {
if err := binary.Write(out, binary.BigEndian, uint32(m.PageCount())); err != nil {
return err
}
for pageIndex, page := range m.pages {
if err := binary.Write(out, binary.BigEndian, pageIndex); err != nil {
return err
}
if _, err := out.Write(page.Data[:]); err != nil {
return err
}
}
return nil
}
func (m *Memory) Deserialize(in io.Reader) error {
var pageCount uint32
if err := binary.Read(in, binary.BigEndian, &pageCount); err != nil {
return err
}
for i := uint32(0); i < pageCount; i++ {
var pageIndex uint32
if err := binary.Read(in, binary.BigEndian, &pageIndex); err != nil {
return err
}
page := m.AllocPage(pageIndex)
if _, err := io.ReadFull(in, page.Data[:]); err != nil {
return err
}
}
return nil
}
type memReader struct { type memReader struct {
m *Memory m *Memory
addr uint32 addr uint32
......
...@@ -3,13 +3,13 @@ package multithreaded ...@@ -3,13 +3,13 @@ package multithreaded
import ( import (
"io" "io"
"github.com/ethereum-optimism/optimism/cannon/serialize"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum-optimism/optimism/cannon/mipsevm" "github.com/ethereum-optimism/optimism/cannon/mipsevm"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/exec" "github.com/ethereum-optimism/optimism/cannon/mipsevm/exec"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/program" "github.com/ethereum-optimism/optimism/cannon/mipsevm/program"
"github.com/ethereum-optimism/optimism/op-service/jsonutil"
) )
type InstrumentedState struct { type InstrumentedState struct {
...@@ -41,7 +41,7 @@ func NewInstrumentedState(state *State, po mipsevm.PreimageOracle, stdOut, stdEr ...@@ -41,7 +41,7 @@ func NewInstrumentedState(state *State, po mipsevm.PreimageOracle, stdOut, stdEr
} }
func NewInstrumentedStateFromFile(stateFile string, po mipsevm.PreimageOracle, stdOut, stdErr io.Writer, log log.Logger) (*InstrumentedState, error) { func NewInstrumentedStateFromFile(stateFile string, po mipsevm.PreimageOracle, stdOut, stdErr io.Writer, log log.Logger) (*InstrumentedState, error) {
state, err := jsonutil.LoadJSON[State](stateFile) state, err := serialize.Load[State](stateFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -3,7 +3,10 @@ package multithreaded ...@@ -3,7 +3,10 @@ package multithreaded
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/versions"
"github.com/ethereum-optimism/optimism/cannon/serialize"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
...@@ -219,6 +222,166 @@ func (s *State) ThreadCount() int { ...@@ -219,6 +222,166 @@ func (s *State) ThreadCount() int {
return len(s.LeftThreadStack) + len(s.RightThreadStack) return len(s.LeftThreadStack) + len(s.RightThreadStack)
} }
// Serialize writes the state in a simple binary format which can be read again using Deserialize
// The format is a simple concatenation of fields, with prefixed item count for repeating items and using big endian
// encoding for numbers.
//
// StateVersion uint8(1)
// Memory As per Memory.Serialize
// PreimageKey [32]byte
// PreimageOffset uint32
// Heap uint32
// ExitCode uint8
// Exited uint8 - 0 for false, 1 for true
// Step uint64
// StepsSinceLastContextSwitch uint64
// Wakeup uint32
// TraverseRight uint8 - 0 for false, 1 for true
// NextThreadId uint32
// len(LeftThreadStack) uint32
// LeftThreadStack entries as per ThreadState.Serialize
// len(RightThreadStack) uint32
// RightThreadStack entries as per ThreadState.Serialize
// len(LastHint) uint32 (0 when LastHint is nil)
// LastHint []byte
func (s *State) Serialize(out io.Writer) error {
bout := serialize.NewBinaryWriter(out)
if err := bout.WriteUInt(versions.VersionMultiThreaded); err != nil {
return err
}
if err := s.Memory.Serialize(out); err != nil {
return err
}
if err := bout.WriteHash(s.PreimageKey); err != nil {
return err
}
if err := bout.WriteUInt(s.PreimageOffset); err != nil {
return err
}
if err := bout.WriteUInt(s.Heap); err != nil {
return err
}
if err := bout.WriteUInt(s.ExitCode); err != nil {
return err
}
if err := bout.WriteBool(s.Exited); err != nil {
return err
}
if err := bout.WriteUInt(s.Step); err != nil {
return err
}
if err := bout.WriteUInt(s.StepsSinceLastContextSwitch); err != nil {
return err
}
if err := bout.WriteUInt(s.Wakeup); err != nil {
return err
}
if err := bout.WriteBool(s.TraverseRight); err != nil {
return err
}
if err := bout.WriteUInt(s.NextThreadId); err != nil {
return err
}
if err := bout.WriteUInt(uint32(len(s.LeftThreadStack))); err != nil {
return err
}
for _, stack := range s.LeftThreadStack {
if err := stack.Serialize(out); err != nil {
return err
}
}
if err := bout.WriteUInt(uint32(len(s.RightThreadStack))); err != nil {
return err
}
for _, stack := range s.RightThreadStack {
if err := stack.Serialize(out); err != nil {
return err
}
}
if err := bout.WriteBytes(s.LastHint); err != nil {
return err
}
return nil
}
func (s *State) Deserialize(in io.Reader) error {
bin := serialize.NewBinaryReader(in)
var version versions.StateVersion
if err := bin.ReadUInt(&version); err != nil {
return err
}
if version != versions.VersionMultiThreaded {
return fmt.Errorf("invalid state encoding version %d", version)
}
s.Memory = memory.NewMemory()
if err := s.Memory.Deserialize(in); err != nil {
return err
}
if err := bin.ReadHash(&s.PreimageKey); err != nil {
return err
}
if err := bin.ReadUInt(&s.PreimageOffset); err != nil {
return err
}
if err := bin.ReadUInt(&s.Heap); err != nil {
return err
}
if err := bin.ReadUInt(&s.ExitCode); err != nil {
return err
}
if err := bin.ReadBool(&s.Exited); err != nil {
return err
}
if err := bin.ReadUInt(&s.Step); err != nil {
return err
}
if err := bin.ReadUInt(&s.StepsSinceLastContextSwitch); err != nil {
return err
}
if err := bin.ReadUInt(&s.Wakeup); err != nil {
return err
}
if err := bin.ReadBool(&s.TraverseRight); err != nil {
return err
}
if err := bin.ReadUInt(&s.NextThreadId); err != nil {
return err
}
var leftThreadStackSize uint32
if err := bin.ReadUInt(&leftThreadStackSize); err != nil {
return err
}
s.LeftThreadStack = make([]*ThreadState, leftThreadStackSize)
for i := range s.LeftThreadStack {
s.LeftThreadStack[i] = &ThreadState{}
if err := s.LeftThreadStack[i].Deserialize(in); err != nil {
return err
}
}
var rightThreadStackSize uint32
if err := bin.ReadUInt(&rightThreadStackSize); err != nil {
return err
}
s.RightThreadStack = make([]*ThreadState, rightThreadStackSize)
for i := range s.RightThreadStack {
s.RightThreadStack[i] = &ThreadState{}
if err := s.RightThreadStack[i].Deserialize(in); err != nil {
return err
}
}
if err := bin.ReadBytes((*[]byte)(&s.LastHint)); err != nil {
return err
}
return nil
}
type StateWitness []byte type StateWitness []byte
func (sw StateWitness) StateHash() (common.Hash, error) { func (sw StateWitness) StateHash() (common.Hash, error) {
......
package multithreaded package multithreaded
import ( import (
"bytes"
"debug/elf" "debug/elf"
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/memory"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -123,6 +127,154 @@ func TestState_JSONCodec(t *testing.T) { ...@@ -123,6 +127,154 @@ func TestState_JSONCodec(t *testing.T) {
require.Equal(t, state.LastHint, newState.LastHint) require.Equal(t, state.LastHint, newState.LastHint)
} }
func TestState_Binary(t *testing.T) {
elfProgram, err := elf.Open("../../testdata/example/bin/hello.elf")
require.NoError(t, err, "open ELF file")
state, err := program.LoadELF(elfProgram, CreateInitialState)
require.NoError(t, err, "load ELF into state")
// Set a few additional fields
state.PreimageKey = crypto.Keccak256Hash([]byte{1, 2, 3, 4})
state.PreimageOffset = 4
state.Heap = 555
state.Step = 99_999
state.StepsSinceLastContextSwitch = 123
state.Exited = true
state.ExitCode = 2
state.LastHint = []byte{11, 12, 13}
buf := new(bytes.Buffer)
err = state.Serialize(buf)
require.NoError(t, err)
newState := new(State)
require.NoError(t, newState.Deserialize(bytes.NewReader(buf.Bytes())))
require.Equal(t, state.PreimageKey, newState.PreimageKey)
require.Equal(t, state.PreimageOffset, newState.PreimageOffset)
require.Equal(t, state.Heap, newState.Heap)
require.Equal(t, state.ExitCode, newState.ExitCode)
require.Equal(t, state.Exited, newState.Exited)
require.Equal(t, state.Memory.MerkleRoot(), newState.Memory.MerkleRoot())
require.Equal(t, state.Step, newState.Step)
require.Equal(t, state.StepsSinceLastContextSwitch, newState.StepsSinceLastContextSwitch)
require.Equal(t, state.Wakeup, newState.Wakeup)
require.Equal(t, state.TraverseRight, newState.TraverseRight)
require.Equal(t, state.LeftThreadStack, newState.LeftThreadStack)
require.Equal(t, state.RightThreadStack, newState.RightThreadStack)
require.Equal(t, state.NextThreadId, newState.NextThreadId)
require.Equal(t, state.LastHint, newState.LastHint)
}
func TestSerializeStateRoundTrip(t *testing.T) {
// Construct a test case with populated fields
mem := memory.NewMemory()
mem.AllocPage(5)
p := mem.AllocPage(123)
p.Data[2] = 0x01
state := &State{
Memory: mem,
PreimageKey: common.Hash{0xFF},
PreimageOffset: 5,
Heap: 0xc0ffee,
ExitCode: 1,
Exited: true,
Step: 0xdeadbeef,
StepsSinceLastContextSwitch: 334,
Wakeup: 42,
TraverseRight: true,
LeftThreadStack: []*ThreadState{
{
ThreadId: 45,
ExitCode: 46,
Exited: true,
FutexAddr: 47,
FutexVal: 48,
FutexTimeoutStep: 49,
Cpu: mipsevm.CpuScalars{
PC: 0xFF,
NextPC: 0xFF + 4,
LO: 0xbeef,
HI: 0xbabe,
},
Registers: [32]uint32{
0xdeadbeef,
0xdeadbeef,
0xc0ffee,
0xbeefbabe,
0xdeadc0de,
0xbadc0de,
0xdeaddead,
},
},
{
ThreadId: 55,
ExitCode: 56,
Exited: false,
FutexAddr: 57,
FutexVal: 58,
FutexTimeoutStep: 59,
Cpu: mipsevm.CpuScalars{
PC: 0xEE,
NextPC: 0xEE + 4,
LO: 0xeeef,
HI: 0xeabe,
},
Registers: [32]uint32{
0xabcdef,
0x123456,
},
},
},
RightThreadStack: []*ThreadState{
{
ThreadId: 65,
ExitCode: 66,
Exited: false,
FutexAddr: 67,
FutexVal: 68,
FutexTimeoutStep: 69,
Cpu: mipsevm.CpuScalars{
PC: 0xdd,
NextPC: 0xdd + 4,
LO: 0xdeef,
HI: 0xdabe,
},
Registers: [32]uint32{
0x654321,
},
},
{
ThreadId: 75,
ExitCode: 76,
Exited: true,
FutexAddr: 77,
FutexVal: 78,
FutexTimeoutStep: 79,
Cpu: mipsevm.CpuScalars{
PC: 0xcc,
NextPC: 0xcc + 4,
LO: 0xceef,
HI: 0xcabe,
},
Registers: [32]uint32{
0x987653,
0xfedbca,
},
},
},
NextThreadId: 489,
LastHint: hexutil.Bytes{1, 2, 3, 4, 5},
}
ser := new(bytes.Buffer)
err := state.Serialize(ser)
require.NoError(t, err, "must serialize state")
state2 := &State{}
err = state2.Deserialize(ser)
require.NoError(t, err, "must deserialize state")
require.Equal(t, state, state2, "must roundtrip state")
}
func TestState_EmptyThreadsRoot(t *testing.T) { func TestState_EmptyThreadsRoot(t *testing.T) {
data := [64]byte{} data := [64]byte{}
expectedEmptyRoot := crypto.Keccak256Hash(data[:]) expectedEmptyRoot := crypto.Keccak256Hash(data[:])
......
...@@ -2,6 +2,7 @@ package multithreaded ...@@ -2,6 +2,7 @@ package multithreaded
import ( import (
"encoding/binary" "encoding/binary"
"io"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
...@@ -74,6 +75,55 @@ func (t *ThreadState) serializeThread() []byte { ...@@ -74,6 +75,55 @@ func (t *ThreadState) serializeThread() []byte {
return out return out
} }
// Serialize writes the ThreadState in a simple binary format which can be read again using Deserialize
// The format exactly matches the serialization generated by serializeThread used for thread proofs.
func (t *ThreadState) Serialize(out io.Writer) error {
_, err := out.Write(t.serializeThread())
return err
}
func (t *ThreadState) Deserialize(in io.Reader) error {
if err := binary.Read(in, binary.BigEndian, &t.ThreadId); err != nil {
return err
}
if err := binary.Read(in, binary.BigEndian, &t.ExitCode); err != nil {
return err
}
var exited uint8
if err := binary.Read(in, binary.BigEndian, &exited); err != nil {
return err
}
t.Exited = exited != 0
if err := binary.Read(in, binary.BigEndian, &t.FutexAddr); err != nil {
return err
}
if err := binary.Read(in, binary.BigEndian, &t.FutexVal); err != nil {
return err
}
if err := binary.Read(in, binary.BigEndian, &t.FutexTimeoutStep); err != nil {
return err
}
if err := binary.Read(in, binary.BigEndian, &t.Cpu.PC); err != nil {
return err
}
if err := binary.Read(in, binary.BigEndian, &t.Cpu.NextPC); err != nil {
return err
}
if err := binary.Read(in, binary.BigEndian, &t.Cpu.LO); err != nil {
return err
}
if err := binary.Read(in, binary.BigEndian, &t.Cpu.HI); err != nil {
return err
}
// Read the registers as big endian uint32s
for i := range t.Registers {
if err := binary.Read(in, binary.BigEndian, &t.Registers[i]); err != nil {
return err
}
}
return nil
}
func computeThreadRoot(prevStackRoot common.Hash, threadToPush *ThreadState) common.Hash { func computeThreadRoot(prevStackRoot common.Hash, threadToPush *ThreadState) common.Hash {
hashedThread := crypto.Keccak256Hash(threadToPush.serializeThread()) hashedThread := crypto.Keccak256Hash(threadToPush.serializeThread())
......
...@@ -6,7 +6,7 @@ import ( ...@@ -6,7 +6,7 @@ import (
"github.com/ethereum-optimism/optimism/cannon/mipsevm" "github.com/ethereum-optimism/optimism/cannon/mipsevm"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/exec" "github.com/ethereum-optimism/optimism/cannon/mipsevm/exec"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/program" "github.com/ethereum-optimism/optimism/cannon/mipsevm/program"
"github.com/ethereum-optimism/optimism/op-service/jsonutil" "github.com/ethereum-optimism/optimism/cannon/serialize"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
) )
...@@ -48,7 +48,7 @@ func NewInstrumentedState(state *State, po mipsevm.PreimageOracle, stdOut, stdEr ...@@ -48,7 +48,7 @@ func NewInstrumentedState(state *State, po mipsevm.PreimageOracle, stdOut, stdEr
} }
func NewInstrumentedStateFromFile(stateFile string, po mipsevm.PreimageOracle, stdOut, stdErr io.Writer, meta *program.Metadata) (*InstrumentedState, error) { func NewInstrumentedStateFromFile(stateFile string, po mipsevm.PreimageOracle, stdOut, stdErr io.Writer, meta *program.Metadata) (*InstrumentedState, error) {
state, err := jsonutil.LoadJSON[State](stateFile) state, err := serialize.Load[State](stateFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -4,7 +4,10 @@ import ( ...@@ -4,7 +4,10 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/versions"
"github.com/ethereum-optimism/optimism/cannon/serialize"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
...@@ -177,6 +180,129 @@ func (s *State) EncodeWitness() ([]byte, common.Hash) { ...@@ -177,6 +180,129 @@ func (s *State) EncodeWitness() ([]byte, common.Hash) {
return out, stateHashFromWitness(out) return out, stateHashFromWitness(out)
} }
// Serialize writes the state in a simple binary format which can be read again using Deserialize
// The format is a simple concatenation of fields, with prefixed item count for repeating items and using big endian
// encoding for numbers.
//
// StateVersion uint8(0)
// Memory As per Memory.Serialize
// PreimageKey [32]byte
// PreimageOffset uint32
// Cpu.PC uint32
// Cpu.NextPC uint32
// Cpu.LO uint32
// Cpu.HI uint32
// Heap uint32
// ExitCode uint8
// Exited uint8 - 0 for false, 1 for true
// Step uint64
// Registers [32]uint32
// len(LastHint) uint32 (0 when LastHint is nil)
// LastHint []byte
func (s *State) Serialize(out io.Writer) error {
bout := serialize.NewBinaryWriter(out)
if err := bout.WriteUInt(versions.VersionSingleThreaded); err != nil {
return err
}
if err := s.Memory.Serialize(out); err != nil {
return err
}
if err := bout.WriteHash(s.PreimageKey); err != nil {
return err
}
if err := bout.WriteUInt(s.PreimageOffset); err != nil {
return err
}
if err := bout.WriteUInt(s.Cpu.PC); err != nil {
return err
}
if err := bout.WriteUInt(s.Cpu.NextPC); err != nil {
return err
}
if err := bout.WriteUInt(s.Cpu.LO); err != nil {
return err
}
if err := bout.WriteUInt(s.Cpu.HI); err != nil {
return err
}
if err := bout.WriteUInt(s.Heap); err != nil {
return err
}
if err := bout.WriteUInt(s.ExitCode); err != nil {
return err
}
if err := bout.WriteBool(s.Exited); err != nil {
return err
}
if err := bout.WriteUInt(s.Step); err != nil {
return err
}
for _, r := range s.Registers {
if err := bout.WriteUInt(r); err != nil {
return err
}
}
if err := bout.WriteBytes(s.LastHint); err != nil {
return err
}
return nil
}
func (s *State) Deserialize(in io.Reader) error {
bin := serialize.NewBinaryReader(in)
var version versions.StateVersion
if err := bin.ReadUInt(&version); err != nil {
return err
}
if version != versions.VersionSingleThreaded {
return fmt.Errorf("invalid state encoding version %d", version)
}
s.Memory = memory.NewMemory()
if err := s.Memory.Deserialize(in); err != nil {
return err
}
if err := bin.ReadHash(&s.PreimageKey); err != nil {
return err
}
if err := bin.ReadUInt(&s.PreimageOffset); err != nil {
return err
}
if err := bin.ReadUInt(&s.Cpu.PC); err != nil {
return err
}
if err := bin.ReadUInt(&s.Cpu.NextPC); err != nil {
return err
}
if err := bin.ReadUInt(&s.Cpu.LO); err != nil {
return err
}
if err := bin.ReadUInt(&s.Cpu.HI); err != nil {
return err
}
if err := bin.ReadUInt(&s.Heap); err != nil {
return err
}
if err := bin.ReadUInt(&s.ExitCode); err != nil {
return err
}
if err := bin.ReadBool(&s.Exited); err != nil {
return err
}
if err := bin.ReadUInt(&s.Step); err != nil {
return err
}
for i := range s.Registers {
if err := bin.ReadUInt(&s.Registers[i]); err != nil {
return err
}
}
if err := bin.ReadBytes((*[]byte)(&s.LastHint)); err != nil {
return err
}
return nil
}
type StateWitness []byte type StateWitness []byte
func (sw StateWitness) StateHash() (common.Hash, error) { func (sw StateWitness) StateHash() (common.Hash, error) {
......
package singlethreaded package singlethreaded
import ( import (
"bytes"
"debug/elf" "debug/elf"
"testing" "testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -79,3 +82,69 @@ func TestStateJSONCodec(t *testing.T) { ...@@ -79,3 +82,69 @@ func TestStateJSONCodec(t *testing.T) {
require.Equal(t, state.Registers, newState.Registers) require.Equal(t, state.Registers, newState.Registers)
require.Equal(t, state.Step, newState.Step) require.Equal(t, state.Step, newState.Step)
} }
func TestStateBinaryCodec(t *testing.T) {
elfProgram, err := elf.Open("../../testdata/example/bin/hello.elf")
require.NoError(t, err, "open ELF file")
state, err := program.LoadELF(elfProgram, CreateInitialState)
require.NoError(t, err, "load ELF into state")
buf := new(bytes.Buffer)
err = state.Serialize(buf)
require.NoError(t, err)
newState := new(State)
require.NoError(t, newState.Deserialize(bytes.NewReader(buf.Bytes())))
require.Equal(t, state.PreimageKey, newState.PreimageKey)
require.Equal(t, state.PreimageOffset, newState.PreimageOffset)
require.Equal(t, state.Cpu, newState.Cpu)
require.Equal(t, state.Heap, newState.Heap)
require.Equal(t, state.ExitCode, newState.ExitCode)
require.Equal(t, state.Exited, newState.Exited)
require.Equal(t, state.Memory.PageCount(), newState.Memory.PageCount())
require.Equal(t, state.Memory.MerkleRoot(), newState.Memory.MerkleRoot())
require.Equal(t, state.Registers, newState.Registers)
require.Equal(t, state.Step, newState.Step)
}
func TestSerializeStateRoundTrip(t *testing.T) {
// Construct a test case with populated fields
mem := memory.NewMemory()
mem.AllocPage(5)
p := mem.AllocPage(123)
p.Data[2] = 0x01
state := &State{
Memory: mem,
PreimageKey: common.Hash{0xFF},
PreimageOffset: 5,
Cpu: mipsevm.CpuScalars{
PC: 0xFF,
NextPC: 0xFF + 4,
LO: 0xbeef,
HI: 0xbabe,
},
Heap: 0xc0ffee,
ExitCode: 1,
Exited: true,
Step: 0xdeadbeef,
Registers: [32]uint32{
0xdeadbeef,
0xdeadbeef,
0xc0ffee,
0xbeefbabe,
0xdeadc0de,
0xbadc0de,
0xdeaddead,
},
LastHint: hexutil.Bytes{1, 2, 3, 4, 5},
}
ser := new(bytes.Buffer)
err := state.Serialize(ser)
require.NoError(t, err, "must serialize state")
state2 := &State{}
err = state2.Deserialize(ser)
require.NoError(t, err, "must deserialize state")
require.Equal(t, state, state2, "must roundtrip state")
}
package versions
type StateVersion uint8
const (
VersionSingleThreaded StateVersion = iota
VersionMultiThreaded
)
package serialize
import (
"errors"
"fmt"
"io"
"os"
"reflect"
"github.com/ethereum-optimism/optimism/op-service/ioutil"
)
// Serializable defines functionality for a type that may be serialized to raw bytes.
type Serializable interface {
// Serialize encodes the type as raw bytes.
Serialize(out io.Writer) error
// Deserialize decodes raw bytes into the type.
Deserialize(in io.Reader) error
}
func LoadSerializedBinary[X any](inputPath string) (*X, error) {
if inputPath == "" {
return nil, errors.New("no path specified")
}
var f io.ReadCloser
f, err := ioutil.OpenDecompressed(inputPath)
if err != nil {
return nil, fmt.Errorf("failed to open file %q: %w", inputPath, err)
}
defer f.Close()
var x X
serializable, ok := reflect.ValueOf(&x).Interface().(Serializable)
if !ok {
return nil, fmt.Errorf("%T is not a Serializable", x)
}
err = serializable.Deserialize(f)
if err != nil {
return nil, err
}
return &x, nil
}
func WriteSerializedBinary(outputPath string, value Serializable, perm os.FileMode) error {
if outputPath == "" {
return nil
}
var out io.Writer
finish := func() error { return nil }
if outputPath == "-" {
out = os.Stdout
} else {
f, err := ioutil.NewAtomicWriterCompressed(outputPath, perm)
if err != nil {
return fmt.Errorf("failed to create temp file when writing: %w", err)
}
// Ensure we close the stream without renaming even if failures occur.
defer func() {
_ = f.Abort()
}()
out = f
// Closing the file causes it to be renamed to the final destination
// so make sure we handle any errors it returns
finish = f.Close
}
err := value.Serialize(out)
if err != nil {
return fmt.Errorf("failed to write binary: %w", err)
}
if err := finish(); err != nil {
return fmt.Errorf("failed to finish write: %w", err)
}
return nil
}
package serialize
import (
"encoding/binary"
"io"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func TestRoundTripBinary(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(dir, "test.bin")
data := &serializableTestData{A: []byte{0xde, 0xad}, B: 3}
err := WriteSerializedBinary(file, data, 0644)
require.NoError(t, err)
hasGzip, err := hasGzipHeader(file)
require.NoError(t, err)
require.False(t, hasGzip)
result, err := LoadSerializedBinary[serializableTestData](file)
require.NoError(t, err)
require.EqualValues(t, data, result)
}
func TestRoundTripBinaryWithGzip(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(dir, "test.bin.gz")
data := &serializableTestData{A: []byte{0xde, 0xad}, B: 3}
err := WriteSerializedBinary(file, data, 0644)
require.NoError(t, err)
hasGzip, err := hasGzipHeader(file)
require.NoError(t, err)
require.True(t, hasGzip)
result, err := LoadSerializedBinary[serializableTestData](file)
require.NoError(t, err)
require.EqualValues(t, data, result)
}
func hasGzipHeader(filename string) (bool, error) {
file, err := os.Open(filename)
if err != nil {
return false, err
}
defer file.Close()
header := make([]byte, 2)
_, err = file.Read(header)
if err != nil {
return false, err
}
// Gzip header magic numbers: 1F 8B
return header[0] == 0x1F && header[1] == 0x8B, nil
}
type serializableTestData struct {
A []byte
B uint8
}
func (s *serializableTestData) Serialize(w io.Writer) error {
if err := binary.Write(w, binary.BigEndian, uint64(len(s.A))); err != nil {
return err
}
if _, err := w.Write(s.A); err != nil {
return err
}
if err := binary.Write(w, binary.BigEndian, s.B); err != nil {
return err
}
return nil
}
func (s *serializableTestData) Deserialize(in io.Reader) error {
var lenA uint64
if err := binary.Read(in, binary.BigEndian, &lenA); err != nil {
return err
}
s.A = make([]byte, lenA)
if _, err := io.ReadFull(in, s.A); err != nil {
return err
}
if err := binary.Read(in, binary.BigEndian, &s.B); err != nil {
return err
}
return nil
}
package serialize
import (
"os"
"strings"
"github.com/ethereum-optimism/optimism/op-service/jsonutil"
)
func Load[X any](inputPath string) (*X, error) {
if isBinary(inputPath) {
return LoadSerializedBinary[X](inputPath)
}
return jsonutil.LoadJSON[X](inputPath)
}
func Write[X Serializable](outputPath string, x X, perm os.FileMode) error {
if isBinary(outputPath) {
return WriteSerializedBinary(outputPath, x, perm)
}
return jsonutil.WriteJSON[X](outputPath, x, perm)
}
func isBinary(path string) bool {
return strings.HasSuffix(path, ".bin") || strings.HasSuffix(path, ".bin.gz")
}
package serialize
import (
"io"
"path/filepath"
"testing"
"github.com/ethereum-optimism/optimism/op-service/ioutil"
"github.com/stretchr/testify/require"
)
func TestRoundtrip(t *testing.T) {
tests := []struct {
filename string
expectJSON bool
expectGzip bool
}{
{filename: "test.json", expectJSON: true, expectGzip: false},
{filename: "test.json.gz", expectJSON: true, expectGzip: true},
{filename: "test.foo", expectJSON: true, expectGzip: false},
{filename: "test.foo.gz", expectJSON: true, expectGzip: true},
{filename: "test.bin", expectJSON: false, expectGzip: false},
{filename: "test.bin.gz", expectJSON: false, expectGzip: true},
}
for _, test := range tests {
test := test
t.Run(test.filename, func(t *testing.T) {
path := filepath.Join(t.TempDir(), test.filename)
data := &serializableTestData{A: []byte{0xde, 0xad}, B: 3}
err := Write[*serializableTestData](path, data, 0644)
require.NoError(t, err)
hasGzip, err := hasGzipHeader(path)
require.NoError(t, err)
require.Equal(t, test.expectGzip, hasGzip)
decompressed, err := ioutil.OpenDecompressed(path)
require.NoError(t, err)
defer decompressed.Close()
start := make([]byte, 1)
_, err = io.ReadFull(decompressed, start)
require.NoError(t, err)
if test.expectJSON {
require.Equal(t, "{", string(start))
} else {
require.NotEqual(t, "{", string(start))
}
result, err := Load[serializableTestData](path)
require.NoError(t, err)
require.EqualValues(t, data, result)
})
}
}
package serialize
import (
"encoding/binary"
"fmt"
"io"
"github.com/ethereum/go-ethereum/common"
)
// BinaryReader provides methods to decode content written by BinaryWriter.
type BinaryReader struct {
in io.Reader
}
func NewBinaryReader(in io.Reader) *BinaryReader {
return &BinaryReader{in: in}
}
func (r *BinaryReader) ReadUInt(target any) error {
return binary.Read(r.in, binary.BigEndian, target)
}
func (r *BinaryReader) ReadHash(target *common.Hash) error {
_, err := io.ReadFull(r.in, target[:])
return err
}
func (r *BinaryReader) ReadBool(target *bool) error {
var v uint8
if err := r.ReadUInt(&v); err != nil {
return err
}
switch v {
case 0:
*target = false
case 1:
*target = true
default:
return fmt.Errorf("invalid boolean value: %v", v)
}
return nil
}
func (r *BinaryReader) ReadBytes(target *[]byte) error {
var size uint32
if err := r.ReadUInt(&size); err != nil {
return err
}
if size == 0 {
*target = nil
return nil
}
data := make([]byte, size)
if _, err := io.ReadFull(r.in, data); err != nil {
return err
}
*target = data
return nil
}
package serialize
import (
"bytes"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/require"
)
func TestRoundTripWithReader(t *testing.T) {
// Test that reader can read the data written by BinaryWriter.
// The writer tests check that the generated data is what is expected, so simply check that the reader correctly
// parses a range of data here rather than duplicating the expected binary serialization.
buf := new(bytes.Buffer)
out := NewBinaryWriter(buf)
require.NoError(t, out.WriteBool(true))
require.NoError(t, out.WriteBool(false))
require.NoError(t, out.WriteUInt(uint8(5)))
require.NoError(t, out.WriteUInt(uint32(76)))
require.NoError(t, out.WriteUInt(uint64(24824424)))
expectedHash := common.HexToHash("0x5a8f75b8e1c1529d1d1c596464d17b99763604f4c00b280436fc0dffacc60efd")
require.NoError(t, out.WriteHash(expectedHash))
expectedBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}
require.NoError(t, out.WriteBytes(expectedBytes))
in := NewBinaryReader(buf)
var b bool
require.NoError(t, in.ReadBool(&b))
require.True(t, b)
require.NoError(t, in.ReadBool(&b))
require.False(t, b)
var vUInt8 uint8
require.NoError(t, in.ReadUInt(&vUInt8))
require.Equal(t, uint8(5), vUInt8)
var vUInt32 uint32
require.NoError(t, in.ReadUInt(&vUInt32))
require.Equal(t, uint32(76), vUInt32)
var vUInt64 uint64
require.NoError(t, in.ReadUInt(&vUInt64))
require.Equal(t, uint64(24824424), vUInt64)
var hash common.Hash
require.NoError(t, in.ReadHash(&hash))
require.Equal(t, expectedHash, hash)
var data []byte
require.NoError(t, in.ReadBytes(&data))
require.Equal(t, expectedBytes, data)
}
package serialize
import (
"encoding/binary"
"io"
"github.com/ethereum/go-ethereum/common"
)
// BinaryWriter writes a simple binary format which can be read again using BinaryReader.
// The format is a simple concatenation of values, with prefixed length for variable length items.
// All numbers are encoded using big endian.
type BinaryWriter struct {
out io.Writer
}
func NewBinaryWriter(out io.Writer) *BinaryWriter {
return &BinaryWriter{out: out}
}
func (w *BinaryWriter) WriteUInt(v any) error {
return binary.Write(w.out, binary.BigEndian, v)
}
func (w *BinaryWriter) WriteHash(v common.Hash) error {
_, err := w.out.Write(v[:])
return err
}
func (w *BinaryWriter) WriteBool(v bool) error {
if v {
return w.WriteUInt(uint8(1))
} else {
return w.WriteUInt(uint8(0))
}
}
func (w *BinaryWriter) WriteBytes(v []byte) error {
if err := w.WriteUInt(uint32(len(v))); err != nil {
return err
}
_, err := w.out.Write(v)
return err
}
package serialize
import (
"bytes"
"fmt"
"math"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/require"
)
func TestUInt(t *testing.T) {
tests := []struct {
name string
val any
expected []byte
}{
{name: "uint8-zero", val: uint8(0), expected: []byte{0}},
{name: "uint8-one", val: uint8(1), expected: []byte{1}},
{name: "uint8-big", val: uint8(156), expected: []byte{156}},
{name: "uint8-max", val: uint8(math.MaxUint8), expected: []byte{255}},
{name: "uint16-zero", val: uint16(0), expected: []byte{0, 0}},
{name: "uint16-one", val: uint16(1), expected: []byte{0, 1}},
{name: "uint16-big", val: uint16(1283), expected: []byte{5, 3}},
{name: "uint16-max", val: uint16(math.MaxUint16), expected: []byte{255, 255}},
{name: "uint32-zero", val: uint32(0), expected: []byte{0, 0, 0, 0}},
{name: "uint32-one", val: uint32(1), expected: []byte{0, 0, 0, 1}},
{name: "uint32-big", val: uint32(1283424245), expected: []byte{0x4c, 0x7f, 0x7f, 0xf5}},
{name: "uint32-max", val: uint32(math.MaxUint32), expected: []byte{255, 255, 255, 255}},
{name: "uint64-zero", val: uint64(0), expected: []byte{0, 0, 0, 0, 0, 0, 0, 0}},
{name: "uint64-one", val: uint64(1), expected: []byte{0, 0, 0, 0, 0, 0, 0, 1}},
{name: "uint64-big", val: uint64(1283424245242429284), expected: []byte{0x11, 0xcf, 0xa3, 0x8d, 0x19, 0xcc, 0x7f, 0x64}},
{name: "uint64-max", val: uint64(math.MaxUint64), expected: []byte{255, 255, 255, 255, 255, 255, 255, 255}},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
out := new(bytes.Buffer)
bout := NewBinaryWriter(out)
require.NoError(t, bout.WriteUInt(test.val))
result := out.Bytes()
require.Equal(t, test.expected, result)
})
}
}
func TestWriteHash(t *testing.T) {
out := new(bytes.Buffer)
bout := NewBinaryWriter(out)
hash := common.HexToHash("0x5a8f75b8e1c1529d1d1c596464d17b99763604f4c00b280436fc0dffacc60efd")
require.NoError(t, bout.WriteHash(hash))
result := out.Bytes()
require.Equal(t, hash[:], result)
}
func TestWriteBool(t *testing.T) {
for _, val := range []bool{true, false} {
val := val
t.Run(fmt.Sprintf("%t", val), func(t *testing.T) {
out := new(bytes.Buffer)
bout := NewBinaryWriter(out)
require.NoError(t, bout.WriteBool(val))
result := out.Bytes()
require.Len(t, result, 1)
if val {
require.Equal(t, result[0], uint8(1))
} else {
require.Equal(t, result[0], uint8(0))
}
})
}
}
func TestWriteBytes(t *testing.T) {
tests := []struct {
name string
val []byte
expected []byte
}{
{name: "nil", val: nil, expected: []byte{0, 0, 0, 0}},
{name: "empty", val: []byte{}, expected: []byte{0, 0, 0, 0}},
{name: "non-empty", val: []byte{1, 2, 3, 4, 5}, expected: []byte{0, 0, 0, 5, 1, 2, 3, 4, 5}},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
out := new(bytes.Buffer)
bout := NewBinaryWriter(out)
require.NoError(t, bout.WriteBytes(test.val))
result := out.Bytes()
require.Equal(t, test.expected, result)
})
}
}
...@@ -38,7 +38,9 @@ func WriteJSON[X any](outputPath string, value X, perm os.FileMode) error { ...@@ -38,7 +38,9 @@ func WriteJSON[X any](outputPath string, value X, perm os.FileMode) error {
} }
var out io.Writer var out io.Writer
finish := func() error { return nil } finish := func() error { return nil }
if outputPath != "-" { if outputPath == "-" {
out = os.Stdout
} else {
f, err := ioutil.NewAtomicWriterCompressed(outputPath, perm) f, err := ioutil.NewAtomicWriterCompressed(outputPath, perm)
if err != nil { if err != nil {
return fmt.Errorf("failed to open output file: %w", err) return fmt.Errorf("failed to open output file: %w", err)
...@@ -51,8 +53,6 @@ func WriteJSON[X any](outputPath string, value X, perm os.FileMode) error { ...@@ -51,8 +53,6 @@ func WriteJSON[X any](outputPath string, value X, perm os.FileMode) error {
// Closing the file causes it to be renamed to the final destination // Closing the file causes it to be renamed to the final destination
// so make sure we handle any errors it returns // so make sure we handle any errors it returns
finish = f.Close finish = f.Close
} else {
out = os.Stdout
} }
enc := json.NewEncoder(out) enc := json.NewEncoder(out)
enc.SetIndent("", " ") enc.SetIndent("", " ")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment