package chain

import (
	"code.wuban.net.cn/movabridge/token-bridge/config"
	"code.wuban.net.cn/movabridge/token-bridge/constant"
	"code.wuban.net.cn/movabridge/token-bridge/contract/bridge"
	"code.wuban.net.cn/movabridge/token-bridge/dao"
	dbModel "code.wuban.net.cn/movabridge/token-bridge/model/db"
	"fmt"
	"golang.org/x/crypto/sha3"
	"math/big"
	"strings"
	"sync"
	"time"

	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/core/types"
	log "github.com/sirupsen/logrus"
)

type ChainSync struct {
	chain     *config.ChainConfig
	d         *dao.Dao
	name      string
	heightKey string
	bridgeCa  *bridge.BridgeContract
	quit      chan struct{}
	stopOnce  sync.Once
}

func NewChainSync(_chain *config.ChainConfig, _d *dao.Dao) (sync *ChainSync) {
	bridgeCa, err := bridge.NewBridgeContract(common.HexToAddress(_chain.BridgeContract), nil)
	if err != nil {
		panic(err)
	}
	sync = &ChainSync{
		chain:     _chain,
		d:         _d,
		name:      _chain.Name,
		heightKey: fmt.Sprintf("%d_%s", _chain.ChainId, "height"),
		bridgeCa:  bridgeCa,
		quit:      make(chan struct{}),
	}

	return sync
}

func (s *ChainSync) Start() {
	lastHeight, err := s.d.GetStorageHeight(s.heightKey)
	if err != nil {
		if err == dao.ErrRecordNotFound {
			lastHeight = s.chain.InitialHeight
		} else {
			log.WithField("chain", s.name).WithField("chain", s.name).WithError(err).Error("get last block height")
			return
		}
	}

	if lastHeight != 1 {
		// 数据库里保存的是已完成的区块, 再次同步时+1
		lastHeight++
	}

	log.WithField("chain", s.name).WithField("height", lastHeight).Info("last validator block height")
	var latestHeight int64
	var beginHeight = lastHeight
	var endHeight = beginHeight + int64(s.chain.BatchBlock)
	tm := time.NewTicker(time.Second)
	defer tm.Stop()

	for {
		select {
		case <-s.quit:
			log.WithField("chain", s.name).Info("chain sync stopped")
			return
		case <-tm.C:
			latestHeight, err = s.d.GetBlockHeight(s.chain, s.chain.BehindBlock)
			if err != nil {
				log.WithField("chain", s.name).WithError(err).Error("get latest block height")
				time.Sleep(time.Second)
				continue
			}
			if (latestHeight-int64(s.chain.BatchBlock)-1)-beginHeight < int64(s.chain.BatchBlock) {
				time.Sleep(2 * time.Second)
				continue
			}
			if err := s.SyncLogs(beginHeight, endHeight); err != nil {
				log.WithField("chain", s.name).WithFields(log.Fields{
					"begin height": beginHeight,
					"end height":   endHeight,
				}).WithError(err).Error("sync logs failed")
				time.Sleep(time.Second)
				continue
			}

			if err = s.d.SetStorageHeight(s.heightKey, endHeight); err != nil {
				log.WithField("chain", s.name).WithError(err).Error("set last block height")
			}

			beginHeight = endHeight + 1
			endHeight = beginHeight + int64(s.chain.BatchBlock)
			log.WithField("chain", s.name).WithFields(log.Fields{
				"begin height":  beginHeight,
				"end height":    endHeight,
				"latest height": latestHeight,
				"diff height":   latestHeight - endHeight,
			}).Info("validator block")
		}
	}
}

func (s *ChainSync) Stop() {
	s.stopOnce.Do(func() {
		log.WithField("chain", s.name).Info("stopping chain sync")
		close(s.quit)
	})
}

func (s *ChainSync) SyncLogs(beginHeight, endHeight int64) error {
	if endHeight < 0 {
		return nil
	}

	if beginHeight < 0 {
		beginHeight = 0
	}

	abi, _ := bridge.BridgeContractMetaData.GetAbi()
	topics := []string{
		abi.Events["TransferOut"].ID.Hex(),            // TransferOut
		abi.Events["TransferIn"].ID.Hex(),             // TransferIn
		abi.Events["TransferInConfirmation"].ID.Hex(), // TransferInConfirmation
		abi.Events["TransferInRejection"].ID.Hex(),    // TransferInRejection
		abi.Events["TransferInExecution"].ID.Hex(),    // TransferInExecution
	}

	logs, err := s.d.GetLogs(s.chain, beginHeight, endHeight, topics, []string{
		s.chain.BridgeContract,
	})
	if err != nil {
		log.WithField("chain", s.name).WithFields(log.Fields{"begin": beginHeight, "end": endHeight}).WithError(err).Error("rpc: get logs")
		return err
	}

	if len(logs) > 0 {
		log.WithField("chain", s.name).WithFields(log.Fields{"begin": beginHeight, "end": endHeight}).Infof("get %d logs", len(logs))
	}
	// begin orm transaction
	ormTx, err := s.d.BeginTx()
	if err != nil {
		log.WithField("chain", s.name).WithError(err).Error("begin db transaction")
		return err
	}

	var ormTxErr error
	for _, txLog := range logs {
		if err := s.FilterTransferOut(txLog, ormTx); err != nil {
			ormTxErr = err
			break
		}
		if err := s.FilterTransferIn(txLog, ormTx); err != nil {
			ormTxErr = err
			break
		}
		if err := s.FilterValidatorEvents(txLog, ormTx); err != nil {
			ormTxErr = err
			break
		}
	}
	// Commit or rollback transaction based on error
	if ormTxErr != nil {
		if rbErr := ormTx.Rollback(); rbErr != nil {
			log.WithField("chain", s.name).WithError(rbErr).Error("failed to rollback transaction")
		}
		log.WithField("chain", s.name).WithError(ormTxErr).Error("error processing logs, transaction rolled back")
	} else {
		if cmtErr := ormTx.Commit(); cmtErr != nil {
			log.WithField("chain", s.name).WithError(cmtErr).Error("failed to commit transaction")
		}
	}

	return nil
}

// FilterTransferOut 用户从当前链跨出事件.
func (s *ChainSync) FilterTransferOut(txLog types.Log, tx *dao.Transaction) error {
	if len(txLog.Topics) == 0 {
		return nil
	}

	abi, _ := bridge.BridgeContractMetaData.GetAbi()
	if txLog.Topics[0].Hex() == abi.Events["TransferOut"].ID.Hex() {
		event, err := s.bridgeCa.ParseTransferOut(txLog)
		if err != nil {
			log.WithField("chain", s.name).WithError(err).Error("parse TransferOut log")
			return err
		}
		// 防止重复入库.
		eventHash := transferOutEventHash(event.FromChainID.Int64(), event.OutId.Int64(), strings.ToLower(txLog.Address.String()))

		dbEvent := &dbModel.BridgeEvent{
			FromChain:       event.FromChainID.Int64(),
			OutTimestamp:    int64(txLog.BlockTimestamp),
			FromContract:    strings.ToLower(txLog.Address.String()),
			FromAddress:     strings.ToLower(event.Sender.String()),
			FromToken:       strings.ToLower(event.Token.String()),
			FromChainTxHash: strings.ToLower(txLog.TxHash.String()),
			SendAmount:      event.Amount.Text(10),
			FeeAmount:       event.Fee.Text(10),
			ToToken:         strings.ToLower(event.ReceiveToken.String()),
			ReceiveAmount:   new(big.Int).Sub(event.Amount, event.Fee).Text(10),
			OutId:           event.OutId.Int64(),
			Receiver:        strings.ToLower(event.Receiver.String()),
			ToChain:         event.ToChainID.Int64(),
			ToChainStatus:   constant.TransferChainNoProcess,
			Hash:            eventHash,
		}

		err = s.d.CreateBridgeEventTx(tx, dbEvent)
		if err != nil {
			log.WithField("chain", s.name).WithFields(log.Fields{
				"error": err.Error(),
			}).Error("db create bridge in event")
			return err
		}
		log.WithField("chain", s.name).WithField("txHash", txLog.TxHash.Hex()).Info("db create, TransferOut event")
	}
	return nil
}

