Commit c829e0cb authored by pcw109550's avatar pcw109550

op-node: Remove Panic while Span Batch Derivation

parent 918459c4
...@@ -392,7 +392,9 @@ func (b *RawSpanBatch) derive(blockTime, genesisTimestamp uint64, chainID *big.I ...@@ -392,7 +392,9 @@ func (b *RawSpanBatch) derive(blockTime, genesisTimestamp uint64, chainID *big.I
} }
} }
b.txs.recoverV(chainID) if err := b.txs.recoverV(chainID); err != nil {
return nil, err
}
fullTxs, err := b.txs.fullTxs(chainID) fullTxs, err := b.txs.fullTxs(chainID)
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -219,7 +219,8 @@ func TestSpanBatchPayload(t *testing.T) { ...@@ -219,7 +219,8 @@ func TestSpanBatchPayload(t *testing.T) {
err = sb.decodePayload(r) err = sb.decodePayload(r)
require.NoError(t, err) require.NoError(t, err)
sb.txs.recoverV(chainID) err = sb.txs.recoverV(chainID)
require.NoError(t, err)
require.Equal(t, rawSpanBatch.spanBatchPayload, sb.spanBatchPayload) require.Equal(t, rawSpanBatch.spanBatchPayload, sb.spanBatchPayload)
} }
...@@ -283,7 +284,8 @@ func TestSpanBatchTxs(t *testing.T) { ...@@ -283,7 +284,8 @@ func TestSpanBatchTxs(t *testing.T) {
err = sb.decodeTxs(r) err = sb.decodeTxs(r)
require.NoError(t, err) require.NoError(t, err)
sb.txs.recoverV(chainID) err = sb.txs.recoverV(chainID)
require.NoError(t, err)
require.Equal(t, rawSpanBatch.txs, sb.txs) require.Equal(t, rawSpanBatch.txs, sb.txs)
} }
...@@ -302,7 +304,8 @@ func TestSpanBatchRoundTrip(t *testing.T) { ...@@ -302,7 +304,8 @@ func TestSpanBatchRoundTrip(t *testing.T) {
err = sb.decode(bytes.NewReader(result.Bytes())) err = sb.decode(bytes.NewReader(result.Bytes()))
require.NoError(t, err) require.NoError(t, err)
sb.txs.recoverV(chainID) err = sb.txs.recoverV(chainID)
require.NoError(t, err)
require.Equal(t, rawSpanBatch, &sb) require.Equal(t, rawSpanBatch, &sb)
} }
......
...@@ -91,9 +91,9 @@ func (btx *spanBatchTxs) decodeContractCreationBits(r *bytes.Reader) error { ...@@ -91,9 +91,9 @@ func (btx *spanBatchTxs) decodeContractCreationBits(r *bytes.Reader) error {
return nil return nil
} }
func (btx *spanBatchTxs) contractCreationCount() uint64 { func (btx *spanBatchTxs) contractCreationCount() (uint64, error) {
if btx.contractCreationBits == nil { if btx.contractCreationBits == nil {
panic("contract creation bits not set") return 0, errors.New("dev error: contract creation bits not set")
} }
var result uint64 = 0 var result uint64 = 0
for i := 0; i < int(btx.totalBlockTxCount); i++ { for i := 0; i < int(btx.totalBlockTxCount); i++ {
...@@ -102,7 +102,7 @@ func (btx *spanBatchTxs) contractCreationCount() uint64 { ...@@ -102,7 +102,7 @@ func (btx *spanBatchTxs) contractCreationCount() uint64 {
result++ result++
} }
} }
return result return result, nil
} }
// yParityBits is bitlist right-padded to a multiple of 8 bits // yParityBits is bitlist right-padded to a multiple of 8 bits
...@@ -264,7 +264,10 @@ func (btx *spanBatchTxs) decodeTxGases(r *bytes.Reader) error { ...@@ -264,7 +264,10 @@ func (btx *spanBatchTxs) decodeTxGases(r *bytes.Reader) error {
func (btx *spanBatchTxs) decodeTxTos(r *bytes.Reader) error { func (btx *spanBatchTxs) decodeTxTos(r *bytes.Reader) error {
var txTos []common.Address var txTos []common.Address
txToBuffer := make([]byte, common.AddressLength) txToBuffer := make([]byte, common.AddressLength)
contractCreationCount := btx.contractCreationCount() contractCreationCount, err := btx.contractCreationCount()
if err != nil {
return err
}
for i := 0; i < int(btx.totalBlockTxCount-contractCreationCount); i++ { for i := 0; i < int(btx.totalBlockTxCount-contractCreationCount); i++ {
_, err := io.ReadFull(r, txToBuffer) _, err := io.ReadFull(r, txToBuffer)
if err != nil { if err != nil {
...@@ -293,9 +296,9 @@ func (btx *spanBatchTxs) decodeTxDatas(r *bytes.Reader) error { ...@@ -293,9 +296,9 @@ func (btx *spanBatchTxs) decodeTxDatas(r *bytes.Reader) error {
return nil return nil
} }
func (btx *spanBatchTxs) recoverV(chainID *big.Int) { func (btx *spanBatchTxs) recoverV(chainID *big.Int) error {
if len(btx.txTypes) != len(btx.txSigs) { if len(btx.txTypes) != len(btx.txSigs) {
panic("tx type length and tx sigs length mismatch") return errors.New("tx type length and tx sigs length mismatch")
} }
for idx, txType := range btx.txTypes { for idx, txType := range btx.txTypes {
bit := uint64(btx.yParityBits.Bit(idx)) bit := uint64(btx.yParityBits.Bit(idx))
...@@ -309,10 +312,11 @@ func (btx *spanBatchTxs) recoverV(chainID *big.Int) { ...@@ -309,10 +312,11 @@ func (btx *spanBatchTxs) recoverV(chainID *big.Int) {
case types.DynamicFeeTxType: case types.DynamicFeeTxType:
v = bit v = bit
default: default:
panic(fmt.Sprintf("invalid tx type: %d", txType)) return fmt.Errorf("invalid tx type: %d", txType)
} }
btx.txSigs[idx].v = v btx.txSigs[idx].v = v
} }
return nil
} }
func (btx *spanBatchTxs) encode(w io.Writer) error { func (btx *spanBatchTxs) encode(w io.Writer) error {
...@@ -400,7 +404,7 @@ func (btx *spanBatchTxs) fullTxs(chainID *big.Int) ([][]byte, error) { ...@@ -400,7 +404,7 @@ func (btx *spanBatchTxs) fullTxs(chainID *big.Int) ([][]byte, error) {
return txs, nil return txs, nil
} }
func convertVToYParity(v uint64, txType int) uint { func convertVToYParity(v uint64, txType int) (uint, error) {
var yParityBit uint var yParityBit uint
switch txType { switch txType {
case types.LegacyTxType: case types.LegacyTxType:
...@@ -412,9 +416,9 @@ func convertVToYParity(v uint64, txType int) uint { ...@@ -412,9 +416,9 @@ func convertVToYParity(v uint64, txType int) uint {
case types.DynamicFeeTxType: case types.DynamicFeeTxType:
yParityBit = uint(v) yParityBit = uint(v)
default: default:
panic(fmt.Sprintf("invalid tx type: %d", txType)) return 0, fmt.Errorf("invalid tx type: %d", txType)
} }
return yParityBit return yParityBit, nil
} }
func newSpanBatchTxs(txs [][]byte, chainID *big.Int) (*spanBatchTxs, error) { func newSpanBatchTxs(txs [][]byte, chainID *big.Int) (*spanBatchTxs, error) {
...@@ -449,7 +453,10 @@ func newSpanBatchTxs(txs [][]byte, chainID *big.Int) (*spanBatchTxs, error) { ...@@ -449,7 +453,10 @@ func newSpanBatchTxs(txs [][]byte, chainID *big.Int) (*spanBatchTxs, error) {
contractCreationBit = uint(0) contractCreationBit = uint(0)
} }
contractCreationBits.SetBit(contractCreationBits, idx, contractCreationBit) contractCreationBits.SetBit(contractCreationBits, idx, contractCreationBit)
yParityBit := convertVToYParity(txSig.v, int(tx.Type())) yParityBit, err := convertVToYParity(txSig.v, int(tx.Type()))
if err != nil {
return nil, err
}
yParityBits.SetBit(yParityBits, idx, yParityBit) yParityBits.SetBit(yParityBits, idx, yParityBit)
txNonces = append(txNonces, tx.Nonce()) txNonces = append(txNonces, tx.Nonce())
txGases = append(txGases, tx.Gas()) txGases = append(txGases, tx.Gas())
......
...@@ -54,7 +54,8 @@ func TestSpanBatchTxsContractCreationCount(t *testing.T) { ...@@ -54,7 +54,8 @@ func TestSpanBatchTxsContractCreationCount(t *testing.T) {
rawSpanBatch := RandomRawSpanBatch(rng, chainID) rawSpanBatch := RandomRawSpanBatch(rng, chainID)
contractCreationBits := rawSpanBatch.txs.contractCreationBits contractCreationBits := rawSpanBatch.txs.contractCreationBits
contractCreationCount := rawSpanBatch.txs.contractCreationCount() contractCreationCount, err := rawSpanBatch.txs.contractCreationCount()
require.NoError(t, err)
totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
var sbt spanBatchTxs var sbt spanBatchTxs
...@@ -62,7 +63,7 @@ func TestSpanBatchTxsContractCreationCount(t *testing.T) { ...@@ -62,7 +63,7 @@ func TestSpanBatchTxsContractCreationCount(t *testing.T) {
sbt.totalBlockTxCount = totalBlockTxCount sbt.totalBlockTxCount = totalBlockTxCount
var buf bytes.Buffer var buf bytes.Buffer
err := sbt.encodeContractCreationBits(&buf) err = sbt.encodeContractCreationBits(&buf)
require.NoError(t, err) require.NoError(t, err)
result := buf.Bytes() result := buf.Bytes()
...@@ -72,7 +73,10 @@ func TestSpanBatchTxsContractCreationCount(t *testing.T) { ...@@ -72,7 +73,10 @@ func TestSpanBatchTxsContractCreationCount(t *testing.T) {
err = sbt.decodeContractCreationBits(r) err = sbt.decodeContractCreationBits(r)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, contractCreationCount, sbt.contractCreationCount()) contractCreationCount2, err := sbt.contractCreationCount()
require.NoError(t, err)
require.Equal(t, contractCreationCount, contractCreationCount2)
} }
func TestSpanBatchTxsYParityBits(t *testing.T) { func TestSpanBatchTxsYParityBits(t *testing.T) {
...@@ -277,7 +281,8 @@ func TestSpanBatchTxsRecoverV(t *testing.T) { ...@@ -277,7 +281,8 @@ func TestSpanBatchTxsRecoverV(t *testing.T) {
txSig.s, _ = uint256.FromBig(s) txSig.s, _ = uint256.FromBig(s)
txSigs = append(txSigs, txSig) txSigs = append(txSigs, txSig)
originalVs = append(originalVs, v.Uint64()) originalVs = append(originalVs, v.Uint64())
yParityBit := convertVToYParity(v.Uint64(), int(tx.Type())) yParityBit, err := convertVToYParity(v.Uint64(), int(tx.Type()))
require.NoError(t, err)
yParityBits.SetBit(yParityBits, idx, yParityBit) yParityBits.SetBit(yParityBits, idx, yParityBit)
} }
...@@ -285,7 +290,8 @@ func TestSpanBatchTxsRecoverV(t *testing.T) { ...@@ -285,7 +290,8 @@ func TestSpanBatchTxsRecoverV(t *testing.T) {
spanBatchTxs.txSigs = txSigs spanBatchTxs.txSigs = txSigs
spanBatchTxs.txTypes = txTypes spanBatchTxs.txTypes = txTypes
// recover txSig.v // recover txSig.v
spanBatchTxs.recoverV(chainID) err := spanBatchTxs.recoverV(chainID)
require.NoError(t, err)
var recoveredVs []uint64 var recoveredVs []uint64
for _, txSig := range spanBatchTxs.txSigs { for _, txSig := range spanBatchTxs.txSigs {
...@@ -315,7 +321,9 @@ func TestSpanBatchTxsRoundTrip(t *testing.T) { ...@@ -315,7 +321,9 @@ func TestSpanBatchTxsRoundTrip(t *testing.T) {
sbt2.totalBlockTxCount = totalBlockTxCount sbt2.totalBlockTxCount = totalBlockTxCount
err = sbt2.decode(r) err = sbt2.decode(r)
require.NoError(t, err) require.NoError(t, err)
sbt2.recoverV(chainID)
err = sbt2.recoverV(chainID)
require.NoError(t, err)
require.Equal(t, sbt, &sbt2) require.Equal(t, sbt, &sbt2)
} }
...@@ -346,11 +354,6 @@ func TestSpanBatchTxsRoundTripFullTxs(t *testing.T) { ...@@ -346,11 +354,6 @@ func TestSpanBatchTxsRoundTripFullTxs(t *testing.T) {
} }
func TestSpanBatchTxsRecoverVInvalidTxType(t *testing.T) { func TestSpanBatchTxsRecoverVInvalidTxType(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()
rng := rand.New(rand.NewSource(0x321)) rng := rand.New(rand.NewSource(0x321))
chainID := big.NewInt(rng.Int63n(1000)) chainID := big.NewInt(rng.Int63n(1000))
...@@ -360,8 +363,8 @@ func TestSpanBatchTxsRecoverVInvalidTxType(t *testing.T) { ...@@ -360,8 +363,8 @@ func TestSpanBatchTxsRecoverVInvalidTxType(t *testing.T) {
sbt.txSigs = []spanBatchSignature{{v: 0, r: nil, s: nil}} sbt.txSigs = []spanBatchSignature{{v: 0, r: nil, s: nil}}
sbt.yParityBits = new(big.Int) sbt.yParityBits = new(big.Int)
// expect panic err := sbt.recoverV(chainID)
sbt.recoverV(chainID) require.ErrorContains(t, err, "invalid tx type")
} }
func TestSpanBatchTxsFullTxNotEnoughTxTos(t *testing.T) { func TestSpanBatchTxsFullTxNotEnoughTxTos(t *testing.T) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment