Commit 142e4f53 authored by Hamdi Allam's avatar Hamdi Allam

database transaction support. Return nil on gorm.ErrRecordNotFound

parent c3e54652
...@@ -46,8 +46,8 @@ type LegacyStateBatch struct { ...@@ -46,8 +46,8 @@ type LegacyStateBatch struct {
} }
type BlocksView interface { type BlocksView interface {
FinalizedL1BlockHeight() (*big.Int, error) FinalizedL1BlockHeader() (*L1BlockHeader, error)
FinalizedL2BlockHeight() (*big.Int, error) FinalizedL2BlockHeader() (*L2BlockHeader, error)
} }
type BlocksDB interface { type BlocksDB interface {
...@@ -80,9 +80,6 @@ func (db *blocksDB) StoreL1BlockHeaders(headers []*L1BlockHeader) error { ...@@ -80,9 +80,6 @@ func (db *blocksDB) StoreL1BlockHeaders(headers []*L1BlockHeader) error {
} }
func (db *blocksDB) StoreLegacyStateBatch(stateBatch *LegacyStateBatch) error { func (db *blocksDB) StoreLegacyStateBatch(stateBatch *LegacyStateBatch) error {
// Even though transaction control flow is managed, could we benefit
// from a nested transaction here?
result := db.gorm.Create(stateBatch) result := db.gorm.Create(stateBatch)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
...@@ -111,14 +108,19 @@ func (db *blocksDB) StoreLegacyStateBatch(stateBatch *LegacyStateBatch) error { ...@@ -111,14 +108,19 @@ func (db *blocksDB) StoreLegacyStateBatch(stateBatch *LegacyStateBatch) error {
return result.Error return result.Error
} }
func (db *blocksDB) FinalizedL1BlockHeight() (*big.Int, error) { // FinalizedL1BlockHeader returns the latest L1 block header stored in the database, nil otherwise
func (db *blocksDB) FinalizedL1BlockHeader() (*L1BlockHeader, error) {
var l1Header L1BlockHeader var l1Header L1BlockHeader
result := db.gorm.Order("number DESC").Take(&l1Header) result := db.gorm.Order("number DESC").Take(&l1Header)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, result.Error return nil, result.Error
} }
return l1Header.Number.Int, nil return &l1Header, nil
} }
// L2 // L2
...@@ -128,17 +130,24 @@ func (db *blocksDB) StoreL2BlockHeaders(headers []*L2BlockHeader) error { ...@@ -128,17 +130,24 @@ func (db *blocksDB) StoreL2BlockHeaders(headers []*L2BlockHeader) error {
return result.Error return result.Error
} }
func (db *blocksDB) FinalizedL2BlockHeight() (*big.Int, error) { // FinalizedL2BlockHeader returns the latest L2 block header stored in the database, nil otherwise
func (db *blocksDB) FinalizedL2BlockHeader() (*L2BlockHeader, error) {
var l2Header L2BlockHeader var l2Header L2BlockHeader
result := db.gorm.Order("number DESC").Take(&l2Header) result := db.gorm.Order("number DESC").Take(&l2Header)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, result.Error return nil, result.Error
} }
result.Logger.Info(context.Background(), "number ", l2Header.Number) result.Logger.Info(context.Background(), "number ", l2Header.Number)
return l2Header.Number.Int, nil return &l2Header, nil
} }
// MarkFinalizedL1RootForL2Block updates the stored L2 block header with the L1 block
// that contains the output proposal for the L2 root.
func (db *blocksDB) MarkFinalizedL1RootForL2Block(l2Root, l1Root common.Hash) error { func (db *blocksDB) MarkFinalizedL1RootForL2Block(l2Root, l1Root common.Hash) error {
var l2Header L2BlockHeader var l2Header L2BlockHeader
l2Header.Hash = l2Root // set the primary key l2Header.Hash = l2Root // set the primary key
......
package database package database
import ( import (
"errors"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -99,6 +101,10 @@ func (db *bridgeDB) DepositsByAddress(address common.Address) ([]*DepositWithTra ...@@ -99,6 +101,10 @@ func (db *bridgeDB) DepositsByAddress(address common.Address) ([]*DepositWithTra
deposits := make([]*DepositWithTransactionHash, 100) deposits := make([]*DepositWithTransactionHash, 100)
result := filteredQuery.Scan(&deposits) result := filteredQuery.Scan(&deposits)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, result.Error return nil, result.Error
} }
...@@ -115,27 +121,29 @@ func (db *bridgeDB) StoreWithdrawals(withdrawals []*Withdrawal) error { ...@@ -115,27 +121,29 @@ func (db *bridgeDB) StoreWithdrawals(withdrawals []*Withdrawal) error {
func (db *bridgeDB) MarkProvenWithdrawalEvent(guid, provenL1EventGuid string) error { func (db *bridgeDB) MarkProvenWithdrawalEvent(guid, provenL1EventGuid string) error {
var withdrawal Withdrawal var withdrawal Withdrawal
result := db.gorm.First(&withdrawal, "guid = ?", guid) result := db.gorm.First(&withdrawal, "guid = ?", guid)
if result.Error == nil { if result.Error != nil {
withdrawal.ProvenL1EventGUID = &provenL1EventGuid return result.Error
db.gorm.Save(&withdrawal)
} }
withdrawal.ProvenL1EventGUID = &provenL1EventGuid
result = db.gorm.Save(&withdrawal)
return result.Error return result.Error
} }
func (db *bridgeDB) MarkFinalizedWithdrawalEvent(guid, finalizedL1EventGuid string) error { func (db *bridgeDB) MarkFinalizedWithdrawalEvent(guid, finalizedL1EventGuid string) error {
var withdrawal Withdrawal var withdrawal Withdrawal
result := db.gorm.First(&withdrawal, "guid = ?", guid) result := db.gorm.First(&withdrawal, "guid = ?", guid)
if result.Error == nil { if result.Error != nil {
withdrawal.FinalizedL1EventGUID = &finalizedL1EventGuid return result.Error
db.gorm.Save(&withdrawal)
} }
withdrawal.FinalizedL1EventGUID = &finalizedL1EventGuid
result = db.gorm.Save(&withdrawal)
return result.Error return result.Error
} }
func (db *bridgeDB) WithdrawalsByAddress(address common.Address) ([]*WithdrawalWithTransactionHashes, error) { func (db *bridgeDB) WithdrawalsByAddress(address common.Address) ([]*WithdrawalWithTransactionHashes, error) {
withdrawalsQuery := db.gorm.Debug().Table("withdrawals").Select("withdrawals.*, l2_contract_events.transaction_hash AS l2_transaction_hash, proven_l1_contract_events.transaction_hash AS proven_l1_transaction_hash, finalized_l1_contract_events.transaction_hash AS finalized_l1_transaction_hash") withdrawalsQuery := db.gorm.Table("withdrawals").Select("withdrawals.*, l2_contract_events.transaction_hash AS l2_transaction_hash, proven_l1_contract_events.transaction_hash AS proven_l1_transaction_hash, finalized_l1_contract_events.transaction_hash AS finalized_l1_transaction_hash")
eventsJoinQuery := withdrawalsQuery.Joins("LEFT JOIN l2_contract_events ON withdrawals.initiated_l2_event_guid = l2_contract_events.guid") eventsJoinQuery := withdrawalsQuery.Joins("LEFT JOIN l2_contract_events ON withdrawals.initiated_l2_event_guid = l2_contract_events.guid")
provenJoinQuery := eventsJoinQuery.Joins("LEFT JOIN l1_contract_events AS proven_l1_contract_events ON withdrawals.proven_l1_event_guid = proven_l1_contract_events.guid") provenJoinQuery := eventsJoinQuery.Joins("LEFT JOIN l1_contract_events AS proven_l1_contract_events ON withdrawals.proven_l1_event_guid = proven_l1_contract_events.guid")
...@@ -147,6 +155,10 @@ func (db *bridgeDB) WithdrawalsByAddress(address common.Address) ([]*WithdrawalW ...@@ -147,6 +155,10 @@ func (db *bridgeDB) WithdrawalsByAddress(address common.Address) ([]*WithdrawalW
withdrawals := make([]*WithdrawalWithTransactionHashes, 100) withdrawals := make([]*WithdrawalWithTransactionHashes, 100)
result := filteredQuery.Scan(&withdrawals) result := filteredQuery.Scan(&withdrawals)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, result.Error return nil, result.Error
} }
......
...@@ -29,8 +29,6 @@ type L2ContractEvent struct { ...@@ -29,8 +29,6 @@ type L2ContractEvent struct {
} }
type ContractEventsView interface { type ContractEventsView interface {
L1ContractEventByGUID(string) (*L1ContractEvent, error)
L2ContractEventByGUID(string) (*L2ContractEvent, error)
} }
type ContractEventsDB interface { type ContractEventsDB interface {
...@@ -59,29 +57,9 @@ func (db *contractEventsDB) StoreL1ContractEvents(events []*L1ContractEvent) err ...@@ -59,29 +57,9 @@ func (db *contractEventsDB) StoreL1ContractEvents(events []*L1ContractEvent) err
return result.Error return result.Error
} }
func (db *contractEventsDB) L1ContractEventByGUID(guid string) (*L1ContractEvent, error) {
var event L1ContractEvent
result := db.gorm.First(&event, "guid = ?", guid)
if result.Error != nil {
return nil, result.Error
}
return &event, nil
}
// L2 // L2
func (db *contractEventsDB) StoreL2ContractEvents(events []*L2ContractEvent) error { func (db *contractEventsDB) StoreL2ContractEvents(events []*L2ContractEvent) error {
result := db.gorm.Create(&events) result := db.gorm.Create(&events)
return result.Error return result.Error
} }
func (db *contractEventsDB) L2ContractEventByGUID(guid string) (*L2ContractEvent, error) {
var event L2ContractEvent
result := db.gorm.First(&event, "guid = ?", guid)
if result.Error != nil {
return nil, result.Error
}
return &event, nil
}
...@@ -34,3 +34,20 @@ func NewDB(dsn string) (*DB, error) { ...@@ -34,3 +34,20 @@ func NewDB(dsn string) (*DB, error) {
return db, nil return db, nil
} }
// Transaction executes all operations conducted with the supplied database in a single
// transaction. If the supplied function errors, the transaction is rolled back.
func (db *DB) Transaction(fn func(db *DB) error) error {
return db.gorm.Transaction(func(tx *gorm.DB) error {
return fn(dbFromGormTx(tx))
})
}
func dbFromGormTx(tx *gorm.DB) *DB {
return &DB{
gorm: tx,
Blocks: newBlocksDB(tx),
ContractEvents: newContractEventsDB(tx),
Bridge: newBridgeDB(tx),
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment