package server

import (
	"bytes"
	"encoding/json"
	"errors"
	"fmt"
	"github.com/ethereum/go-ethereum/common"
	"github.com/golang/protobuf/proto"
	"github.com/google/uuid"
	lru "github.com/hashicorp/golang-lru"
	"github.com/odysseus/nodemanager/config"
	"github.com/odysseus/nodemanager/standardlib"
	"github.com/odysseus/nodemanager/utils"
	odysseus "github.com/odysseus/odysseus-protocol/gen/proto/go/base/v1"
	omanager "github.com/odysseus/odysseus-protocol/gen/proto/go/nodemanager/v1"
	"github.com/redis/go-redis/v9"
	log "github.com/sirupsen/logrus"
	"golang.org/x/crypto/sha3"
	"strconv"
	"sync"
	"time"
)

var (
	Succeed             = errors.New("succeed")
	ErrWorkerExist      = errors.New("worker exist")
	ErrHeartBeatExpired = errors.New("worker heartbeat expired")
)

type dispatchTask struct {
	task  *odysseus.TaskContent
	errCh chan error
}

type Worker struct {
	quit           chan interface{}
	taskCh         chan *dispatchTask
	resultCh       chan *omanager.SubmitTaskResult
	uuid           int64
	publicKey      string
	addr           string
	benefitAddr    string
	status         []byte
	online         bool
	usageInfo      []*omanager.DeviceUsage
	deviceInfo     []*omanager.DeviceInfo
	deviceInfoHash []byte
	recentTask     *lru.Cache
	stream         omanager.NodeManagerService_RegisterWorkerServer
}

func (w *Worker) ProfitAccount() common.Address {
	return common.HexToAddress(w.benefitAddr)
}

func (w *Worker) WorkerAccount() common.Address {
	return common.HexToAddress(w.addr)
}

type WorkerManager struct {
	rdb       *redis.Client
	heartBeat map[int64]int64
	hbRwLock  sync.RWMutex

	workers  map[int64]*Worker
	workid   map[string]*Worker
	wkRwLock sync.RWMutex
	quit     chan struct{}

	node *Node
	std  *standardlib.StandardTasks
}

func NewWorkerManager(rdb *redis.Client, node *Node) *WorkerManager {
	return &WorkerManager{
		heartBeat: make(map[int64]int64),
		workers:   make(map[int64]*Worker),
		workid:    make(map[string]*Worker),
		quit:      make(chan struct{}),
		rdb:       rdb,
		node:      node,
		std:       standardlib.NewStandardTasks(),
	}
}

func (wm *WorkerManager) Stop() {
	close(wm.quit)
}

func (wm *WorkerManager) UpdateHeartBeat(uuid int64) {
	wm.hbRwLock.Lock()
	defer wm.hbRwLock.Unlock()
	wm.heartBeat[uuid] = time.Now().Unix()
}

func (wm *WorkerManager) UpdateStatus(worker *Worker) {

}

func (wm *WorkerManager) GetHeartBeat(uuid int64) int64 {
	wm.hbRwLock.RLock()
	defer wm.hbRwLock.RUnlock()
	return wm.heartBeat[uuid]
}

func (wm *WorkerManager) GetWorker(uuid int64) *Worker {
	wm.wkRwLock.RLock()
	defer wm.wkRwLock.RUnlock()

	return wm.workers[uuid]
}

func (wm *WorkerManager) SetWorkerAddr(worker *Worker, addr string) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()

	worker.addr = addr
	wm.workid[addr] = worker
}

func (wm *WorkerManager) GetWorkerByAddr(addr string) *Worker {
	wm.wkRwLock.RLock()
	defer wm.wkRwLock.RUnlock()

	return wm.workid[addr]
}

func (wm *WorkerManager) AddNewWorker(uuid int64, worker omanager.NodeManagerService_RegisterWorkerServer) (*Worker, error) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()

	if _, exist := wm.workers[uuid]; exist {
		return nil, ErrWorkerExist
	}
	w := &Worker{
		taskCh:   make(chan *dispatchTask),
		resultCh: make(chan *omanager.SubmitTaskResult),
		uuid:     uuid,
		stream:   worker,
		quit:     make(chan interface{}),
	}
	taskCache, err := lru.New(100)
	if err != nil {
		return nil, err
	}
	w.recentTask = taskCache
	wm.workers[uuid] = w
	go wm.handleWorkerMsg(w)

	return w, nil
}

type Callback func(err error) bool

func (wm *WorkerManager) doCallback(hook string, response *odysseus.TaskResponse) {
	d, err := proto.Marshal(response)
	if err != nil {
		log.WithError(err).Error("marshal task response failed")
	} else {
		log.WithField("task-id", response.TaskId).Debug("marshal task response")
	}
	err = utils.Post(hook, d)
	if err != nil {
		log.WithError(err).Error("post task result failed")
	} else {
		log.WithField("task-id", response.TaskId).Debug("post task result")
	}
}

func (wm *WorkerManager) manageWorker(worker *Worker) error {

	log.WithField("worker", worker.uuid).Info("start manage worker")
	tickerConf := config.GetConfig().Tickers

	initialHeartBeatInterval := time.Second * 1
	initialInterval := initialHeartBeatInterval * 2

	heartBeatDuration := time.Second * time.Duration(tickerConf.HeartBeat)
	workerCheckDuration := heartBeatDuration * 3

	heartBeatTicker := time.NewTicker(initialHeartBeatInterval)
	defer heartBeatTicker.Stop()

	workerCheckTicker := time.NewTicker(workerCheckDuration)
	defer workerCheckTicker.Stop()

	statusTicker := time.NewTicker(initialInterval)
	defer statusTicker.Stop()

	deviceInfoTicker := time.NewTicker(initialInterval)
	defer deviceInfoTicker.Stop()

	deviceUsageTicker := time.NewTicker(initialInterval)
	defer deviceUsageTicker.Stop()

	defer func() {
		log.WithFields(log.Fields{
			"worker-addr": worker.addr,
			"worker-uuid": worker.uuid,
		}).Info("exit manage worker")
		worker.online = false

		wm.InActiveWorker(worker)
	}()

	for {
		var msg = new(omanager.ManagerMessage)
		var callback = Callback(func(err error) bool {
			// do nothing
			return true
		})

		select {
		case <-wm.quit:
			gb := new(omanager.ManagerMessage_GoodbyeMessage)
			gb.GoodbyeMessage = &omanager.GoodbyeMessage{}
			msg.Message = gb
		case <-worker.quit:
			return nil

		case <-workerCheckTicker.C:
			if worker.deviceInfo != nil && worker.addr != "" {
				deviceInfoTicker.Reset(time.Second * time.Duration(tickerConf.DeviceInfoTicker))
			}
			if worker.status != nil {
				statusTicker.Reset(time.Second * time.Duration(tickerConf.StatusTicker))
			}
			if worker.usageInfo != nil {
				deviceUsageTicker.Reset(time.Second * time.Duration(tickerConf.DeviceUsageTicker))
			}
			if time.Now().Unix()-wm.GetHeartBeat(worker.uuid) > int64(workerCheckDuration.Seconds()) {
				wm.InActiveWorker(worker)
				// todo: remove worker
				return ErrHeartBeatExpired
			}

		case <-heartBeatTicker.C:
			hb := new(omanager.ManagerMessage_HeartbeatRequest)
			hb.HeartbeatRequest = &omanager.HeartbeatRequest{
				Timestamp: uint64(time.Now().Unix()),
			}
			heartBeatTicker.Reset(heartBeatDuration)
			msg.Message = hb
			callback = func(err error) bool {
				log.WithField("worker", worker.uuid).Info("send hear beat to worker")
				return true
			}

		case <-deviceInfoTicker.C:
			deviceInfo := new(omanager.ManagerMessage_DeviceRequest)
			deviceInfo.DeviceRequest = &omanager.DeviceInfoRequest{}
			msg.Message = deviceInfo
			callback = func(err error) bool {

				return true
			}

		case <-deviceUsageTicker.C:
			deviceUsage := new(omanager.ManagerMessage_DeviceUsage)
			deviceUsage.DeviceUsage = &omanager.DeviceUsageRequest{}
			msg.Message = deviceUsage
			callback = func(err error) bool {

				return true
			}

		case <-statusTicker.C:
			status := new(omanager.ManagerMessage_StatusRequest)
			status.StatusRequest = &omanager.StatusRequest{}
			msg.Message = status
			callback = func(err error) bool {

				return true
			}

		case dtask, ok := <-worker.taskCh:
			if !ok {
				return nil
			}
			task := dtask.task
			taskMsg := new(omanager.ManagerMessage_PushTaskMessage)
			taskMsg.PushTaskMessage = &omanager.PushTaskMessage{
				TaskId:    task.TaskId,
				TaskType:  task.TaskType,
				TaskKind:  task.TaskKind,
				Workload:  uint64(task.TaskWorkload),
				TaskCmd:   task.TaskCmd,
				TaskParam: task.TaskParam,
			}
			msg.Message = taskMsg
			callback = func(err error) bool {
				if err == nil {
					// add task to cache.
					worker.recentTask.Add(task.TaskId, task)
				}
				log.WithField("worker", worker.uuid).Info("dispatch task to worker")

				select {
				case dtask.errCh <- err:
				default:
					// err ch is invalid
				}
				return true
			}
		case result := <-worker.resultCh:
			// verify result and make a new signature.
			data, exist := worker.recentTask.Get(result.TaskId)
			if !exist {
				log.WithField("worker", worker.uuid).Error("task not found for verify result")
				continue
			}

			task := data.(*odysseus.TaskContent)
			if result.TaskId != task.TaskId {
				log.WithField("worker", worker.uuid).Error("task id not match")
				continue
			}

			proof, err := wm.taskResult(worker, task, result)
			if err != nil {
				continue
			}
			msg.Message = proof

			callback = func(err error) bool {
				// remove task from cache.
				worker.recentTask.Remove(result.TaskId)
				if task.TaskKind != odysseus.TaskKind_StandardTask {
					_ = wm.AddWorkerSingle(worker)
					wm.Payment(task)
				}
				return true
			}
		}

		if msg != nil {
			err := worker.stream.Send(msg)
			if err != nil {
				log.WithError(err).Error("send message to worker failed")
			}
			callback(err)
		}
	}

	return nil
}
func (wm *WorkerManager) handleWorkerMsg(worker *Worker) {
	l := log.WithField("worker-uuid", worker.uuid)
	l.WithField("worker-addr", worker.addr).Info("start handle worker message")
	defer l.WithField("worker-addr", worker.addr).Info("exit handle worker message")
	for {
		select {
		case <-wm.quit:
			return
		case <-worker.quit:
			return
		default:
			wmsg, err := worker.stream.Recv()
			if err != nil {
				l.WithError(err).WithField("worker-addr", worker.addr).Error("recv msg failed")
				close(worker.quit)
				return
			}
			switch msg := wmsg.Message.(type) {
			case *omanager.WorkerMessage_GoodbyeMessage:
				worker.online = false
				worker.quit <- msg.GoodbyeMessage.Reason
				close(worker.taskCh)
				return
			case *omanager.WorkerMessage_SubmitTaskResult:
				worker.resultCh <- msg.SubmitTaskResult
			case *omanager.WorkerMessage_HeartbeatResponse:
				worker.online = true
				wm.UpdateHeartBeat(worker.uuid)
				l.WithFields(log.Fields{
					"worker-addr": worker.addr,
					"hearBeat":    time.Now().Unix() - int64(msg.HeartbeatResponse.Timestamp),
				}).Debug("receive worker heartbeat")
			case *omanager.WorkerMessage_Status:
				// todo: store worker status
				worker.status = msg.Status.DeviceStatus
				l.WithFields(log.Fields{
					"worker-addr": worker.addr,
				}).Debugf("receive worker status:0x%x", msg.Status.DeviceStatus)
			case *omanager.WorkerMessage_ResourceMap:
				// todo: store worker resource map.
				l.WithFields(log.Fields{
					"worker-addr": worker.addr,
				}).Debugf("receive worker resource map:%v", msg.ResourceMap)
			case *omanager.WorkerMessage_FetchStandardTask:
				l.WithFields(log.Fields{
					"worker-addr": worker.addr,
				}).Debugf("receive worker fetch std task request:%v", msg.FetchStandardTask.TaskType)
				pushTask := standardlib.StdTask{}
				task, exist := wm.std.GetTask(msg.FetchStandardTask.TaskType)
				if exist {
					stdlib := wm.std.GetStdLib(task.TaskType)
					if stdlib == nil {
						l.WithField("task-type", task.TaskType).Warn("not found std lib")
						continue
					}
					pushTask = task
					pushTask.TaskId = uuid.NewString()
					param, err := stdlib.GenerateParam(0)
					if err != nil {
						l.WithError(err).WithField("task-type", task.TaskType).Error("generate param failed")
						continue
					}
					pushTask.TaskParam = []byte(param)
					pushTask.TaskInLen = int32(len(param))
					pushTask.TaskUid = "0"
					pushTask.TaskTimestamp = uint64(time.Now().UnixNano())
					pushTask.TaskKind = odysseus.TaskKind_StandardTask
					pushTask.TaskFee = "0"
					worker.taskCh <- &dispatchTask{
						task:  &pushTask.TaskContent,
						errCh: make(chan error, 1),
					}
					break
				} else {
					l.WithField("task-type", msg.FetchStandardTask.TaskType).Warn("not found std task")
				}

			case *omanager.WorkerMessage_DeviceInfo:
				// todo: handler worker device info
				l.WithFields(log.Fields{
					"worker-addr": worker.addr,
				}).Debugf("receive worker device info:%v", msg.DeviceInfo)
				{
					// receive device info
					worker.online = true
					worker.publicKey = msg.DeviceInfo.MinerPubkey
					worker.deviceInfo = msg.DeviceInfo.Devices
					worker.benefitAddr = msg.DeviceInfo.BenefitAddress
					var addr = ""
					if pubkey, err := utils.HexToPubkey(worker.publicKey); err != nil {
						l.WithFields(log.Fields{
							"worker-addr": worker.addr,
							"error":       err,
						}).Error("parse pubkey failed")
					} else {
						addr = utils.PubkeyToAddress(pubkey)
					}
					if addr == worker.addr {
						// addr is not change.
						continue
					}

					if worker.addr != "" {
						wm.InActiveWorker(worker)
					}

					worker.addr = addr
					if worker.addr != "" {
						infoData, err := json.Marshal(msg.DeviceInfo.Devices)
						if err != nil {
							l.WithFields(log.Fields{
								"worker-addr": worker.addr,
								"error":       err,
							}).Error("marshal device info failed")
						} else if len(infoData) > 0 {
							infoHash := sha3.Sum256(infoData)
							if bytes.Compare(infoHash[:], worker.deviceInfoHash) != 0 {
								wm.UpdateWorkerDeviceInfo(worker, string(infoData))
							}
							worker.deviceInfoHash = infoHash[:]
						}
						wm.AddWorkerFirst(worker)
						wm.SetWorkerAddr(worker, worker.addr)
					}
				}

			case *omanager.WorkerMessage_DeviceUsage:
				// todo: handler worker device usage
				worker.usageInfo = msg.DeviceUsage.Usage
				l.WithFields(log.Fields{
					"worker-addr": worker.addr,
				}).Debugf("receive worker device usage:%v", msg.DeviceUsage.Usage)

			default:
				l.WithField("worker-addr", worker.addr).Error(fmt.Sprintf("unsupport msg type %T", msg))
			}
		}
	}
}

func (wm *WorkerManager) Payment(task *odysseus.TaskContent) error {
	if config.GetConfig().EnablePay == true {
		// pay for task.
		fee, _ := strconv.ParseInt(task.TaskFee, 10, 64)
		uid, _ := strconv.ParseInt(task.TaskUid, 10, 64)

		err := wm.node.cache.PayforFee(uid, fee)
		if err != nil {
			return err
		}
	}
	return nil
}

func (wm *WorkerManager) makeReceipt(worker *Worker, task *odysseus.TaskContent, result *omanager.SubmitTaskResult, err error) *odysseus.TaskReceipt {
	now := uint64(time.Now().UnixNano())
	receipt := &odysseus.TaskReceipt{
		TaskId:              task.TaskId,
		TaskTimestamp:       task.TaskTimestamp,
		TaskType:            task.TaskType,
		TaskUid:             task.TaskUid,
		TaskWorkload:        task.TaskWorkload,
		TaskDuration:        int64(now-task.TaskTimestamp) / 1000,
		TaskFee:             0,
		TaskOutLen:          int64(len(result.TaskResultBody)),
		TaskProfitAccount:   worker.ProfitAccount().Hex(),
		TaskWorkerAccount:   worker.WorkerAccount().Hex(),
		TaskExecuteDuration: result.TaskExecuteDuration,
	}
	if result.IsSuccessed {
		fee, _ := strconv.ParseInt(task.TaskFee, 10, 64)
		receipt.TaskFee = fee
		receipt.TaskResult = Succeed.Error()
	} else {
		receipt.TaskResult = err.Error()
	}
	return receipt
}
