package dao

import (
	dbModel "code.wuban.net.cn/movabridge/token-bridge/model/db"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"
)

var (
	ErrRecordNotFound = gorm.ErrRecordNotFound
)

// Transaction represents a database transaction
type Transaction struct {
	tx *gorm.DB
}

// BeginTx starts a new transaction
func (d *Dao) BeginTx() (*Transaction, error) {
	tx := d.db.Begin()
	if tx.Error != nil {
		return nil, tx.Error
	}
	return &Transaction{tx: tx}, nil
}

// Commit commits the transaction
func (tx *Transaction) Commit() error {
	return tx.tx.Commit().Error
}

// Rollback aborts the transaction
func (tx *Transaction) Rollback() error {
	return tx.tx.Rollback().Error
}

// Transaction-aware versions of the database methods

func (d *Dao) CreateBridgeEventTx(tx *Transaction, event *dbModel.BridgeEvent) error {
	return tx.tx.Clauses(clause.OnConflict{DoNothing: true}).Create(event).Error
}

func (d *Dao) GetBridgeEventWithOutInfoTx(tx *Transaction, fromChain int64, outId int64) (event *dbModel.BridgeEvent, err error) {
	event = new(dbModel.BridgeEvent)
	err = tx.tx.Model(event).Where("`from_chain` = ? AND `out_id` = ?", fromChain, outId).First(event).Error
	if err == gorm.ErrRecordNotFound {
		return nil, ErrRecordNotFound
	}
	return event, err
}

func (d *Dao) GetBridgeEventWithInInfoTx(tx *Transaction, chain int64, inId int64) (event *dbModel.BridgeEvent, err error) {
	event = new(dbModel.BridgeEvent)
	err = tx.tx.Model(event).Where("`to_chain` = ? AND `in_id` = ?", chain, inId).First(event).Error
	if err == gorm.ErrRecordNotFound {
		return nil, ErrRecordNotFound
	}
	return event, err
}

func (d *Dao) UpdateFullBridgeTx(tx *Transaction, event *dbModel.BridgeEvent) error {
	return tx.tx.Model(event).Where("`id` = ?", event.ID).Updates(event).Error
}

func (d *Dao) GetBridgeEventByHashTx(tx *Transaction, hash string) (event *dbModel.BridgeEvent, err error) {
	event = new(dbModel.BridgeEvent)
	err = tx.tx.Model(event).Where("`hash` = ?", hash).First(event).Error
	if err == gorm.ErrRecordNotFound {
		return nil, ErrRecordNotFound
	}
	return event, err
}

func (d *Dao) UpdateBridgeWithTransferInTx(tx *Transaction, event *dbModel.BridgeEvent) error {
	return tx.tx.Model(event).Where("`id` = ?", event.ID).Updates(map[string]interface{}{
		"to_contract":      event.ToContract,
		"in_timestamp":     event.InTimestamp,
		"in_id":            event.InId,
		"to_chain_tx_hash": event.ToChainTxHash,
		"to_chain_status":  event.ToChainStatus,
	}).Error
}

func (d *Dao) UpdateBridgeValidatorOperationTx(tx *Transaction, event *dbModel.BridgeEvent, op int) (err error) {
	return tx.tx.Model(&dbModel.BridgeEvent{}).Where("`id` = ?", event.ID).Updates(map[string]interface{}{
		"validator_status": op,
	}).Error
}

func (d *Dao) UpdateBridgeResultTx(tx *Transaction, event *dbModel.BridgeEvent, toChainHash string, status int) error {
	return tx.tx.Model(&dbModel.BridgeEvent{}).Where("`id` = ?", event.ID).Updates(map[string]interface{}{
		"to_chain_status": status,
		"finish_tx_hash":  toChainHash,
	}).Error
}

func (d *Dao) CreateValidatorEventTx(tx *Transaction, contract string, hash string, chain int64, validator string, txHash string, eventType string, transferInId int64) error {
	event := &dbModel.ValidatorEvent{
		Contract:     contract,
		ChainId:      chain,
		Validator:    validator,
		TxHash:       txHash,
		Event:        eventType,
		TransferInId: transferInId,
		Hash:         hash,
	}
	return tx.tx.Clauses(clause.OnConflict{DoNothing: true}).Create(event).Error
}
