package chaindb

import (
	"errors"
	"fmt"
	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/crypto"
	"github.com/ethereum/go-ethereum/event"
	"github.com/ethereum/go-ethereum/log"
	"github.com/exchain/go-exchain/exchain"
	nebulav1 "github.com/exchain/go-exchain/exchain/protocol/gen/go/nebula/v1"
	"github.com/exchain/go-exchain/exchain/wdt"
	"github.com/exchain/go-exchain/exchain/wrapper"
	"github.com/exchain/go-exchain/metadb"
	"github.com/exchain/go-exchain/metadb/memdb"
	"github.com/exchain/go-exchain/op-service/grouptask"
	"github.com/golang/protobuf/proto"
	"github.com/holiman/uint256"
	"sync"
	"sync/atomic"
	"time"

	lru "github.com/hashicorp/golang-lru"
)

type ChainDB interface {
	Database() metadb.Database
	ChainId() (*uint256.Int, error)
	GetBlockByLabel(label ExChainBlockLabel) (*nebulav1.Block, error)
	SaveChainId(chainid *uint256.Int) error
	CurrentHeight() uint256.Int
	GetOriginBlockData(*uint256.Int) ([]byte, error)
	GetBlock(*uint256.Int) *nebulav1.Block
	GetBlockHeader(*uint256.Int) *nebulav1.BlockHeader
	BlockHeaderByHash(hash common.Hash) *nebulav1.BlockHeader
	BlockByHash(hash common.Hash) *nebulav1.Block
	GetTransaction(hash common.Hash) (*nebulav1.Transaction, error)
	GetReceipt(hash common.Hash) *nebulav1.TransactionReceipt
	SaveBlockHeader(header *nebulav1.BlockHeader) error
	SaveBlockData(*nebulav1.Block, *nebulav1.TransactionReceiptList) error
	SubscribeChainEvent(ch chan<- exchain.ChainEvent) event.Subscription
	EmitChain(block *nebulav1.Block, hash common.Hash)
	ResetHeight(*uint256.Int, bool) error
	GetWDT(number uint256.Int) (*wdt.MerkleTree, error)
}

var (
	big10 = uint256.NewInt(10)
	big1  = uint256.NewInt(1)
	big0  = uint256.NewInt(0)
)

type ExChainBlockLabel int

const (
	ExChainBlockLatest    ExChainBlockLabel = -1
	ExChainBlockEarliest  ExChainBlockLabel = 0
	ExChainBlockFinalized ExChainBlockLabel = -2
)

func NewChainDB(log log.Logger, database metadb.Database) ChainDB {
	chain := &chaindb{
		log:        log,
		database:   database,
		cache:      memdb.NewMemDB(),
		toSaveData: make(chan chainData, 1000000),
	}
	var err error
	if chain.txCache, err = lru.New(1000000); err != nil {
		panic(err)
	}
	if chain.receiptCache, err = lru.New(1000000); err != nil {
		panic(err)
	}
	if chain.wdtCache, err = lru.New(100); err != nil {
		log.Error("failed to create mdt cache, ", err)
	}
	chain.storeTask()
	return chain
}

type chainData struct {
	block    *wrapper.BlkWrapper
	receipts *nebulav1.TransactionReceiptList
}

type chaindb struct {
	log              log.Logger
	cache            metadb.CacheKV
	wdtCache         *lru.Cache
	txCache          *lru.Cache
	receiptCache     *lru.Cache
	database         metadb.Database
	height           atomic.Value
	toSaveData       chan chainData
	startHeight      *uint256.Int
	blockConfirmFeed event.Feed
	chainFeed        event.Feed
	chainHeadFeed    event.Feed
	logsFeed         event.Feed
	rmlogsFeed       event.Feed
	scope            event.SubscriptionScope
}

func (m *chaindb) GetBlockByLabel(label ExChainBlockLabel) (*nebulav1.Block, error) {
	switch label {
	case ExChainBlockLatest, ExChainBlockFinalized:
		latest := m.CurrentHeight()
		return m.GetBlock(&latest), nil
	case ExChainBlockEarliest:
		return m.GetBlock(big0), nil
	default:
		return nil, errors.New("invalid block label")
	}
}

func (m *chaindb) ChainId() (*uint256.Int, error) {
	k := chainIdKey()
	if v, err := m.database.Get([]byte(k)); err != nil {
		return nil, err
	} else {
		return new(uint256.Int).SetBytes(v), nil
	}
}

func (m *chaindb) GetWDT(number uint256.Int) (*wdt.MerkleTree, error) {
	cp := exchain.ToCheckpoint(number)
	latest := m.CurrentHeight()
	if latest.Cmp(cp.End()) < 0 {
		return nil, errors.New("checkpoint not ready")
	}
	if t, _ := m.wdtCache.Get(cp); t != nil {
		return t.(*wdt.MerkleTree), nil
	}
	nt, err := m.generateWDT(cp)
	if err != nil {
		return nil, err
	}
	m.wdtCache.Add(cp, nt)
	return nt, nil
}

func (m *chaindb) generateWDT(cp exchain.Checkpoint) (*wdt.MerkleTree, error) {
	withdrawalTxs := &nebulav1.TransactionList{
		Txs: make([]*nebulav1.Transaction, 0),
	}
	for i := cp.Start().Uint64(); i <= cp.End().Uint64(); i++ {
		blk := m.GetBlock(uint256.NewInt(i))
		if blk == nil {
			return nil, errors.New("block not found in exchain")
		}
		wblk := wrapper.NewBlkWrapper(blk)
		withdrawalTxs.Txs = append(withdrawalTxs.Txs, wblk.FilterTransactions(wrapper.TxTypeFilter(nebulav1.TxType_WithdrawTx))...)
	}
	leaves := make([][]byte, len(withdrawalTxs.Txs))
	for _, tx := range withdrawalTxs.Txs {
		content := tx.GetWithdrawTx()
		data := make([]byte, 0)

		data = append(data, content.User...)
		data = append(data, content.Coin...)
		data = append(data, content.Amount...)
		leaves = append(leaves, crypto.Keccak256Hash(data).Bytes())
	}
	tree, err := wdt.GenerateTreeFromHashedItems(leaves)
	if err != nil {
		m.log.Error("failed to generate wdt tree", "err", err)
		return nil, err
	}
	return tree, nil
}

func (m *chaindb) SaveChainId(chainid *uint256.Int) error {
	k := chainIdKey()
	return m.database.Put([]byte(k), chainid.Bytes())
}

func (m *chaindb) SubscribeChainEvent(ch chan<- exchain.ChainEvent) event.Subscription {
	return m.scope.Track(m.chainFeed.Subscribe(ch))
}
func (m *chaindb) GetTransaction(hash common.Hash) (*nebulav1.Transaction, error) {
	if tx, exist := m.txCache.Get(hash); exist {
		ptx := tx.(*nebulav1.Transaction)
		return ptx, nil
	} else {
		entry, err := m.getTxEntry(hash)
		if err != nil {
			return nil, err
		}
		return m.getTransaction(uint256.NewInt(entry.BlockNumber), int(entry.Index))
	}
}
func (m *chaindb) GetReceipt(txhash common.Hash) *nebulav1.TransactionReceipt {
	if r, exist := m.receiptCache.Get(txhash); exist {
		return r.(*nebulav1.TransactionReceipt)
	} else {
		entry, err := m.getTxEntry(txhash)
		if err != nil {
			return nil
		}
		return m.getReceipt(uint256.NewInt(entry.BlockNumber), int(entry.Index))
	}
}
func (m *chaindb) CurrentHeight() uint256.Int {
	var height = uint256.NewInt(0)
	if v := m.height.Load(); v != nil {
		return *v.(*uint256.Int)
	} else {
		// load height string
		h, err := m.database.Get([]byte(chainHeightKey()))
		if err == nil {
			height, _ = uint256.FromDecimal(string(h))
		}
		m.height.Store(height)
	}
	return *height
}
func (m *chaindb) GetBlock(num *uint256.Int) *nebulav1.Block {
	return m.GetBlockBody(num)
}
func (m *chaindb) GetBlockBody(num *uint256.Int) *nebulav1.Block {
	k := blockBodyKey(num)
	if b, exist := m.cache.Get(k); exist {
		return b.(*nebulav1.Block)
	} else {
		d, err := m.database.Get([]byte(k))
		if err != nil {
			return nil
		}
		block := new(nebulav1.Block)
		if err = proto.Unmarshal(d, block); err != nil {
			return nil
		}

		return block
	}
}
func (m *chaindb) GetBlockHeader(num *uint256.Int) *nebulav1.BlockHeader {
	k := blockHeaderKey(num)
	if h, exist := m.cache.Get(k); exist {
		return h.(*nebulav1.BlockHeader)
	} else {
		d, err := m.database.Get([]byte(k))
		if err != nil {
			return nil
		}
		header := new(nebulav1.BlockHeader)
		if err = proto.Unmarshal(d, header); err != nil {
			return nil
		}
		return header
	}
}
func (m *chaindb) GetOriginBlockData(num *uint256.Int) (block []byte, err error) {
	k := blockBodyKey(num)
	if b, exist := m.cache.Get(k); exist {
		blk := b.(*nebulav1.Block)
		return proto.Marshal(blk)
	}
	return m.database.Get([]byte(k))
}

func (m *chaindb) BlockHeaderByHash(hash common.Hash) *nebulav1.BlockHeader {
	number := m.blockNumberByHash(hash)
	if number == nil {
		return nil
	}
	return m.GetBlockHeader(number)
}
func (m *chaindb) BlockByHash(hash common.Hash) *nebulav1.Block {
	number := m.blockNumberByHash(hash)
	if number == nil {
		return nil
	}
	return m.GetBlock(number)
}
func (m *chaindb) GetBlockTransactions(num *uint256.Int) *nebulav1.TransactionList {
	txs := m.getBlockTxs(num)
	return txs
}
func (m *chaindb) GetBlockReceipts(num *uint256.Int) *nebulav1.TransactionReceiptList {
	return m.getBlockReceipts(num)
}

func (m *chaindb) Database() metadb.Database {
	return m.database
}

func (m *chaindb) chainDataSaveTask() {
	var duration = time.Second * 10
	tm := time.NewTicker(duration)
	defer tm.Stop()

	for {
		select {
		case data := <-m.toSaveData:
			// 1. save txhash -> entry
			// 2. save blockHash -> number
			// 3. save block header
			// 4. save block body
			// 5. save block receipts
			// 6. save block accounts
			// 7. save chain height
			block := data.block

			m.log.Info("save block", "block number", block.Height())
			blockHash := block.Hash()
			blockHeight := uint256.NewInt(block.Height())

			header := block.Header()
			{
				t1 := time.Now()
				// save blockHash -> number.Bytes()
				bnk := blockNumKey(blockHash)
				m.database.Put([]byte(bnk), blockHeight.Bytes())
				t2 := time.Now()
				m.log.Debug("save block number key", "cost", t2.Sub(t1).Milliseconds())
			}
			{
				t1 := time.Now()
				// save block header data with number
				hk := blockHeaderKey(blockHeight)
				if dh, err := proto.Marshal(header); err != nil {
					m.log.Error("marshal block header failed", "err", err)
					// todo: vicotor to handle error.
					panic(fmt.Sprintf("marshal block header failed with err %v", err))
				} else {
					if err := m.database.Put([]byte(hk), dh); err != nil {
						m.log.Error("save block header failed", "err", err)
						panic(fmt.Sprintf("save block header failed with err %v", err))
					}
				}

				t2 := time.Now()
				m.log.Debug("save block header", "cost", t2.Sub(t1).Milliseconds())
				// remove from cache.
				m.cache.Delete(hk)
			}
			{
				t1 := time.Now()
				// save block body
				bodyk := blockBodyKey(blockHeight)
				dbody, err := proto.Marshal(block.Block())
				if err != nil {
					m.log.Error("marshal block data failed", "err", err)
					panic(fmt.Sprintf("marshal block data failed with err %v", err))
				}

				if err := m.database.Put([]byte(bodyk), dbody); err != nil {
					m.log.Error("save block body failed", "err", err)
					panic(fmt.Sprintf("save block body failed with err %v", err))
				}
				// remove from cache.
				m.cache.Delete(bodyk)
				t2 := time.Now()
				m.log.Debug("save block body", "cost", t2.Sub(t1).Milliseconds(), "number", blockHeight.String(), "blk size", len(dbody))
			}
			if data.receipts != nil && len(data.receipts.Receipts) > 0 {
				t1 := time.Now()

				// save block receipts
				t2 := time.Now()
				m.log.Debug("save block receipts", "convert cost", t2.Sub(t1).Milliseconds())

				t1 = time.Now()
				receiptsk := blockReceiptsKey(blockHeight)
				dreceipts, err := proto.Marshal(data.receipts)
				if err != nil {
					m.log.Error("marshal block receipts", "err", err)
					panic(err)
				}
				if err := m.database.Put([]byte(receiptsk), dreceipts); err != nil {
					m.log.Error("save block receipts", "err", err)
				} else {
					// remove from cache.
					m.cache.Delete(receiptsk)
				}
			}
			if data.receipts != nil && len(data.receipts.Receipts) > 0 {
				type itemData struct {
					hash  common.Hash
					entry txEntry
				}
				t1 := time.Now()
				batch := m.database.NewBatch()
				number := block.Height()
				{
					sequence := func(receipts []*nebulav1.TransactionReceipt) []interface{} {
						s := make([]interface{}, len(receipts))
						for i, r := range receipts {
							hash := common.BytesToHash(r.Hash)
							s[i] = &itemData{hash: hash, entry: txEntry{
								BlockNumber: number,
								Index:       int64(i),
							}}
						}
						return s
					}
					handler := func(item interface{}, chanIdx uint) {
						idata := item.(*itemData)
						batch.Put([]byte(txEntryKey(idata.hash)), idata.entry.Bytes())
						m.txCache.Remove(idata.hash)
						m.receiptCache.Remove(idata.hash)
					}
					{
						t1 := time.Now()
						seq := sequence(data.receipts.Receipts)
						t2 := time.Now()
						m.log.Debug("save block txentry sequence cost", t2.Sub(t1).Milliseconds())

						grouptask.DoMultiTasks(4, handler, seq...)
					}
				}

				t2 := time.Now()
				m.log.Debug("save block txentry batch put cost", t2.Sub(t1).Milliseconds())
				// save txhash -> entry with multi routine
				if err := batch.Write(); err != nil {
					m.log.Error("save block txentry batch write error", "err", err)
					panic(err)
				}
				t3 := time.Now()
				m.log.Debug("save block txentry batch write cost", t3.Sub(t1).Milliseconds())
			}
			{
				// save latest height as string
				k := chainHeightKey()
				m.database.Put([]byte(k), []byte(blockHeight.String())) // ethdb save height string.
			}

			m.log.Info("save block finished", "block number", block.Height())
		case <-tm.C:
			if m.startHeight != nil {
				h := m.CurrentHeight()
				for h.Cmp(m.startHeight.Add(m.startHeight, big10)) > 0 {
					hk := blockHeaderKey(m.startHeight)
					m.cache.Delete(hk)
					m.startHeight = new(uint256.Int).AddUint64(m.startHeight, 1)
				}
			}
			tm.Reset(duration)
		}
	}
}

func (m *chaindb) storeTask() {
	go m.chainDataSaveTask()
}

func (m *chaindb) toStoreChainData(data chainData) {
	m.toSaveData <- data
}

func (m *chaindb) blockNumberByHash(hash common.Hash) *uint256.Int {
	k := blockNumKey(hash)
	if number, exist := m.cache.Get(k); exist {
		return number.(*uint256.Int)
	} else {
		d, err := m.database.Get([]byte(k))
		if err != nil {
			return nil
		}
		n := new(uint256.Int)
		n.SetBytes(d)
		return n
	}
}

func (m *chaindb) getBlockTxs(num *uint256.Int) *nebulav1.TransactionList {
	blockbody := m.GetBlockBody(num)
	if blockbody == nil {
		return nil
	}
	return blockbody.Transactions
}

func (m *chaindb) getTransaction(num *uint256.Int, index int) (*nebulav1.Transaction, error) {
	blockTxs := m.getBlockTxs(num)
	if blockTxs == nil {
		return nil, errors.New("transaction not found")
	}
	if index >= len(blockTxs.Txs) || index < 0 {
		return nil, errors.New("transaction not found")
	}

	return blockTxs.Txs[index], nil
}

func (m *chaindb) getBlockReceipts(num *uint256.Int) *nebulav1.TransactionReceiptList {
	k := blockReceiptsKey(num)
	d, err := m.database.Get([]byte(k))
	if err != nil {
		m.log.Error("GetBlockReceipts failed, ", err)
		return nil
	}
	receipts := new(nebulav1.TransactionReceiptList)
	if err := proto.Unmarshal(d, receipts); err != nil {
		m.log.Error("GetBlockReceipts failed, ", err)
		return nil
	}
	return receipts
}

func (m *chaindb) getReceipt(num *uint256.Int, index int) *nebulav1.TransactionReceipt {
	blockReceipts := m.getBlockReceipts(num)
	if blockReceipts == nil {
		return nil
	}
	if index >= len(blockReceipts.Receipts) || index < 0 {
		return nil
	}
	return blockReceipts.Receipts[index]
}

func (m *chaindb) cacheChainHeight(n *uint256.Int) {
	if m.startHeight == nil {
		m.startHeight = new(uint256.Int).Set(n)
	}
	m.height.Store(new(uint256.Int).Set(n)) // cache save uint256.Int
}

func (m *chaindb) cacheBlockHeader(header *nebulav1.BlockHeader) error {
	k := blockHeaderKey(uint256.NewInt(header.Height))
	m.cache.Set(k, header)
	return nil
}

func (m *chaindb) cacheBlock(wBlk *wrapper.BlkWrapper) error {
	k := blockBodyKey(uint256.NewInt(wBlk.Height()))
	m.cache.Set(k, wBlk.Block())
	return nil
}

func (m *chaindb) cacheBlockNumber(wBlk *wrapper.BlkWrapper) error {
	k := blockNumKey(wBlk.Hash())
	number := uint256.NewInt(wBlk.Height())
	m.cache.Set(k, number)
	return nil
}

func (m *chaindb) cacheBlockTxsInfo(txs *nebulav1.TransactionList, rs *nebulav1.TransactionReceiptList) error {
	if txs == nil || rs == nil {
		return nil
	}
	if len(txs.Txs) != len(rs.Receipts) {
		return errors.New("txs and receipts not match")
	}
	for i, tx := range txs.Txs {
		receipt := rs.Receipts[i]
		txhash := common.BytesToHash(receipt.Hash)
		m.txCache.Add(txhash, tx)
		m.receiptCache.Add(txhash, receipt)
	}
	return nil
}

func (m *chaindb) cacheReceipts(rs []*nebulav1.TransactionReceipt) error {
	for _, receipt := range rs {
		m.receiptCache.Add(receipt.Hash, receipt)
	}
	return nil
}