// FilterTransferIn 用户从目标链跨入事件及执行结束事件.
func (s *ChainSync) FilterTransferIn(txLog types.Log, tx *dao.Transaction) error {
	if len(txLog.Topics) == 0 {
		return nil
	}

	abi, _ := bridge.BridgeContractMetaData.GetAbi()
	switch txLog.Topics[0].Hex() {
	case abi.Events["TransferIn"].ID.Hex():
		event, err := s.bridgeCa.ParseTransferIn(txLog)
		if err != nil {
			log.WithField("chain", s.name).WithError(err).Error("parse TransferIn log")
			return err
		}
		// find out if the event already exists in the database.
		dbEvent, err := s.d.GetBridgeEventWithOutInfoTx(tx, event.FromChainID.Int64(), event.OutId.Int64())
		if err == dao.ErrRecordNotFound {
			log.WithField("chain", s.name).WithField("outId", event.OutId.Int64()).Error("transfer out event not found")
			return nil
		}
		// update the event with the transfer in information.
		dbEvent.ToContract = strings.ToLower(txLog.Address.String())
		dbEvent.InTimestamp = int64(txLog.BlockTimestamp)
		dbEvent.InId = event.InId.Int64()
		dbEvent.ToChainTxHash = strings.ToLower(txLog.TxHash.String())
		dbEvent.ToChainStatus = constant.TransferChainWaitConfirm
		if err := s.d.UpdateBridgeWithTransferInTx(tx, dbEvent); err != nil {
			log.WithField("chain", s.name).WithFields(log.Fields{
				"error": err.Error(),
			}).Error("db update transfer in event")
			return err
		}
		return nil

	case abi.Events["TransferInExecution"].ID.Hex():
		event, err := s.bridgeCa.ParseTransferInExecution(txLog)
		if err != nil {
			log.WithField("chain", s.name).WithError(err).Error("parse TransferInExecution log")
			return err
		}
		dbEvent, err := s.d.GetBridgeEventWithInInfoTx(tx, s.chain.ChainId, event.InId.Int64())
		if err == dao.ErrRecordNotFound {
			log.WithField("chain", s.name).WithField("inId", event.InId.Int64()).Error("transfer in event not found")
			return nil
		}
		dbEvent.ToChainStatus = constant.TransferChainExecuted
		dbEvent.FinishTxHash = strings.ToLower(txLog.TxHash.String())
		if err := s.d.UpdateBridgeResultTx(tx, dbEvent, dbEvent.FinishTxHash, dbEvent.ToChainStatus); err != nil {
			log.WithField("chain", s.name).WithFields(log.Fields{
				"error": err.Error(),
			}).Error("db update transfer in execution event")
			return err
		}
	case abi.Events["TransferInRejection"].ID.Hex():
		event, err := s.bridgeCa.ParseTransferInRejection(txLog)
		if err != nil {
			log.WithField("chain", s.name).WithError(err).Error("parse TransferInExecution log")
			return err
		}
		dbEvent, err := s.d.GetBridgeEventWithInInfoTx(tx, s.chain.ChainId, event.InId.Int64())
		if err == dao.ErrRecordNotFound {
			log.WithField("chain", s.name).WithField("inId", event.InId.Int64()).Error("transfer in event not found")
			return nil
		}
		dbEvent.ToChainStatus = constant.TransferChainRejected
		dbEvent.FinishTxHash = strings.ToLower(txLog.TxHash.String())
		if err := s.d.UpdateBridgeResultTx(tx, dbEvent, dbEvent.FinishTxHash, dbEvent.ToChainStatus); err != nil {
			log.WithField("chain", s.name).WithFields(log.Fields{
				"error": err.Error(),
			}).Error("db update transfer in execution event")
			return err
		}
	}
	return nil
}

// FilterValidatorEvents 当前链验证者事件.
func (s *ChainSync) FilterValidatorEvents(txLog types.Log, tx *dao.Transaction) error {
	if len(txLog.Topics) == 0 {
		return nil
	}
	var (
		chainId   = s.chain.ChainId
		validator = ""
		eventType = ""
		inId      = int64(0)
		eventHash = ""
		txHash    = txLog.TxHash.Hex()
	)

	abi, _ := bridge.BridgeContractMetaData.GetAbi()
	switch txLog.Topics[0].Hex() {
	case abi.Events["TransferInConfirmation"].ID.Hex():
		event, err := s.bridgeCa.ParseTransferInConfirmation(txLog)
		if err != nil {
			log.WithField("chain", s.name).WithError(err).Error("parse TransferInConfirmation log")
			return err
		}
		eventHash = validatorEventHash(s.chain.ChainId, event.Validator.String(), txLog.TxHash.Hex(), event.InId.Int64(), "TransferInConfirmation")
		validator = strings.ToLower(event.Validator.String())
		inId = event.InId.Int64()
		eventType = "TransferInConfirmation"
	case abi.Events["TransferInRejection"].ID.Hex():
		event, err := s.bridgeCa.ParseTransferInRejection(txLog)
		if err != nil {
			log.WithField("chain", s.name).WithError(err).Error("parse TransferInRejection log")
			return err
		}
		eventHash = validatorEventHash(s.chain.ChainId, event.Validator.String(), txLog.TxHash.Hex(), event.InId.Int64(), "TransferInRejection")
		validator = strings.ToLower(event.Validator.String())
		inId = event.InId.Int64()
		eventType = "TransferInRejection"
	}
	err := s.d.CreateValidatorEventTx(tx, eventHash, chainId, validator, txHash, eventType, inId)
	if err != nil {
		log.WithField("chain", s.name).WithFields(log.Fields{
			"error": err.Error(),
		}).Error("db create validator event")
		return err
	}
	return nil
}

func validatorEventHash(chainId int64, validator string, txHash string, inId int64, event string) string {
	hash := sha3.NewLegacyKeccak256()
	hash.Write([]byte(fmt.Sprintf("%d%s%s%d%s", chainId, validator, txHash, inId, event)))
	return common.BytesToHash(hash.Sum(nil)).String()
}

func transferOutEventHash(fromChain int64, outId int64, fromContract string) string {
	hash := sha3.NewLegacyKeccak256()
	hash.Write([]byte(fmt.Sprintf("%d%d%s", fromChain, outId, fromContract)))
	return common.BytesToHash(hash.Sum(nil)).String()
}
