package proposer

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"github.com/ethereum/go-ethereum/common"
	"github.com/exchain/go-exchain/metadb"
	"math/big"
	"sync"
	"time"

	"github.com/ethereum/go-ethereum/accounts/abi"
	"github.com/ethereum/go-ethereum/accounts/abi/bind"
	"github.com/ethereum/go-ethereum/core/types"
	"github.com/ethereum/go-ethereum/log"
	"github.com/exchain/go-exchain/op-proposer/bindings"
	"github.com/exchain/go-exchain/op-service/dial"
	"github.com/exchain/go-exchain/op-service/txmgr"
)

const (
	progressKey     = "operatorprogress"
	progressTaskKey = "optask"
)

type WithdrawalContract interface {
	Version(*bind.CallOpts) (string, error)
}

// Operator is responsible for proposing withdrawal tx to the L1 chain.
type Operator struct {
	DriverSetup

	wg   sync.WaitGroup
	done chan struct{}

	ctx    context.Context
	cancel context.CancelFunc

	mutex   sync.Mutex
	running bool

	portalContract WithdrawalContract
	l2ooContract   L2OOContract
	portalABI      *abi.ABI
}

// NewOperator creates a new L2 Operator
func NewOperator(setup DriverSetup, db metadb.Database) (_ *Operator, err error) {
	ctx, cancel := context.WithCancel(context.Background())
	defer func() {
		if err != nil || recover() != nil {
			cancel()
		}
	}()

	if setup.Cfg.PortalAddr != nil {
		return newOperator(ctx, cancel, setup)
	} else {
		return nil, errors.New("the `PortalAddr` were not provided")
	}
}

type OperatorProgress struct {
	BlockNumber uint64 `json:"block_number"`
	Finished    uint64 `json:"finished"`
	Total       uint64 `json:"total"`
}

type OperatorTask struct {
	TxHash string `json:"tx_hash"`
}

type OperatorTaskList []OperatorTask

func (l *Operator) GetOperatorTaskList() OperatorTaskList {
	value, err := l.Cfg.Database.Get([]byte(progressTaskKey))
	if err != nil {
		return OperatorTaskList{}
	}
	var taskList OperatorTaskList
	if err = json.Unmarshal(value, &taskList); err != nil {
		l.Log.Error("failed to unmarshal operator task list", "err", err)
	}
	return taskList
}

func (l *Operator) SetOperatorTaskList(taskList OperatorTaskList) {
	value, err := json.Marshal(taskList)
	if err != nil {
		l.Log.Error("failed to marshal operator task list", "err", err)
		return
	}
	if err = l.Cfg.Database.Put([]byte(progressTaskKey), value); err != nil {
		l.Log.Error("failed to set operator task list", "err", err)
		return
	}
}

func (l *Operator) GetProgress() OperatorProgress {
	value, err := l.Cfg.Database.Get([]byte(progressKey))
	if err != nil {
		return OperatorProgress{}
	}
	var progress OperatorProgress
	if err = json.Unmarshal(value, &progress); err != nil {
		l.Log.Error("failed to unmarshal operator progress", "err", err)
	}
	return progress
}

func (l *Operator) SetProgress(progress OperatorProgress) {
	value, err := json.Marshal(progress)
	if err != nil {
		l.Log.Error("failed to marshal operator progress", "err", err)
		return
	}
	if err = l.Cfg.Database.Put([]byte(progressKey), value); err != nil {
		l.Log.Error("failed to set operator progress", "err", err)
	}
}

func newOperator(ctx context.Context, cancel context.CancelFunc, setup DriverSetup) (*Operator, error) {
	l2ooContract, err := bindings.NewL2OutputOracleCaller(*setup.Cfg.L2OutputOracleAddr, setup.L1Client)
	if err != nil {
		cancel()
		return nil, fmt.Errorf("failed to create L2OO at address %s: %w", setup.Cfg.L2OutputOracleAddr, err)
	}

	portalContract, err := bindings.NewPortalCaller(*setup.Cfg.PortalAddr, setup.L1Client)
	if err != nil {
		cancel()
		return nil, fmt.Errorf("failed to create L2OO at address %s: %w", setup.Cfg.PortalAddr, err)
	}

	cCtx, cCancel := context.WithTimeout(ctx, setup.Cfg.NetworkTimeout)
	defer cCancel()
	version, err := portalContract.Version(&bind.CallOpts{Context: cCtx})
	if err != nil {
		cancel()
		return nil, err
	}
	log.Info("Connected to Portal", "address", setup.Cfg.PortalAddr, "version", version)

	parsed, err := bindings.PortalMetaData.GetAbi()
	if err != nil {
		cancel()
		return nil, err
	}

	return &Operator{
		DriverSetup: setup,
		done:        make(chan struct{}),
		ctx:         ctx,
		cancel:      cancel,

		portalContract: portalContract,
		l2ooContract:   l2ooContract,
		portalABI:      parsed,
	}, nil
}

func (l *Operator) StartOperator() error {
	l.Log.Info("Starting Proposer operator")

	l.mutex.Lock()
	defer l.mutex.Unlock()

	if l.running {
		return errors.New("proposer operator is already running")
	}
	l.running = true

	if l.Cfg.WaitNodeSync {
		err := l.waitNodeSync()
		if err != nil {
			return fmt.Errorf("error waiting for node sync: %w", err)
		}
	}

	l.wg.Add(1)
	go l.loop()

	l.Log.Info("Proposer operator started")
	return nil
}

func (l *Operator) StopOperatorIfRunning() error {
	err := l.StopOperator()
	if errors.Is(err, ErrProposerNotRunning) {
		return nil
	}
	return err
}

func (l *Operator) getL2ooIndex(callOpts *bind.CallOpts, l2BlockNumber uint64) (*big.Int, error) {
	index, err := l.l2ooContract.GetL2OutputIndexAfter(callOpts, big.NewInt(int64(l2BlockNumber)))
	if err != nil {
		l.Log.Error("failed to get l2 output index after", "err", err, "l2blk", l2BlockNumber)
		return nil, err
	}
	return index, nil
}

func (l *Operator) StopOperator() error {
	l.Log.Info("Stopping Proposer operator")

	l.mutex.Lock()
	defer l.mutex.Unlock()

	if !l.running {
		return ErrProposerNotRunning
	}
	l.running = false

	l.cancel()
	close(l.done)
	l.wg.Wait()

	l.Log.Info("Proposer operator stopped")
	return nil
}

func (l *Operator) DoOperator(ctx context.Context) {
	progress := l.GetProgress()

	cCtx, cancel := context.WithTimeout(ctx, l.Cfg.NetworkTimeout)
	defer cancel()
	callOpts := &bind.CallOpts{
		From:    l.Txmgr.From(),
		Context: cCtx,
	}
	// 1. check current latest submitted block number on l2oo.
	l2Latest, err := l.l2ooContract.LatestBlockNumber(callOpts)
	if err != nil {
		l.Log.Error("failed to get latest block number from l2oo", "err", err)
		return
	}
	l.Log.Info("operator latest block number from l2oo", "l2Latest", l2Latest, "progress", progress)
	currentFinished := false
	if progress.Total == 0 || progress.Finished == progress.Total {
		currentFinished = true
	}
	// 2. check current process is finished.
	rollupClient, err := l.RollupProvider.RollupClient(ctx)
	if err != nil {
		l.Log.Error("failed to get rollup client in withdrawal tx operator", "err", err)
		return
	}
	if !currentFinished { // if current progress is not finished, we need finish it first.
		// get saved withdrawal tx hashes from db, and iterator to do withdrawal.
		ooIndex, err := l.getL2ooIndex(callOpts, progress.BlockNumber)
		if err != nil {
			l.Log.Error("failed to get l2oo index and output", "err", err)
			return
		}

		task := l.GetOperatorTaskList()
		if uint64(len(task)) != progress.Total {
			l.Log.Warn("task list length is not equal to progress total", "task count", len(task), "progress", progress)
			progress.Total = uint64(len(task))
		}

		finished := progress.Finished
		for finished < progress.Total {
			tx := task[finished]

			proof, err := rollupClient.WithdrawalProof(ctx, common.HexToHash(tx.TxHash))
			if err != nil {
				l.Log.Error("failed to get withdrawal proof", "err", err)
				return
			} else {
				//l.Log.Info("withdrawal proof", "tx", tx.TxHash, "proof", proof)

				withdrawalProof := make([][]byte, len(proof.Proof))
				for i, p := range proof.Proof {
					withdrawalProof[i] = common.FromHex(p)
				}

				// call contract to do withdrawal.
				params := bindings.TypesBatchExChainWithdrawalParam{
					WithdrawalProof: withdrawalProof,
					Value:           proof.Value,
					User:            proof.User,
					Coin:            []byte(proof.Coin),
					TxHash:          common.HexToHash(tx.TxHash),
				}
				paramOne := bindings.TypesExChainWithdrawalParam{
					Value:  proof.Value,
					User:   proof.User,
					Coin:   []byte(proof.Coin),
					TxHash: common.HexToHash(tx.TxHash),
				}

				ooProof := bindings.TypesOutputRootProof{
					StateRoot:                proof.Output.StateRoot,
					MessagePasserStorageRoot: proof.Output.MessagePasserStorageRoot,
					LatestBlockhash:          proof.Output.BlockHash,
				}
				err = l.doWithdrawalOne(ctx, paramOne, ooIndex, ooProof, withdrawalProof)
				_ = params
				//err = l.doWithdrawalBatch(ctx, []bindings.TypesBatchExChainWithdrawalParam{params}, ooIndex, ooProof)
				if err != nil {
					l.Log.Error("failed to do withdrawal", "err", err)
					return
				} else {
					finished++
					// update process.
					l.SetProgress(OperatorProgress{
						BlockNumber: progress.BlockNumber,
						Finished:    finished,
						Total:       progress.Total,
					})
				}
			}
		}
	}
	// 3. check local progress is less than the latest submitted block number.
	if progress.BlockNumber >= l2Latest.Uint64() {
		l.Log.Debug("progress is reach latest")
		return
	}
	// 4. to process next block.
	for blkNum := progress.BlockNumber + 1; blkNum <= l2Latest.Uint64(); blkNum++ {

		ooIndex, err := l.getL2ooIndex(callOpts, blkNum)
		if err != nil {
			l.Log.Error("operator failed to get l2oo index and output", "err", err)
			return
		}

		withDrawalTxs := make([]common.Hash, 0)
		if txs, err := rollupClient.WithdrawalTxs(ctx, blkNum); err != nil {
			l.Log.Error("failed to get withdrawal txs", "blknum", blkNum, "err", err)
			return
		} else {
			if len(txs) == 0 {
				l.Log.Info("ignore block no txs to withdrawal", "blknum", blkNum)
				l.SetOperatorTaskList(OperatorTaskList{})
				l.SetProgress(OperatorProgress{
					BlockNumber: blkNum,
					Finished:    0,
					Total:       0,
				})
				continue
			}
			withDrawalTxs = append(withDrawalTxs, txs...)
		}
		newProgress := OperatorProgress{
			BlockNumber: blkNum,
			Finished:    0,
			Total:       uint64(len(withDrawalTxs)),
		}
		tasks := make([]OperatorTask, 0)
		for _, tx := range withDrawalTxs {
			task := OperatorTask{
				TxHash: tx.Hex(),
			}
			tasks = append(tasks, task)
		}
		l.SetOperatorTaskList(tasks)
		l.SetProgress(newProgress)

		// 5. to process each withdrawal tx.
		finished := newProgress.Finished

		for finished < newProgress.Total {
			tx := tasks[finished]

			proof, err := rollupClient.WithdrawalProof(ctx, common.HexToHash(tx.TxHash))
			if err != nil {
				l.Log.Error("failed to get withdrawal proof", "err", err)
				break
			} else {
				withdrawalProof := make([][]byte, len(proof.Proof))
				for i, p := range proof.Proof {
					withdrawalProof[i] = common.FromHex(p)
				}

				// call contract to do withdrawal.
				params := bindings.TypesBatchExChainWithdrawalParam{
					WithdrawalProof: withdrawalProof,
					Value:           proof.Value,
					User:            proof.User,
					Coin:            []byte(proof.Coin),
					TxHash:          common.HexToHash(tx.TxHash),
				}
				paramOne := bindings.TypesExChainWithdrawalParam{
					Value:  proof.Value,
					User:   proof.User,
					Coin:   []byte(proof.Coin),
					TxHash: common.HexToHash(tx.TxHash),
				}

				ooProof := bindings.TypesOutputRootProof{
					StateRoot:                proof.Output.StateRoot,
					MessagePasserStorageRoot: proof.Output.MessagePasserStorageRoot,
					LatestBlockhash:          proof.Output.BlockHash,
				}
				err = l.doWithdrawalOne(ctx, paramOne, ooIndex, ooProof, withdrawalProof)
				_ = params
				//err = l.doWithdrawalBatch(ctx, []bindings.TypesBatchExChainWithdrawalParam{params}, ooIndex, ooProof)
				if err != nil {
					l.Log.Error("failed to do withdrawal", "err", err)
					return
				} else {
					finished++
					// update process.
					l.SetProgress(OperatorProgress{
						BlockNumber: newProgress.BlockNumber,
						Finished:    finished,
						Total:       newProgress.Total,
					})
				}
			}
		}
	}
}

func (l *Operator) BatchExChainWithdrawalTxData(params []bindings.TypesBatchExChainWithdrawalParam, l2OutputIndex *big.Int, outputRootProof bindings.TypesOutputRootProof) ([]byte, error) {
	return l.portalABI.Pack("batchExChainWithdrawal", params, l2OutputIndex, outputRootProof)
}

func (l *Operator) ExChainWithdrawalTxData(param bindings.TypesExChainWithdrawalParam, l2OutputIndex *big.Int, outputRoot bindings.TypesOutputRootProof, proof [][]byte) ([]byte, error) {
	return l.portalABI.Pack("exChainWithdrawal", param, l2OutputIndex, outputRoot, proof)
}

// sendTransaction creates & sends transactions through the underlying transaction manager.
func (l *Operator) sendTransaction(ctx context.Context, params []bindings.TypesBatchExChainWithdrawalParam, l2OutputIndex *big.Int, outputRootProof bindings.TypesOutputRootProof) error {
	l.Log.Info("DoWithdrawal", "l2OutputIndex", l2OutputIndex)
	var receipt *types.Receipt
	{
		data, err := l.BatchExChainWithdrawalTxData(params, l2OutputIndex, outputRootProof)
		if err != nil {
			return err
		}
		//l.Log.Info("BatchExChainWithdrawalTxData", "data", hex.EncodeToString(data))
		receipt, err = l.Txmgr.Send(ctx, txmgr.TxCandidate{
			TxData:   data,
			To:       l.Cfg.PortalAddr,
			GasLimit: 10000000,
		})
		if err != nil {
			return err
		}
	}

	if receipt.Status == types.ReceiptStatusFailed {
		l.Log.Error("Proposer withdrawal tx successfully published but reverted", "tx_hash", receipt.TxHash)
	} else {
		l.Log.Info("Proposer withdrawal tx successfully published",
			"tx_hash", receipt.TxHash)
	}
	return nil
}

// sendTransaction creates & sends transactions through the underlying transaction manager.
func (l *Operator) sendTransactionOne(ctx context.Context, param bindings.TypesExChainWithdrawalParam, l2OutputIndex *big.Int, outputRoot bindings.TypesOutputRootProof, proof [][]byte) error {
	//l.Log.Info("DoWithdrawalOne", "l2OutputIndex", l2OutputIndex)
	var receipt *types.Receipt
	{
		data, err := l.ExChainWithdrawalTxData(param, l2OutputIndex, outputRoot, proof)
		if err != nil {
			return err
		}
		//l.Log.Info("ExChainWithdrawalTxData", "data", hex.EncodeToString(data))
		receipt, err = l.Txmgr.Send(ctx, txmgr.TxCandidate{
			TxData:   data,
			To:       l.Cfg.PortalAddr,
			GasLimit: 10000000,
		})
		if err != nil {
			return err
		}
	}

	if receipt.Status == types.ReceiptStatusFailed {
		l.Log.Error("Proposer withdrawal tx successfully published but reverted", "tx_hash", receipt.TxHash)
	} else {
		l.Log.Info("Proposer withdrawal tx successfully published",
			"tx_hash", receipt.TxHash)
	}
	return nil
}

// loop is responsible for creating & submitting the next outputs
// The loop regularly polls the L2 chain to infer whether to make the next proposal.
func (l *Operator) loop() {
	l.Log.Info("Operator loop started")
	defer l.wg.Done()
	defer l.Log.Info("loop returning")
	ctx := l.ctx
	ticker := time.NewTicker(l.Cfg.PollInterval)
	defer ticker.Stop()
	for {
		select {
		case <-ticker.C:
			// prioritize quit signal
			select {
			case <-l.done:
				return
			default:
			}
			l.DoOperator(ctx)
		case <-l.done:
			return
		}
	}

}

func (l *Operator) waitNodeSync() error {
	cCtx, cancel := context.WithTimeout(l.ctx, l.Cfg.NetworkTimeout)
	defer cancel()

	l1head, err := l.Txmgr.BlockNumber(cCtx)
	if err != nil {
		return fmt.Errorf("failed to retrieve current L1 block number: %w", err)
	}

	rollupClient, err := l.RollupProvider.RollupClient(l.ctx)
	if err != nil {
		return fmt.Errorf("failed to get rollup client: %w", err)
	}

	return dial.WaitRollupSync(l.ctx, l.Log, rollupClient, l1head, time.Second*12)
}

func (l *Operator) doWithdrawalBatch(ctx context.Context, params []bindings.TypesBatchExChainWithdrawalParam, l2OutputIndex *big.Int, outputRootProof bindings.TypesOutputRootProof) error {
	cCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
	defer cancel()

	if err := l.sendTransaction(cCtx, params, l2OutputIndex, outputRootProof); err != nil {
		l.Log.Error("Failed to send proposal withdrawal transaction",
			"err", err)
		return err
	}
	return nil
}

func (l *Operator) doWithdrawalOne(ctx context.Context, param bindings.TypesExChainWithdrawalParam, l2OutputIndex *big.Int, outputRoot bindings.TypesOutputRootProof, proof [][]byte) error {
	cCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
	defer cancel()

	if err := l.sendTransactionOne(cCtx, param, l2OutputIndex, outputRoot, proof); err != nil {
		l.Log.Error("Failed to send proposal withdrawal transaction",
			"err", err)
		return err
	}
	return nil
}
