Commit c829e0cb authored by pcw109550's avatar pcw109550

op-node: Remove Panic while Span Batch Derivation

parent 918459c4
......@@ -392,7 +392,9 @@ func (b *RawSpanBatch) derive(blockTime, genesisTimestamp uint64, chainID *big.I
}
}
b.txs.recoverV(chainID)
if err := b.txs.recoverV(chainID); err != nil {
return nil, err
}
fullTxs, err := b.txs.fullTxs(chainID)
if err != nil {
return nil, err
......
......@@ -219,7 +219,8 @@ func TestSpanBatchPayload(t *testing.T) {
err = sb.decodePayload(r)
require.NoError(t, err)
sb.txs.recoverV(chainID)
err = sb.txs.recoverV(chainID)
require.NoError(t, err)
require.Equal(t, rawSpanBatch.spanBatchPayload, sb.spanBatchPayload)
}
......@@ -283,7 +284,8 @@ func TestSpanBatchTxs(t *testing.T) {
err = sb.decodeTxs(r)
require.NoError(t, err)
sb.txs.recoverV(chainID)
err = sb.txs.recoverV(chainID)
require.NoError(t, err)
require.Equal(t, rawSpanBatch.txs, sb.txs)
}
......@@ -302,7 +304,8 @@ func TestSpanBatchRoundTrip(t *testing.T) {
err = sb.decode(bytes.NewReader(result.Bytes()))
require.NoError(t, err)
sb.txs.recoverV(chainID)
err = sb.txs.recoverV(chainID)
require.NoError(t, err)
require.Equal(t, rawSpanBatch, &sb)
}
......
......@@ -91,9 +91,9 @@ func (btx *spanBatchTxs) decodeContractCreationBits(r *bytes.Reader) error {
return nil
}
func (btx *spanBatchTxs) contractCreationCount() uint64 {
func (btx *spanBatchTxs) contractCreationCount() (uint64, error) {
if btx.contractCreationBits == nil {
panic("contract creation bits not set")
return 0, errors.New("dev error: contract creation bits not set")
}
var result uint64 = 0
for i := 0; i < int(btx.totalBlockTxCount); i++ {
......@@ -102,7 +102,7 @@ func (btx *spanBatchTxs) contractCreationCount() uint64 {
result++
}
}
return result
return result, nil
}
// yParityBits is bitlist right-padded to a multiple of 8 bits
......@@ -264,7 +264,10 @@ func (btx *spanBatchTxs) decodeTxGases(r *bytes.Reader) error {
func (btx *spanBatchTxs) decodeTxTos(r *bytes.Reader) error {
var txTos []common.Address
txToBuffer := make([]byte, common.AddressLength)
contractCreationCount := btx.contractCreationCount()
contractCreationCount, err := btx.contractCreationCount()
if err != nil {
return err
}
for i := 0; i < int(btx.totalBlockTxCount-contractCreationCount); i++ {
_, err := io.ReadFull(r, txToBuffer)
if err != nil {
......@@ -293,9 +296,9 @@ func (btx *spanBatchTxs) decodeTxDatas(r *bytes.Reader) error {
return nil
}
func (btx *spanBatchTxs) recoverV(chainID *big.Int) {
func (btx *spanBatchTxs) recoverV(chainID *big.Int) error {
if len(btx.txTypes) != len(btx.txSigs) {
panic("tx type length and tx sigs length mismatch")
return errors.New("tx type length and tx sigs length mismatch")
}
for idx, txType := range btx.txTypes {
bit := uint64(btx.yParityBits.Bit(idx))
......@@ -309,10 +312,11 @@ func (btx *spanBatchTxs) recoverV(chainID *big.Int) {
case types.DynamicFeeTxType:
v = bit
default:
panic(fmt.Sprintf("invalid tx type: %d", txType))
return fmt.Errorf("invalid tx type: %d", txType)
}
btx.txSigs[idx].v = v
}
return nil
}
func (btx *spanBatchTxs) encode(w io.Writer) error {
......@@ -400,7 +404,7 @@ func (btx *spanBatchTxs) fullTxs(chainID *big.Int) ([][]byte, error) {
return txs, nil
}
func convertVToYParity(v uint64, txType int) uint {
func convertVToYParity(v uint64, txType int) (uint, error) {
var yParityBit uint
switch txType {
case types.LegacyTxType:
......@@ -412,9 +416,9 @@ func convertVToYParity(v uint64, txType int) uint {
case types.DynamicFeeTxType:
yParityBit = uint(v)
default:
panic(fmt.Sprintf("invalid tx type: %d", txType))
return 0, fmt.Errorf("invalid tx type: %d", txType)
}
return yParityBit
return yParityBit, nil
}
func newSpanBatchTxs(txs [][]byte, chainID *big.Int) (*spanBatchTxs, error) {
......@@ -449,7 +453,10 @@ func newSpanBatchTxs(txs [][]byte, chainID *big.Int) (*spanBatchTxs, error) {
contractCreationBit = uint(0)
}
contractCreationBits.SetBit(contractCreationBits, idx, contractCreationBit)
yParityBit := convertVToYParity(txSig.v, int(tx.Type()))
yParityBit, err := convertVToYParity(txSig.v, int(tx.Type()))
if err != nil {
return nil, err
}
yParityBits.SetBit(yParityBits, idx, yParityBit)
txNonces = append(txNonces, tx.Nonce())
txGases = append(txGases, tx.Gas())
......
......@@ -54,7 +54,8 @@ func TestSpanBatchTxsContractCreationCount(t *testing.T) {
rawSpanBatch := RandomRawSpanBatch(rng, chainID)
contractCreationBits := rawSpanBatch.txs.contractCreationBits
contractCreationCount := rawSpanBatch.txs.contractCreationCount()
contractCreationCount, err := rawSpanBatch.txs.contractCreationCount()
require.NoError(t, err)
totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
var sbt spanBatchTxs
......@@ -62,7 +63,7 @@ func TestSpanBatchTxsContractCreationCount(t *testing.T) {
sbt.totalBlockTxCount = totalBlockTxCount
var buf bytes.Buffer
err := sbt.encodeContractCreationBits(&buf)
err = sbt.encodeContractCreationBits(&buf)
require.NoError(t, err)
result := buf.Bytes()
......@@ -72,7 +73,10 @@ func TestSpanBatchTxsContractCreationCount(t *testing.T) {
err = sbt.decodeContractCreationBits(r)
require.NoError(t, err)
require.Equal(t, contractCreationCount, sbt.contractCreationCount())
contractCreationCount2, err := sbt.contractCreationCount()
require.NoError(t, err)
require.Equal(t, contractCreationCount, contractCreationCount2)
}
func TestSpanBatchTxsYParityBits(t *testing.T) {
......@@ -277,7 +281,8 @@ func TestSpanBatchTxsRecoverV(t *testing.T) {
txSig.s, _ = uint256.FromBig(s)
txSigs = append(txSigs, txSig)
originalVs = append(originalVs, v.Uint64())
yParityBit := convertVToYParity(v.Uint64(), int(tx.Type()))
yParityBit, err := convertVToYParity(v.Uint64(), int(tx.Type()))
require.NoError(t, err)
yParityBits.SetBit(yParityBits, idx, yParityBit)
}
......@@ -285,7 +290,8 @@ func TestSpanBatchTxsRecoverV(t *testing.T) {
spanBatchTxs.txSigs = txSigs
spanBatchTxs.txTypes = txTypes
// recover txSig.v
spanBatchTxs.recoverV(chainID)
err := spanBatchTxs.recoverV(chainID)
require.NoError(t, err)
var recoveredVs []uint64
for _, txSig := range spanBatchTxs.txSigs {
......@@ -315,7 +321,9 @@ func TestSpanBatchTxsRoundTrip(t *testing.T) {
sbt2.totalBlockTxCount = totalBlockTxCount
err = sbt2.decode(r)
require.NoError(t, err)
sbt2.recoverV(chainID)
err = sbt2.recoverV(chainID)
require.NoError(t, err)
require.Equal(t, sbt, &sbt2)
}
......@@ -346,11 +354,6 @@ func TestSpanBatchTxsRoundTripFullTxs(t *testing.T) {
}
func TestSpanBatchTxsRecoverVInvalidTxType(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()
rng := rand.New(rand.NewSource(0x321))
chainID := big.NewInt(rng.Int63n(1000))
......@@ -360,8 +363,8 @@ func TestSpanBatchTxsRecoverVInvalidTxType(t *testing.T) {
sbt.txSigs = []spanBatchSignature{{v: 0, r: nil, s: nil}}
sbt.yParityBits = new(big.Int)
// expect panic
sbt.recoverV(chainID)
err := sbt.recoverV(chainID)
require.ErrorContains(t, err, "invalid tx type")
}
func TestSpanBatchTxsFullTxNotEnoughTxTos(t *testing.T) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment