package server

import (
	"bytes"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"github.com/ethereum/go-ethereum/common"
	"github.com/golang/protobuf/proto"
	"github.com/google/uuid"
	lru "github.com/hashicorp/golang-lru"
	"github.com/odysseus/nodemanager/config"
	"github.com/odysseus/nodemanager/standardlib"
	"github.com/odysseus/nodemanager/utils"
	odysseus "github.com/odysseus/odysseus-protocol/gen/proto/go/base/v1"
	omanager "github.com/odysseus/odysseus-protocol/gen/proto/go/nodemanager/v1"
	"github.com/odysseus/service-registry/registry"
	"github.com/redis/go-redis/v9"
	log "github.com/sirupsen/logrus"
	"golang.org/x/crypto/sha3"
	"strconv"
	"sync"
	"time"
)

var (
	Succeed                = errors.New("succeed")
	ErrWorkerExist         = errors.New("worker exist")
	ErrHeartBeatExpired    = errors.New("worker heartbeat expired")
	ErrInvalidMessageValue = errors.New("invalid message value")
)

type dispatchTask struct {
	task  *odysseus.TaskContent
	errCh chan error
}

type workerInfo struct {
	nodeInfo         *omanager.NodeInfoResponse
	deviceUsageInfo  []*omanager.DeviceUsage
	deviceInfo       *omanager.DeviceInfoMessage
	deviceStatusInfo *omanager.StatusResponse
	resourceInfo     *omanager.SubmitResourceMap
}

type Worker struct {
	quit     chan interface{}
	taskCh   chan *dispatchTask
	resultCh chan *omanager.SubmitTaskResult

	uuid            int64 // worker uuid in the local.
	registed        bool  // worker is registed to this nm.
	online          bool
	nonce           int
	latestNmValue   string
	addFirstSucceed bool

	info           workerInfo
	workerAddr     string // worker address from public-key
	deviceInfoHash []byte
	recentTask     *lru.Cache
	status         string

	stream omanager.NodeManagerService_RegisterWorkerServer
}

func (w *Worker) ProfitAccount() common.Address {
	if w.info.nodeInfo != nil {
		return common.HexToAddress(w.info.nodeInfo.BenefitAddress)
	}
	return common.Address{}
}

func (w *Worker) WorkerAccount() common.Address {
	return common.HexToAddress(w.workerAddr)
}

type WorkerManager struct {
	rdb       *redis.Client
	heartBeat map[int64]int64
	hbRwLock  sync.RWMutex

	workerByIp sync.Map
	workers    map[int64]*Worker
	workid     map[string]*Worker
	workerReg  map[int64]*registry.Registry
	wkRwLock   sync.RWMutex
	quit       chan struct{}

	node *Node
	std  *standardlib.StandardTasks
}

func NewWorkerManager(rdb *redis.Client, node *Node) *WorkerManager {
	return &WorkerManager{
		heartBeat: make(map[int64]int64),
		workerReg: make(map[int64]*registry.Registry),
		workers:   make(map[int64]*Worker),
		workid:    make(map[string]*Worker),
		quit:      make(chan struct{}),
		rdb:       rdb,
		node:      node,
		std:       standardlib.NewStandardTasks(),
	}
}

func (wm *WorkerManager) Stop() {
	close(wm.quit)
}

func (wm *WorkerManager) UpdateHeartBeat(uuid int64) {
	wm.hbRwLock.Lock()
	defer wm.hbRwLock.Unlock()
	wm.heartBeat[uuid] = time.Now().Unix()
}

func (wm *WorkerManager) GetHeartBeat(uuid int64) int64 {
	wm.hbRwLock.RLock()
	defer wm.hbRwLock.RUnlock()
	return wm.heartBeat[uuid]
}

func (wm *WorkerManager) GetWorker(uuid int64) *Worker {
	wm.wkRwLock.RLock()
	defer wm.wkRwLock.RUnlock()

	return wm.workers[uuid]
}

func (wm *WorkerManager) SetWorkerRegistry(uuid int64, reg *registry.Registry) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()
	wm.workerReg[uuid] = reg
}

func (wm *WorkerManager) StopRegistry(uuid int64) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()
	if reg, exist := wm.workerReg[uuid]; exist {
		reg.Stop()
		delete(wm.workerReg, uuid)
	}
}

func (wm *WorkerManager) SetWorkerAddr(worker *Worker, addr string) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()

	worker.workerAddr = addr
	wm.workid[addr] = worker
}

func (wm *WorkerManager) GetWorkerByAddr(addr string) *Worker {
	wm.wkRwLock.RLock()
	defer wm.wkRwLock.RUnlock()

	return wm.workid[addr]
}

func (wm *WorkerManager) GetWorkerById(id int64) *Worker {
	wm.wkRwLock.RLock()
	defer wm.wkRwLock.RUnlock()
	return wm.workers[id]
}

func (wm *WorkerManager) AddNewWorker(id int64, worker omanager.NodeManagerService_RegisterWorkerServer) (*Worker, error) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()

	if _, exist := wm.workers[id]; exist {
		return nil, ErrWorkerExist
	}
	w := &Worker{
		quit:     make(chan interface{}),
		taskCh:   make(chan *dispatchTask),
		resultCh: make(chan *omanager.SubmitTaskResult),

		uuid:     id,
		registed: false,
		online:   false,

		info:           workerInfo{},
		workerAddr:     "",
		deviceInfoHash: nil,
		status:         "",
		stream:         worker,
	}

	taskCache, err := lru.New(100)
	if err != nil {
		return nil, err
	}
	w.recentTask = taskCache
	wm.workers[id] = w

	go wm.handleWorkerMsg(w)

	return w, nil
}

type Callback func(err error) bool

func (wm *WorkerManager) doCallback(hook string, response *odysseus.TaskResponse) {
	d, err := proto.Marshal(response)
	if err != nil {
		log.WithError(err).Error("marshal task response failed")
	} else {
		log.WithField("task-id", response.TaskId).Debug("marshal task response")
	}
	err = utils.Post(hook, d)
	if err != nil {
		log.WithError(err).Error("post task result failed")
	} else {
		log.WithField("task-id", response.TaskId).Debug("post task result")
	}
}

func (wm *WorkerManager) disconnect(worker *Worker) {
	worker.online = false
	worker.status = "disconnected"

	wm.InActiveWorker(worker)
	if worker.registed {
		wm.StopRegistry(worker.uuid)
	}
	wm.wkRwLock.Lock()
	delete(wm.workers, worker.uuid)
	delete(wm.workid, worker.workerAddr)
	wm.wkRwLock.Unlock()
}

func (wm *WorkerManager) manageWorker(worker *Worker) error {

	log.WithField("worker", worker.uuid).Info("start manage worker")
	tickerConf := config.GetConfig().Tickers

	initialHeartBeatInterval := time.Second * 1
	initialInterval := initialHeartBeatInterval * 2

	heartBeatDuration := time.Second * time.Duration(tickerConf.HeartBeat)
	workerCheckDuration := heartBeatDuration * 3

	heartBeatTicker := time.NewTicker(initialHeartBeatInterval)
	defer heartBeatTicker.Stop()

	nodeinfoTicker := time.NewTicker(initialHeartBeatInterval)
	defer nodeinfoTicker.Stop()

	workerCheckTicker := time.NewTicker(workerCheckDuration)
	defer workerCheckTicker.Stop()

	statusTicker := time.NewTicker(initialInterval)
	defer statusTicker.Stop()

	deviceUsageTicker := time.NewTicker(initialInterval)
	defer deviceUsageTicker.Stop()

	worker.status = "connected"

	defer func() {
		log.WithFields(log.Fields{
			"worker-addr": worker.workerAddr,
			"worker-uuid": worker.uuid,
		}).Info("exit manage worker")

		wm.disconnect(worker)
	}()

	for {
		var msg = new(omanager.ManagerMessage)
		var callback = Callback(func(err error) bool {
			// do nothing
			return true
		})

		select {
		case <-wm.quit:
			gb := new(omanager.ManagerMessage_GoodbyeMessage)
			gb.GoodbyeMessage = &omanager.GoodbyeMessage{}
			msg.Message = gb

		case reason, ok := <-worker.quit:
			if ok {
				log.WithField("reason", reason).WithField("worker-uuid", worker.uuid).Error("worker quit")
			}
			return nil

		case <-workerCheckTicker.C:
			if worker.info.nodeInfo != nil {
				nodeinfoTicker.Reset(time.Hour * 24)
			}

			if worker.info.deviceStatusInfo != nil {
				statusTicker.Reset(time.Second * time.Duration(tickerConf.StatusTicker))
			}

			if worker.info.deviceUsageInfo != nil {
				deviceUsageTicker.Reset(time.Second * time.Duration(tickerConf.DeviceUsageTicker))
			}

			if worker.registed && worker.addFirstSucceed == false && len(worker.deviceInfoHash) > 0 {
				wm.AddWorkerToQueue(worker)
			}

			wm.UpdateWorkerActive(worker)

		case <-heartBeatTicker.C:
			hb := new(omanager.ManagerMessage_HeartbeatRequest)
			hb.HeartbeatRequest = &omanager.HeartbeatRequest{
				Timestamp: uint64(time.Now().Unix()),
			}
			heartBeatTicker.Reset(heartBeatDuration)
			msg.Message = hb
			callback = func(err error) bool {
				log.WithField("worker", worker.uuid).Info("send hear beat to worker")
				return true
			}

		case <-nodeinfoTicker.C:
			nodeinfo := new(omanager.ManagerMessage_NodeInfoRequest)
			nodeinfo.NodeInfoRequest = &omanager.NodeInfoRequest{}
			msg.Message = nodeinfo
			callback = func(err error) bool {
				return true
			}

		case <-deviceUsageTicker.C:
			// if worker is not registed to me, ignore device usage info.
			if !worker.registed {
				continue
			}
			deviceUsage := new(omanager.ManagerMessage_DeviceUsage)
			deviceUsage.DeviceUsage = &omanager.DeviceUsageRequest{}
			msg.Message = deviceUsage
			callback = func(err error) bool {
				return true
			}

		case <-statusTicker.C:
			// if worker is not registed to me, ignore device status info.
			if !worker.registed {
				continue
			}
			status := new(omanager.ManagerMessage_StatusRequest)
			status.StatusRequest = &omanager.StatusRequest{}
			msg.Message = status
			callback = func(err error) bool {
				return true
			}

		case dtask, ok := <-worker.taskCh:
			if !ok {
				return nil
			}
			task := dtask.task
			taskMsg := new(omanager.ManagerMessage_PushTaskMessage)
			taskMsg.PushTaskMessage = &omanager.PushTaskMessage{
				TaskId:    task.TaskId,
				TaskType:  task.TaskType,
				TaskKind:  task.TaskKind,
				Workload:  uint64(task.TaskWorkload),
				TaskCmd:   task.TaskCmd,
				TaskParam: task.TaskParam,
			}
			msg.Message = taskMsg
			callback = func(err error) bool {
				if err == nil {
					// add task to cache.
					worker.recentTask.Add(task.TaskId, task)
				}
				log.WithField("worker", worker.uuid).Info("dispatch task to worker")

				select {
				case dtask.errCh <- err:
				default:
					// err ch is invalid
				}
				return true
			}
		case result := <-worker.resultCh:
			// verify result and make a new signature.
			data, exist := worker.recentTask.Get(result.TaskId)
			if !exist {
				log.WithField("worker", worker.uuid).Error("task not found for verify result")
				continue
			}

			task := data.(*odysseus.TaskContent)
			if result.TaskId != task.TaskId {
				log.WithField("worker", worker.uuid).Error("task id not match")
				continue
			}

			proof, err := wm.taskResult(worker, task, result)
			if err != nil {
				continue
			}
			_ = proof

			worker.recentTask.Remove(result.TaskId)
			if task.TaskKind != odysseus.TaskKind_StandardTask {
				_ = wm.AddWorkerSingle(worker)
				wm.Payment(task)
			}
		}

		if msg != nil {
			err := worker.stream.Send(msg)
			if err != nil {
				log.WithError(err).Error("send message to worker failed")
			}
			callback(err)
		}
	}

	return nil
}
func (wm *WorkerManager) handleWorkerMsg(worker *Worker) {
	l := log.WithField("worker-uuid", worker.uuid)
	l.WithField("worker-addr", worker.workerAddr).Info("start handle worker message")
	defer l.WithField("worker-addr", worker.workerAddr).Info("exit handle worker message")
	defer close(worker.quit)

	checkDuration := config.GetConfig().Tickers.HeartBeat * 3

	workerCheckTicker := time.NewTicker(time.Second * time.Duration(checkDuration))
	defer workerCheckTicker.Stop()

	for {
		select {
		case <-wm.quit:
			return
		case <-workerCheckTicker.C:
			if time.Now().Unix()-wm.GetHeartBeat(worker.uuid) > int64(checkDuration) {
				log.WithField("worker-uuid", worker.uuid).Error("worker heartbeat expired")
				worker.quit <- ErrHeartBeatExpired
				return
			}
		default:
			wmsg, err := worker.stream.Recv()
			if err != nil {
				l.WithError(err).WithField("worker-addr", worker.workerAddr).Error("recv msg failed")
				worker.quit <- "recv msg failed"
				return
			}
			worker.online = true

			switch msg := wmsg.Message.(type) {

			case *omanager.WorkerMessage_GoodbyeMessage:
				worker.online = false
				worker.quit <- msg.GoodbyeMessage.Reason
				close(worker.taskCh)
				return

			case *omanager.WorkerMessage_SubmitTaskResult:
				worker.resultCh <- msg.SubmitTaskResult

			case *omanager.WorkerMessage_HeartbeatResponse:
				worker.online = true
				wm.UpdateHeartBeat(worker.uuid)
				l.WithFields(log.Fields{
					"worker-addr": worker.workerAddr,
					"hearBeat":    time.Now().Unix() - int64(msg.HeartbeatResponse.Timestamp),
				}).Debug("receive worker heartbeat")
			case *omanager.WorkerMessage_NodeInfo:
				worker.info.nodeInfo = msg.NodeInfo
				var addr = ""
				if pubkey, err := utils.HexToPubkey(msg.NodeInfo.MinerPubkey); err != nil {
					l.WithFields(log.Fields{
						"worker-addr": worker.workerAddr,
						"error":       err,
					}).Error("parse pubkey failed")
				} else {
					addr = utils.PubkeyToAddress(pubkey)
				}
				if addr == worker.workerAddr || addr == "" {
					// addr is not change.
					continue
				}

				// checkout addr exist.
				if worker.workerAddr == "" {
					if w := wm.GetWorkerByAddr(addr); w != nil {
						log.WithField("worker-addr", addr).Error("worker with the address is existed")
						return
					}
				}

				if worker.workerAddr != "" {
					// todo: worker change pubkey.
					wm.InActiveWorker(worker)
				}

				// update new worker.
				wm.SetWorkerAddr(worker, addr)

			case *omanager.WorkerMessage_Status:
				if !worker.registed {
					continue
				}
				worker.info.deviceStatusInfo = msg.Status
				l.WithFields(log.Fields{
					"worker-addr": worker.workerAddr,
				}).Debugf("receive worker status:0x%x", msg.Status.DeviceStatus)
				wm.UpdateWorkerDeviceStatusInfo(worker, msg.Status.DeviceStatus)

			case *omanager.WorkerMessage_ResourceMap:
				if !worker.registed {
					continue
				}
				l.WithFields(log.Fields{
					"worker-addr": worker.workerAddr,
				}).Debugf("receive worker resource map:%v", msg.ResourceMap)
				wm.UpdateWorkerResourceInfo(worker, msg.ResourceMap.ResourceMap)

			case *omanager.WorkerMessage_FetchStandardTask:
				if worker.info.nodeInfo == nil {
					continue
				}
				l.WithFields(log.Fields{
					"worker-addr": worker.workerAddr,
				}).Debugf("receive worker fetch std task request:%v", msg.FetchStandardTask.TaskType)
				pushTask := standardlib.StdTask{}
				task, exist := wm.std.GetTask(msg.FetchStandardTask.TaskType)
				if exist {
					stdlib := wm.std.GetStdLib(task.TaskType)
					if stdlib == nil {
						l.WithField("task-type", task.TaskType).Warn("not found std lib")
						continue
					}
					pushTask = task
					pushTask.TaskId = uuid.NewString()
					param, err := stdlib.GenerateParam(0)
					if err != nil {
						l.WithError(err).WithField("task-type", task.TaskType).Error("generate param failed")
						continue
					}
					pushTask.TaskParam = []byte(param)
					pushTask.TaskInLen = int32(len(param))
					pushTask.TaskUid = "0"
					pushTask.TaskTimestamp = uint64(time.Now().UnixNano())
					pushTask.TaskKind = odysseus.TaskKind_StandardTask
					pushTask.TaskFee = "0"
					worker.taskCh <- &dispatchTask{
						task:  &pushTask.TaskContent,
						errCh: make(chan error, 1),
					}
					break
				} else {
					l.WithField("task-type", msg.FetchStandardTask.TaskType).Warn("not found std task")
				}

			case *omanager.WorkerMessage_DeviceInfo:
				l.WithFields(log.Fields{
					"worker-addr": worker.workerAddr,
				}).Debugf("receive worker device info:%v", msg.DeviceInfo)
				if !worker.registed {
					// ignore the info.
					continue
				}
				{
					var infoHash [32]byte
					infoData, err := json.Marshal(msg.DeviceInfo)
					if err != nil {
						l.WithFields(log.Fields{
							"worker-addr": worker.workerAddr,
							"error":       err,
						}).Error("marshal device info failed")

					}
					if len(infoData) == 0 {
						continue
					}

					infoHash = sha3.Sum256(infoData)
					worker.info.deviceInfo = msg.DeviceInfo

					if worker.registed && worker.addFirstSucceed == false {
						wm.AddWorkerToQueue(worker)
					}
					// check device info changed, and update to cache.
					if bytes.Compare(infoHash[:], worker.deviceInfoHash) != 0 {
						wm.UpdateWorkerDeviceInfo(worker, string(infoData))
					}
					worker.deviceInfoHash = infoHash[:]
				}

			case *omanager.WorkerMessage_DeviceUsage:
				if !worker.registed {
					continue
				}
				usageData, _ := json.Marshal(msg.DeviceUsage)
				wm.UpdateWorkerDeviceInfo(worker, string(usageData))

				worker.info.deviceUsageInfo = msg.DeviceUsage.Usage
				l.WithFields(log.Fields{
					"worker-addr": worker.workerAddr,
				}).Debugf("receive worker device usage:%v", msg.DeviceUsage.Usage)
			case *omanager.WorkerMessage_RegisteMessage:
				if worker.registed {
					continue
				}
				l.WithFields(log.Fields{
					"worker-addr": worker.workerAddr,
				}).Debug("receive registed message")
				worker.registed = true

				if pubkey, err := utils.HexToPubkey(msg.RegisteMessage.MinerPubkey); err != nil {
					l.WithFields(log.Fields{
						"worker-addr": worker.workerAddr,
						"error":       err,
					}).Error("parse pubkey failed")
				} else {
					worker.workerAddr = utils.PubkeyToAddress(pubkey)
				}
				wm.SetWorkerAddr(worker, worker.workerAddr)

				wreg := workerRegistry{
					worker: worker,
					wm:     wm,
				}

				reg := registry.NewRegistry(registry.RedisConnParam{
					Addr:     config.GetConfig().Redis.Addr,
					Password: config.GetConfig().Redis.Password,
					DbIndex:  config.GetConfig().Redis.DbIndex,
				}, wreg, wreg.Instance)

				go reg.Start()
				wm.SetWorkerRegistry(worker.uuid, reg)

			default:
				l.WithField("worker-addr", worker.workerAddr).Error(fmt.Sprintf("unsupport msg type %T", msg))
			}
		}
	}
}

func (wm *WorkerManager) Payment(task *odysseus.TaskContent) error {
	if config.GetConfig().EnablePay == true {
		// pay for task.
		fee, _ := strconv.ParseInt(task.TaskFee, 10, 64)
		uid, _ := strconv.ParseInt(task.TaskUid, 10, 64)

		err := wm.node.cache.PayforFee(uid, fee)
		if err != nil {
			return err
		}
	}
	return nil
}

func (wm *WorkerManager) makeReceipt(worker *Worker, task *odysseus.TaskContent, result *omanager.SubmitTaskResult, err error) *odysseus.TaskReceipt {
	now := uint64(time.Now().UnixNano())
	receipt := &odysseus.TaskReceipt{
		TaskId:              task.TaskId,
		TaskTimestamp:       task.TaskTimestamp,
		TaskType:            task.TaskType,
		TaskUid:             task.TaskUid,
		TaskWorkload:        task.TaskWorkload,
		TaskDuration:        int64(now-task.TaskTimestamp) / 1000,
		TaskFee:             0,
		TaskOutLen:          int64(len(result.TaskResultBody)),
		TaskProfitAccount:   worker.ProfitAccount().Hex(),
		TaskWorkerAccount:   worker.WorkerAccount().Hex(),
		TaskExecuteDuration: result.TaskExecuteDuration,
	}
	if result.IsSuccessed {
		fee, _ := strconv.ParseInt(task.TaskFee, 10, 64)
		receipt.TaskFee = fee
		receipt.TaskResult = Succeed.Error()
	} else {
		receipt.TaskResult = err.Error()
	}
	return receipt
}

type taskProof struct {
	paramHash          []byte
	resultHash         []byte
	finishTime         uint64
	nmSignature        []byte
	containerSignature []byte
	minerSignature     []byte
}

func (wm *WorkerManager) makeTaskProof(worker *Worker, task *odysseus.TaskContent, t taskProof) *odysseus.TaskProof {
	proof := &odysseus.TaskProof{
		TaskId:                 task.TaskId,
		TaskFinishTimestamp:    t.finishTime,
		TaskType:               task.TaskType,
		TaskWorkload:           uint64(task.TaskWorkload),
		TaskReqHash:            hex.EncodeToString(t.paramHash),
		TaskRespHash:           hex.EncodeToString(t.resultHash),
		TaskManagerSignature:   hex.EncodeToString(t.nmSignature),
		TaskContainerSignature: hex.EncodeToString(t.containerSignature),
		TaskMinerSignature:     hex.EncodeToString(t.minerSignature),
		TaskWorkerAccount:      worker.workerAddr,
		TaskProfitAccount:      worker.info.nodeInfo.BenefitAddress,
	}
	return proof
}
