package chaindb

import (
	"errors"
	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/log"
	"github.com/exchain/go-exchain/exchain"
	nebulav1 "github.com/exchain/go-exchain/exchain/protocol/gen/go/nebula/v1"
	"github.com/exchain/go-exchain/exchain/wdt"
	"github.com/exchain/go-exchain/exchain/wrapper"
	"github.com/exchain/go-exchain/metadb"
	"github.com/golang/protobuf/proto"
	lru "github.com/hashicorp/golang-lru"
	"github.com/holiman/uint256"
)

type ChainReader interface {
	ChainId() (*uint256.Int, error)
	GetBlockByLabel(label ExChainBlockLabel) (*nebulav1.Block, error)
	CurrentHeight() uint256.Int
	GetOriginBlockData(*uint256.Int) ([]byte, error)
	GetBlock(*uint256.Int) *nebulav1.Block
	GetBlockHeader(*uint256.Int) *nebulav1.BlockHeader
	BlockHeaderByHash(hash common.Hash) *nebulav1.BlockHeader
	BlockByHash(hash common.Hash) *nebulav1.Block
	GetTransaction(hash common.Hash) (*nebulav1.Transaction, error)
	GetReceipt(hash common.Hash) *nebulav1.TransactionReceipt
	GetWDT(number uint64) (*wdt.WDT, error)
}

func NewChainReader(log log.Logger, database metadb.Database) ChainReader {
	chain := &chainReader{
		log:      log,
		database: database,
	}
	var err error
	if chain.txCache, err = lru.New(1000000); err != nil {
		panic(err)
	}
	if chain.receiptCache, err = lru.New(1000000); err != nil {
		panic(err)
	}
	if chain.blkCache, err = lru.New(1000); err != nil {
		panic(err)
	}
	if chain.headerCache, err = lru.New(1000); err != nil {
		panic(err)
	}
	if chain.wdtCache, err = lru.New(100); err != nil {
		log.Error("failed to create mdt cache, ", err)
	}
	return chain
}

type chainReader struct {
	log          log.Logger
	txCache      *lru.Cache
	receiptCache *lru.Cache
	blkCache     *lru.Cache
	headerCache  *lru.Cache
	wdtCache     *lru.Cache
	database     metadb.Database
	chainId      *uint256.Int
}

func (m *chainReader) GetBlockByLabel(label ExChainBlockLabel) (*nebulav1.Block, error) {
	switch label {
	case ExChainBlockLatest, ExChainBlockFinalized:
		latest := m.CurrentHeight()
		return m.GetBlock(&latest), nil
	case ExChainBlockEarliest:
		return m.GetBlock(big0), nil
	default:
		return nil, errors.New("invalid block label")
	}
}

func (m *chainReader) ChainId() (*uint256.Int, error) {
	if m.chainId != nil {
		return m.chainId, nil
	}
	k := chainIdKey()
	if v, err := m.database.Get([]byte(k)); err != nil {
		return nil, err
	} else {
		m.chainId = new(uint256.Int).SetBytes(v)
		return m.chainId, nil
	}
}

func (m *chainReader) GetTransaction(hash common.Hash) (*nebulav1.Transaction, error) {
	if tx, exist := m.txCache.Get(hash); exist {
		ptx := tx.(*nebulav1.Transaction)
		return ptx, nil
	} else {
		entry, err := m.getTxEntry(hash)
		if err != nil {
			return nil, err
		}
		return m.getTransaction(uint256.NewInt(entry.BlockNumber), int(entry.Index))
	}
}
func (m *chainReader) GetReceipt(txhash common.Hash) *nebulav1.TransactionReceipt {
	if r, exist := m.receiptCache.Get(txhash); exist {
		return r.(*nebulav1.TransactionReceipt)
	} else {
		entry, err := m.getTxEntry(txhash)
		if err != nil {
			return nil
		}
		return m.getReceipt(uint256.NewInt(entry.BlockNumber), int(entry.Index))
	}
}
func (m *chainReader) CurrentHeight() uint256.Int {
	height := uint256.NewInt(0)
	// load height string
	h, err := m.database.Get([]byte(chainHeightKey()))
	if err == nil {
		height, _ = uint256.FromDecimal(string(h))
	}
	return *height

}
func (m *chainReader) GetBlock(num *uint256.Int) *nebulav1.Block {
	return m.GetBlockBody(num)
}
func (m *chainReader) GetBlockBody(num *uint256.Int) *nebulav1.Block {
	k := blockBodyKey(num)
	if b, exist := m.blkCache.Get(k); exist {
		return b.(*nebulav1.Block)
	} else {
		d, err := m.database.Get([]byte(k))
		if err != nil {
			return nil
		}
		block := new(nebulav1.Block)
		if err = proto.Unmarshal(d, block); err != nil {
			return nil
		}

		m.blkCache.Add(k, block)

		return block
	}
}
func (m *chainReader) GetBlockHeader(num *uint256.Int) *nebulav1.BlockHeader {
	k := blockHeaderKey(num)
	if h, exist := m.headerCache.Get(k); exist {
		return h.(*nebulav1.BlockHeader)
	} else {
		d, err := m.database.Get([]byte(k))
		if err != nil {
			return nil
		}
		header := new(nebulav1.BlockHeader)
		if err = proto.Unmarshal(d, header); err != nil {
			return nil
		}
		m.headerCache.Add(k, header)
		return header
	}
}
func (m *chainReader) GetOriginBlockData(num *uint256.Int) (block []byte, err error) {
	k := blockBodyKey(num)

	if b, exist := m.blkCache.Get(k); exist {
		blk := b.(*nebulav1.Block)
		return proto.Marshal(blk)
	}
	return m.database.Get([]byte(k))
}

func (m *chainReader) BlockHeaderByHash(hash common.Hash) *nebulav1.BlockHeader {
	number := m.blockNumberByHash(hash)
	if number == nil {
		return nil
	}
	return m.GetBlockHeader(number)
}

func (m *chainReader) BlockByHash(hash common.Hash) *nebulav1.Block {
	number := m.blockNumberByHash(hash)
	if number == nil {
		return nil
	}
	return m.GetBlock(number)
}

func (m *chainReader) GetBlockTransactions(num *uint256.Int) *nebulav1.TransactionList {
	txs := m.getBlockTxs(num)
	return txs
}

func (m *chainReader) GetBlockReceipts(num *uint256.Int) *nebulav1.TransactionReceiptList {
	return m.getBlockReceipts(num)
}

func (m *chainReader) GetWDT(number uint64) (*wdt.WDT, error) {
	cp := exchain.ToCheckpoint(number)
	latest := m.CurrentHeight()
	if latest.Uint64() < number {
		return nil, errors.New("checkpoint not ready")
	}
	if t, _ := m.wdtCache.Get(cp); t != nil {
		return t.(*wdt.WDT), nil
	}
	nt, err := m.generateWDT(cp)
	if err != nil {
		return nil, err
	}
	m.wdtCache.Add(cp, nt)
	return nt, nil
}

func (m *chainReader) generateWDT(cp exchain.Checkpoint) (*wdt.WDT, error) {
	withdrawalTxs := &nebulav1.TransactionList{
		Txs: make([]*nebulav1.Transaction, 0),
	}
	trie := wdt.NewWdt()
	for i := cp.Start(); i <= cp.End(); i++ {
		blk := m.GetBlock(uint256.NewInt(i))
		if blk == nil {
			return nil, errors.New("block not found in exchain")
		}
		wblk := wrapper.NewBlkWrapper(blk)
		withdrawalTxs.Txs = append(withdrawalTxs.Txs, wblk.FilterTransactions(wrapper.TxTypeFilter(nebulav1.TxType_WithdrawTx))...)
	}
	for _, tx := range withdrawalTxs.Txs {
		if err := trie.AddTx(tx); err != nil {
			return nil, err
		}
	}

	return trie, nil
}

func (m *chainReader) blockNumberByHash(hash common.Hash) *uint256.Int {
	k := blockNumKey(hash)
	if number, exist := m.txCache.Get(k); exist {
		return number.(*uint256.Int)
	} else {
		d, err := m.database.Get([]byte(k))
		if err != nil {
			return nil
		}
		n := new(uint256.Int)
		n.SetBytes(d)
		m.txCache.Add(k, n)
		return n
	}
}

func (m *chainReader) getBlockTxs(num *uint256.Int) *nebulav1.TransactionList {
	blockbody := m.GetBlockBody(num)
	if blockbody == nil {
		return nil
	}
	return blockbody.Transactions
}

func (m *chainReader) getTransaction(num *uint256.Int, index int) (*nebulav1.Transaction, error) {
	blockTxs := m.getBlockTxs(num)
	if blockTxs == nil {
		return nil, errors.New("transaction not found")
	}
	if index >= len(blockTxs.Txs) || index < 0 {
		return nil, errors.New("transaction not found")
	}

	return blockTxs.Txs[index], nil
}

func (m *chainReader) getBlockReceipts(num *uint256.Int) *nebulav1.TransactionReceiptList {
	k := blockReceiptsKey(num)
	d, err := m.database.Get([]byte(k))
	if err != nil {
		m.log.Error("GetBlockReceipts failed, ", err)
		return nil
	}
	receipts := new(nebulav1.TransactionReceiptList)
	if err := proto.Unmarshal(d, receipts); err != nil {
		m.log.Error("GetBlockReceipts failed, ", err)
		return nil
	}
	return receipts
}

func (m *chainReader) getReceipt(num *uint256.Int, index int) *nebulav1.TransactionReceipt {
	blockReceipts := m.getBlockReceipts(num)
	if blockReceipts == nil {
		return nil
	}
	if index >= len(blockReceipts.Receipts) || index < 0 {
		return nil
	}
	return blockReceipts.Receipts[index]
}

func (m *chainReader) cacheBlockTxsInfo(txs *nebulav1.TransactionList, rs *nebulav1.TransactionReceiptList) error {
	if txs == nil || rs == nil {
		return nil
	}
	if len(txs.Txs) != len(rs.Receipts) {
		return errors.New("txs and receipts not match")
	}
	for i, tx := range txs.Txs {
		receipt := rs.Receipts[i]
		txhash := common.BytesToHash(receipt.Hash)
		m.txCache.Add(txhash, tx)
		m.receiptCache.Add(txhash, receipt)
	}
	return nil
}

func (m *chainReader) cacheReceipts(rs []*nebulav1.TransactionReceipt) error {
	for _, receipt := range rs {
		m.receiptCache.Add(receipt.Hash, receipt)
	}
	return nil
}

func (m *chainReader) storeTxEntry(hash common.Hash, entry txEntry) error {
	k := txEntryKey(hash)
	return m.database.Put([]byte(k), entry.Bytes())
}

func (m *chainReader) getTxEntry(hash common.Hash) (txEntry, error) {
	k := txEntryKey(hash)

	v, err := m.database.Get([]byte(k))
	if err != nil {
		return txEntry{}, err
	}
	var entry = new(txEntry)
	if err := entry.SetBytes(v); err != nil {
		return txEntry{}, err
	}
	return *entry, nil
}
