Commit 981d6b2a authored by Nickqiao's avatar Nickqiao

use io.ReadFull to ensure read enough & add read/writeSolidityABIUint64 helper...

use io.ReadFull to ensure read enough & add read/writeSolidityABIUint64 helper function & add some formatted error
parent 813f7b4b
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"math/big" "math/big"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
...@@ -65,13 +66,17 @@ type L1BlockInfo struct { ...@@ -65,13 +66,17 @@ type L1BlockInfo struct {
func (info *L1BlockInfo) MarshalBinary() ([]byte, error) { func (info *L1BlockInfo) MarshalBinary() ([]byte, error) {
writer := bytes.NewBuffer(make([]byte, 0, L1InfoLen)) writer := bytes.NewBuffer(make([]byte, 0, L1InfoLen))
writer.Write(L1InfoFuncBytes4)
var padding [24]byte // Helper function to write Solidity ABI encode Uint64
writer.Write(padding[:]) writeSolidityABIUint64 := func(num uint64) {
_ = binary.Write(writer, binary.BigEndian, info.Number) var padding [24]byte
writer.Write(padding[:]) writer.Write(padding[:])
_ = binary.Write(writer, binary.BigEndian, info.Time) _ = binary.Write(writer, binary.BigEndian, num)
}
writer.Write(L1InfoFuncBytes4)
writeSolidityABIUint64(info.Number)
writeSolidityABIUint64(info.Time)
// Ensure that the baseFee is not too large. // Ensure that the baseFee is not too large.
if info.BaseFee.BitLen() > 256 { if info.BaseFee.BitLen() > 256 {
return nil, fmt.Errorf("base fee exceeds 256 bits: %d", info.BaseFee) return nil, fmt.Errorf("base fee exceeds 256 bits: %d", info.BaseFee)
...@@ -80,8 +85,7 @@ func (info *L1BlockInfo) MarshalBinary() ([]byte, error) { ...@@ -80,8 +85,7 @@ func (info *L1BlockInfo) MarshalBinary() ([]byte, error) {
info.BaseFee.FillBytes(baseFeeBuf[:]) info.BaseFee.FillBytes(baseFeeBuf[:])
writer.Write(baseFeeBuf[:]) writer.Write(baseFeeBuf[:])
writer.Write(info.BlockHash.Bytes()) writer.Write(info.BlockHash.Bytes())
writer.Write(padding[:]) writeSolidityABIUint64(info.SequenceNumber)
_ = binary.Write(writer, binary.BigEndian, info.SequenceNumber)
var addrPadding [12]byte var addrPadding [12]byte
writer.Write(addrPadding[:]) writer.Write(addrPadding[:])
...@@ -91,64 +95,63 @@ func (info *L1BlockInfo) MarshalBinary() ([]byte, error) { ...@@ -91,64 +95,63 @@ func (info *L1BlockInfo) MarshalBinary() ([]byte, error) {
return writer.Bytes(), nil return writer.Bytes(), nil
} }
func (info *L1BlockInfo) UnmarshalBinary(data []byte) error { func (info *L1BlockInfo) UnmarshalBinary(data []byte) (err error) {
if len(data) != L1InfoLen { if len(data) != L1InfoLen {
return fmt.Errorf("data is unexpected length: %d", len(data)) return fmt.Errorf("data is unexpected length: %d", len(data))
} }
reader := bytes.NewReader(data) reader := bytes.NewReader(data)
funcSignature := make([]byte, 4) // Helper function to read Solidity ABI encode Uint64
if _, err := reader.Read(funcSignature); err != nil || !bytes.Equal(funcSignature, L1InfoFuncBytes4) { readSolidityABIUint64 := func() (num uint64, err error) {
return fmt.Errorf("data does not match L1 info function signature: 0x%x", funcSignature) var padding, readPadding [24]byte
if _, err := io.ReadFull(reader, readPadding[:]); err != nil || !bytes.Equal(readPadding[:], padding[:]) {
return 0, fmt.Errorf("L1BlockInfo number exceeds uint64 bounds: %x", readPadding[:])
}
if err := binary.Read(reader, binary.BigEndian, &num); err != nil {
return 0, fmt.Errorf("L1BlockInfo expected number length to be 8 bytes")
}
return num, nil
} }
var padding, readPadding [24]byte funcSignature := make([]byte, 4)
if _, err = io.ReadFull(reader, funcSignature); err != nil || !bytes.Equal(funcSignature, L1InfoFuncBytes4) {
if _, err := reader.Read(readPadding[:]); err != nil || !bytes.Equal(readPadding[:], padding[:]) { return fmt.Errorf("data does not match L1 info function signature: 0x%x", funcSignature)
return fmt.Errorf("l1 info number exceeds uint64 bounds: %x", readPadding[:])
}
if err := binary.Read(reader, binary.BigEndian, &info.Number); err != nil {
return err
} }
if _, err := reader.Read(readPadding[:]); err != nil || !bytes.Equal(readPadding[:], padding[:]) { if info.Number, err = readSolidityABIUint64(); err != nil {
return fmt.Errorf("l1 info time exceeds uint64 bounds: %x", readPadding[:]) return
} }
if err := binary.Read(reader, binary.BigEndian, &info.Time); err != nil { if info.Time, err = readSolidityABIUint64(); err != nil {
return err return
} }
var baseFeeBytes [32]byte var baseFeeBytes [32]byte
if _, err := reader.Read(baseFeeBytes[:]); err != nil { if _, err = io.ReadFull(reader, baseFeeBytes[:]); err != nil {
return err return fmt.Errorf("expected BaseFee length to be 32 bytes, but got %x", baseFeeBytes)
} }
info.BaseFee = new(big.Int).SetBytes(baseFeeBytes[:]) info.BaseFee = new(big.Int).SetBytes(baseFeeBytes[:])
var blockHashBytes [32]byte var blockHashBytes [32]byte
if _, err := reader.Read(blockHashBytes[:]); err != nil { if _, err = io.ReadFull(reader, blockHashBytes[:]); err != nil {
return err return fmt.Errorf("expected BlockHash length to be 32 bytes, but got %x", blockHashBytes)
} }
info.BlockHash.SetBytes(blockHashBytes[:]) info.BlockHash.SetBytes(blockHashBytes[:])
if _, err := reader.Read(readPadding[:]); err != nil || !bytes.Equal(readPadding[:], padding[:]) { if info.SequenceNumber, err = readSolidityABIUint64(); err != nil {
return fmt.Errorf("l1 info sequence number exceeds uint64 bounds: %x", readPadding[:]) return
}
if err := binary.Read(reader, binary.BigEndian, &info.SequenceNumber); err != nil {
return err
} }
var addrPadding [12]byte var addrPadding [12]byte
if _, err := reader.Read(addrPadding[:]); err != nil { if _, err = io.ReadFull(reader, addrPadding[:]); err != nil {
return err return fmt.Errorf("expected addrPadding length to be 12 bytes, but got %x", addrPadding)
} }
if _, err := reader.Read(info.BatcherAddr[:]); err != nil { if _, err = io.ReadFull(reader, info.BatcherAddr[:]); err != nil {
return err return fmt.Errorf("expected BatcherAddr length to be 20 bytes, but got %x", info.BatcherAddr)
} }
if _, err := reader.Read(info.L1FeeOverhead[:]); err != nil { if _, err = io.ReadFull(reader, info.L1FeeOverhead[:]); err != nil {
return err return fmt.Errorf("expected L1FeeOverhead length to be 32 bytes, but got %x", info.L1FeeOverhead)
} }
if _, err := reader.Read(info.L1FeeScalar[:]); err != nil { if _, err = io.ReadFull(reader, info.L1FeeScalar[:]); err != nil {
return err return fmt.Errorf("expected L1FeeScalar length to be 32 bytes, but got %x", info.L1FeeScalar)
} }
return nil return nil
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment