package tokenrepo

import (
	"code.wuban.net.cn/movabridge/bridge-backend/chainlist"
	"code.wuban.net.cn/movabridge/bridge-backend/constant"
	"code.wuban.net.cn/movabridge/bridge-backend/contract/token"
	"context"
	"errors"
	"fmt"
	"github.com/ethereum/go-ethereum/accounts/abi/bind"
	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/ethclient"
	"math/big"
	"strings"
	"sync"
)

type TokenInfo struct {
	Name     string `json:"name" bson:"name"`
	Symbol   string `json:"symbol" bson:"symbol"`
	Decimals int64  `json:"decimals" bson:"decimals"`
	Address  string `json:"address" bson:"address"`
}
type TokenRepo struct {
	repo      map[string]TokenInfo
	chainRepo *chainlist.ChainRepo
	mux       sync.RWMutex
}

func NewTokenRepo(chainRepo *chainlist.ChainRepo) *TokenRepo {
	return &TokenRepo{
		repo:      make(map[string]TokenInfo),
		chainRepo: chainRepo,
	}
}

func (tr *TokenRepo) GetTokenInfo(address string) (TokenInfo, bool) {
	tr.mux.RLock()
	defer tr.mux.RUnlock()
	info, ok := tr.repo[strings.ToLower(address)]
	return info, ok
}

func (tr *TokenRepo) SetTokenInfo(address string, info TokenInfo) {
	tr.mux.Lock()
	defer tr.mux.Unlock()
	tr.repo[strings.ToLower(address)] = info
}

func (tr *TokenRepo) RetriveTokenInfo(chainId int64, address string) (TokenInfo, error) {
	if info, ok := tr.GetTokenInfo(address); ok {
		return info, nil
	}
	if strings.Compare(strings.ToLower(address), strings.ToLower(constant.CoinAddress)) == 0 {
		cinfo, _ := tr.chainRepo.Get(chainId)
		info := TokenInfo{
			Name:     cinfo.NativeCurrency.Name,
			Symbol:   cinfo.NativeCurrency.Symbol,
			Decimals: int64(cinfo.NativeCurrency.Decimals),
			Address:  address,
		}
		tr.SetTokenInfo(address, info)
		return info, nil
	}
	chain, ok := tr.chainRepo.Get(chainId)
	if !ok {
		return TokenInfo{}, errors.New("chain not found")
	}
	client, err := ethclient.Dial(chain.Rpc)
	if err != nil {
		return TokenInfo{}, fmt.Errorf("fail to connect chain with url: %v, err: %v", chain.Rpc, err)
	}

	contract, err := token.NewTokenCaller(common.HexToAddress(address), client)
	if err != nil {
		return TokenInfo{}, err
	}
	callOpt := &bind.CallOpts{
		BlockNumber: nil,
		From:        common.Address{},
		Context:     context.Background(),
	}
	name, err := contract.Name(callOpt)
	if err != nil {
		return TokenInfo{}, err
	}
	symbol, err := contract.Symbol(callOpt)
	if err != nil {
		return TokenInfo{}, err
	}
	decimals, err := contract.Decimals(callOpt)
	if err != nil {
		return TokenInfo{}, err
	}
	info := TokenInfo{
		Name:     name,
		Symbol:   symbol,
		Decimals: decimals.Int64(),
		Address:  address,
	}
	tr.SetTokenInfo(address, info)
	return info, nil
}

func (tr *TokenRepo) RetriveTokenInfoAndBalance(client *ethclient.Client, chainId int64, address string, user string) (TokenInfo, *big.Int, error) {
	info := TokenInfo{}
	balance := big.NewInt(0)

	cinfo, _ := tr.chainRepo.Get(chainId)
	if strings.Compare(strings.ToLower(address), strings.ToLower(constant.CoinAddress)) == 0 {
		info = TokenInfo{
			Name:     cinfo.NativeCurrency.Name,
			Symbol:   cinfo.NativeCurrency.Symbol,
			Decimals: int64(cinfo.NativeCurrency.Decimals),
			Address:  address,
		}
		tr.SetTokenInfo(address, info)
		userBalance, err := client.BalanceAt(context.Background(), common.HexToAddress(user), nil)
		if err != nil {
			return info, balance, err
		} else {
			balance = userBalance
		}
		return info, balance, nil
	}

	contract, err := token.NewTokenCaller(common.HexToAddress(address), client)
	if err != nil {
		return info, balance, fmt.Errorf("fail to connect contract err: %v", err)
	}
	callOpt := &bind.CallOpts{
		BlockNumber: nil,
		From:        common.HexToAddress(user),
		Context:     context.Background(),
	}

	userBalance, err := contract.BalanceOf(callOpt, common.HexToAddress(user))
	if err != nil {
		return info, balance, err
	} else {
		balance = userBalance
	}

	if value, ok := tr.GetTokenInfo(address); ok {
		info = value
	} else {

		name, err := contract.Name(callOpt)
		if err != nil {
			return info, balance, err
		}
		symbol, err := contract.Symbol(callOpt)
		if err != nil {
			return info, balance, err
		}
		decimals, err := contract.Decimals(callOpt)
		if err != nil {
			return info, balance, err
		}
		info := TokenInfo{
			Name:     name,
			Symbol:   symbol,
			Decimals: decimals.Int64(),
			Address:  address,
		}
		tr.SetTokenInfo(address, info)
	}
	return info, balance, nil
}