func (m *chaindb) EmitChain(block *nebulav1.Block, hash common.Hash) {
	var txCount = int64(0)
	if block.Transactions != nil {
		txCount = int64(len(block.Transactions.Txs))
	}
	event := exchain.ChainEvent{
		Block:     block,
		BlockHash: hash,
		TxCount:   txCount,
	}
	m.chainFeed.Send(event)
}

func (m *chaindb) ResetHeight(height *uint256.Int, clear bool) error {
	cur := m.CurrentHeight()
	if cur.Cmp(height) < 0 {
		return errors.New("given height too high")
	}
	if clear {
		batch := m.database.NewBatch()
		newh := uint256.NewInt(0).Set(&cur)
		for ; newh.Cmp(height) > 0; newh = new(uint256.Int).SubUint64(newh, 1) {
			receipts := m.GetBlockReceipts(newh)
			if receipts != nil {
				for _, r := range receipts.Receipts {
					txhash := common.BytesToHash(r.Hash)
					if err := batch.Delete([]byte(txEntryKey(txhash))); err != nil {
						m.log.Debug("delete tx entry failed", "txhash", txhash.String(), "err", err)
					}
					if err := batch.Delete([]byte(transactionKey(txhash))); err != nil {
						m.log.Debug("delete tx failed", "txhash", txhash.String(), "err", err)
					}
					if err := batch.Delete([]byte(receiptKey(txhash))); err != nil {
						m.log.Debug("delete tx receipt failed", "txhash", txhash.String(), "err", err)
					}
				}
			}

			batch.Delete([]byte(blockHeaderKey(newh)))
			batch.Delete([]byte(blockBodyKey(newh)))
			m.log.Debug("reset height delete data for height", "height", newh.String())
			batch.Put([]byte(chainHeightKey()), []byte(newh.String()))
		}
		batch.Put([]byte(chainHeightKey()), []byte(height.String()))
		if err := batch.Write(); err != nil {
			return errors.New(fmt.Sprintf("reset height failed:%v", err))
		}
		m.height.Store(new(uint256.Int).Set(height))
	} else {
		if err := m.database.Put([]byte(chainHeightKey()), []byte(height.String())); err != nil {
			return errors.New(fmt.Sprintf("reset height failed:%v", err))
		}
		m.height.Store(height)
	}
	return nil
}

func (m *chaindb) SaveBlockHeader(header *nebulav1.BlockHeader) error {
	hk := blockHeaderKey(uint256.NewInt(header.Height))
	dh, err := proto.Marshal(header)
	if err != nil {
		return err
	}
	if err := m.database.Put([]byte(hk), dh); err != nil {
		m.log.Error("database save block header", "err", err)
		return err
	}
	return nil
}

func (m *chaindb) SaveBlockData(block *nebulav1.Block, rs *nebulav1.TransactionReceiptList) error {
	wblk := wrapper.NewBlkWrapper(block)
	m.cacheBlockHeader(block.Header)
	m.cacheBlock(wblk)
	m.cacheBlockNumber(wblk)

	wg := sync.WaitGroup{}

	wg.Add(1)
	go func() {
		defer wg.Done()
		// cache txs and receipt.
		m.cacheBlockTxsInfo(block.Transactions, rs)
	}()

	height := uint256.NewInt(block.Header.Height)
	m.cacheChainHeight(height)
	wg.Wait()

	data := chainData{
		block:    wblk,
		receipts: rs,
	}
	m.toStoreChainData(data)

	return nil
}

func (m *chaindb) storeTxEntry(hash common.Hash, entry txEntry) error {
	k := txEntryKey(hash)
	return m.database.Put([]byte(k), entry.Bytes())
}

func (m *chaindb) getTxEntry(hash common.Hash) (txEntry, error) {
	k := txEntryKey(hash)

	v, err := m.database.Get([]byte(k))
	if err != nil {
		return txEntry{}, err
	}
	var entry = new(txEntry)
	if err := entry.SetBytes(v); err != nil {
		return txEntry{}, err
	}
	return *entry, nil
}
