package server

import (
	"encoding/hex"
	"errors"
	"fmt"
	"github.com/golang/protobuf/proto"
	lru "github.com/hashicorp/golang-lru"
	"github.com/odysseus/cache/cachedata"
	"github.com/odysseus/mogo/operator"
	"github.com/odysseus/nodemanager/config"
	"github.com/odysseus/nodemanager/standardlib"
	"github.com/odysseus/nodemanager/utils"
	odysseus "github.com/odysseus/odysseus-protocol/gen/proto/go/base/v1"
	omanager "github.com/odysseus/odysseus-protocol/gen/proto/go/nodemanager/v2"
	"github.com/odysseus/service-registry/registry"
	"github.com/redis/go-redis/v9"
	log "github.com/sirupsen/logrus"
	"go.mongodb.org/mongo-driver/mongo"
	"math/big"
	"regexp"
	"strconv"
	"sync"
	"time"
)

var (
	Succeed                = errors.New("succeed")
	ErrWorkerExist         = errors.New("worker exist")
	ErrHeartBeatExpired    = errors.New("worker heartbeat expired")
	ErrLongtimeNoTask      = errors.New("worker long time no task")
	ErrInvalidMessageValue = errors.New("invalid message value")
	ErrInvalidMsgSignature = errors.New("invalid message signature")
	ErrExpiredMsgSignature = errors.New("expired message signature")
	ErrOldConnection       = errors.New("old connection")
)

type TaskStatus int

const (
	TASK_CREATE TaskStatus = iota
	TASK_WAIT_ACK
	TASK_ACKED
	TASK_ACK_TIMEOUT
	TASK_FINISHED
	TASK_TIMEOUT
)

type NodeInterface interface {
	PostResult(*odysseus.TaskReceipt)
	PostProof(*odysseus.TaskProof)
	Sign(hash []byte) ([]byte, error)
	Cache() *cachedata.CacheData
	PayForFee(uid int64, fee int64) error
	Mongo() *mongo.Client
}

type WorkerManager struct {
	rdb       *redis.Client
	heartBeat map[int64]int64
	hbRwLock  sync.RWMutex

	workerByIp sync.Map
	workers    map[int64]*Worker
	workid     map[string]*Worker
	workerReg  map[int64]*registry.Registry
	wkRwLock   sync.RWMutex
	quit       chan struct{}

	node                    NodeInterface
	std                     *standardlib.StandardTasks
	workerInfoOperator      *operator.WorkerInfoOperator
	workerInstalledOperator *operator.WorkerInstalledOperator
	workerRunningOperator   *operator.WorkerRunningOperator
}

func NewWorkerManager(rdb *redis.Client, node NodeInterface) *WorkerManager {
	return &WorkerManager{
		heartBeat:               make(map[int64]int64),
		workerReg:               make(map[int64]*registry.Registry),
		workers:                 make(map[int64]*Worker),
		workid:                  make(map[string]*Worker),
		quit:                    make(chan struct{}),
		rdb:                     rdb,
		node:                    node,
		std:                     standardlib.NewStandardTasks(),
		workerInfoOperator:      operator.NewDBWorker(node.Mongo(), config.GetConfig().Mongo.Database),
		workerInstalledOperator: operator.NewDBWorkerInstalled(node.Mongo(), config.GetConfig().Mongo.Database),
		workerRunningOperator:   operator.NewDBWorkerRunning(node.Mongo(), config.GetConfig().Mongo.Database),
	}
}

func (wm *WorkerManager) Stop() {
	close(wm.quit)
}

func (wm *WorkerManager) GetWorker(uuid int64) *Worker {
	wm.wkRwLock.RLock()
	defer wm.wkRwLock.RUnlock()

	return wm.workers[uuid]
}

func (wm *WorkerManager) SetWorkerRegistry(uuid int64, reg *registry.Registry) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()
	wm.workerReg[uuid] = reg
}

func (wm *WorkerManager) StopRegistry(uuid int64) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()
	if reg, exist := wm.workerReg[uuid]; exist {
		reg.Clear()
		reg.Stop()
		delete(wm.workerReg, uuid)
	}
}

func (wm *WorkerManager) SetWorkerAddr(worker *Worker, addr string) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()

	worker.workerAddr = addr
	wm.workid[addr] = worker
}

func (wm *WorkerManager) GetWorkerByAddr(addr string) *Worker {
	wm.wkRwLock.RLock()
	defer wm.wkRwLock.RUnlock()

	return wm.workid[addr]
}

func (wm *WorkerManager) GetWorkerById(id int64) *Worker {
	wm.wkRwLock.RLock()
	defer wm.wkRwLock.RUnlock()
	return wm.workers[id]
}

func (wm *WorkerManager) AddNewWorker(id int64, worker omanager.NodeManagerService_RegisterWorkerServer) (*Worker, error) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()

	if _, exist := wm.workers[id]; exist {
		return nil, ErrWorkerExist
	}
	w := &Worker{
		wm:       wm,
		quit:     make(chan interface{}),
		errCh:    make(chan error, 1),
		taskCh:   make(chan *dispatchTask),
		msgCh:    make(chan *omanager.WorkerMessage, 30),
		sendCh:   make(chan sendMsgCallback, 30),
		resultCh: make(chan *omanager.SubmitTaskResult, 30),

		uuid:     id,
		registed: false,
		online:   false,

		info:           nil,
		workerAddr:     "",
		deviceInfoHash: nil,
		status:         "",
		stream:         worker,
		infoOp:         wm.workerInfoOperator,
		installOp:      wm.workerInstalledOperator,
		runningOp:      wm.workerRunningOperator,
	}
	w.latestNmValue = wm.LastNmValue(w)

	taskCache, err := lru.NewARC(10000)
	if err != nil {
		return nil, err
	}
	w.recentTask = taskCache
	wm.workers[id] = w

	go wm.handleWorkerMsg(w)

	return w, nil
}

type Callback func(err error) bool

func (wm *WorkerManager) doCallback(hook string, response *odysseus.TaskResponse) {
	d, err := proto.Marshal(response)
	if err != nil {
		log.WithError(err).Error("marshal task response failed")
	} else {
		log.WithField("task-id", response.TaskId).Debug("marshal task response")
	}
	for i := 0; i < 5; i++ {
		err = utils.Post(hook, d)
		if err != nil {
			log.WithField("task-id", response.TaskId).WithField("try", i).WithError(err).Error("post task result failed")
			time.Sleep(time.Second)
		} else {
			log.WithField("task-id", response.TaskId).Debug("post task result")
			break
		}
	}
}

func (wm *WorkerManager) disconnect(worker *Worker) {
	worker.online = false
	worker.status = "disconnected"
	worker.Disconnect()
	wm.InActiveWorker(worker)
	if worker.registed {
		wm.StopRegistry(worker.uuid)
	}
	wm.wkRwLock.Lock()
	delete(wm.workers, worker.uuid)
	delete(wm.workid, worker.workerAddr)
	wm.wkRwLock.Unlock()
}

func (wm *WorkerManager) manageWorker(worker *Worker) error {

	log.WithField("worker", worker.uuid).Info("start manage worker")
	tickerConf := config.GetConfig().Tickers

	initialHeartBeatInterval := time.Second * 1
	initialInterval := initialHeartBeatInterval * 2

	heartBeatDuration := time.Second * time.Duration(tickerConf.HeartBeat)
	workerCheckDuration := heartBeatDuration * 3

	heartBeatTicker := time.NewTicker(initialHeartBeatInterval)
	defer heartBeatTicker.Stop()

	nodeinfoTicker := time.NewTicker(initialHeartBeatInterval * 10)
	defer nodeinfoTicker.Stop()

	workerCheckTicker := time.NewTicker(workerCheckDuration)
	defer workerCheckTicker.Stop()

	deviceUsageTicker := time.NewTicker(initialInterval)
	defer deviceUsageTicker.Stop()

	gpuUsageTicker := time.NewTicker(initialInterval * 5)
	defer gpuUsageTicker.Stop()

	worker.status = "connected"

	defer func() {
		log.WithFields(log.Fields{
			"worker-addr": worker.workerAddr,
			"worker-uuid": worker.uuid,
		}).Info("exit manage worker")

		wm.disconnect(worker)
	}()
	go worker.SendMessage()

	for {
		var msg = new(omanager.ManagerMessage)
		var callback = Callback(func(err error) bool {
			// do nothing
			return true
		})

		select {
		case <-wm.quit:
			gb := new(omanager.ManagerMessage_GoodbyeMessage)
			gb.GoodbyeMessage = &omanager.GoodbyeMessage{}
			msg.Message = gb

		case reason, ok := <-worker.quit:
			if ok {
				log.WithField("reason", reason).WithField("worker-uuid", worker.uuid).Error("worker quit")
			}
			return nil

		case <-workerCheckTicker.C:
			if worker.info != nil {
				nodeinfoTicker.Reset(time.Minute * 30)
			}

			if worker.usage.hwUsage != nil {
				deviceUsageTicker.Reset(time.Second * time.Duration(tickerConf.DeviceUsageTicker))
			}

			if worker.registed && worker.addFirstSucceed == false {
				if err := wm.AddWorker(worker); err == nil {
					worker.addFirstSucceed = true
				}
			}

			wm.UpdateWorkerActive(worker)

		case <-heartBeatTicker.C:
			hb := new(omanager.ManagerMessage_HeartbeatRequest)
			hb.HeartbeatRequest = &omanager.HeartbeatRequest{
				Timestamp: uint64(time.Now().Unix()),
			}
			heartBeatTicker.Reset(heartBeatDuration)
			msg.Message = hb
			callback = func(err error) bool {
				log.WithField("worker", worker.uuid).Debug("send heart beat to worker")
				return true
			}

		case <-nodeinfoTicker.C:
			nodeinfo := new(omanager.ManagerMessage_NodeInfoRequest)
			nodeinfo.NodeInfoRequest = &omanager.NodeInfoRequest{}
			msg.Message = nodeinfo
			callback = func(err error) bool {
				return true
			}

		case <-deviceUsageTicker.C:
			// if worker is not registed to me, ignore device usage info.
			if !worker.registed {
				continue
			}
			deviceUsage := new(omanager.ManagerMessage_DeviceUsageRequest)
			deviceUsage.DeviceUsageRequest = &omanager.DeviceUsageRequest{}
			msg.Message = deviceUsage
			callback = func(err error) bool {
				return true
			}

		case <-gpuUsageTicker.C:
			gpu := new(omanager.ManagerMessage_GpuUsageRequest)
			gpu.GpuUsageRequest = &omanager.GPUUsageRequest{}

			msg.Message = gpu
			callback = func(err error) bool {
				log.WithField("worker", worker.uuid).Debug("send heart beat to worker")
				return true
			}

		case dtask, ok := <-worker.taskCh:
			if !ok {
				return nil
			}
			task := dtask.task
			taskMsg := new(omanager.ManagerMessage_PushTask)
			taskMsg.PushTask = &omanager.PushTaskMessage{
				TaskId:    task.TaskId,
				TaskType:  task.TaskType,
				TaskKind:  task.TaskKind,
				Workload:  uint64(task.TaskWorkload),
				TaskCmd:   task.TaskCmd,
				TaskParam: task.TaskParam,
			}
			msg.Message = taskMsg
			// add task to cache.
			worker.recentTask.Add(task.TaskId, dtask)
			go dtask.dispatched(wm)

			callback = func(err error) bool {
				if err == nil {
					if task.TaskKind == odysseus.TaskKind_ComputeTask {
						if e := wm.setWorkerLastTaskTime(worker, time.Now().Unix()); e != nil {
							log.WithField("worker", worker.uuid).WithError(e).Error("set worker last task time failed")
						}
					}
				}
				log.WithFields(log.Fields{
					"task":        task.TaskId,
					"worker":      worker.uuid,
					"worker addr": worker.workerAddr,
				}).Info("dispatch task to worker")
				if err != nil {
					dtask.errCh <- err
				}
				return true
			}
		case result := <-worker.resultCh:
			log.WithFields(log.Fields{
				"task":          result.TaskId,
				"worker addr":   worker.workerAddr,
				"success":       result.IsSuccessed,
				"result code":   result.TaskExecuteCode,
				"duration":      result.TaskExecuteDuration / 1000,
				"result length": len(result.TaskResultBody),
			}).Info("got result from worker")
			if result.TaskExecuteCode == 303 {
				log.WithField("result data", string(result.TaskResultBody)).WithField("task", result.TaskId).Info("got worker result body")
			}
			// verify result and make a new signature.
			data, exist := worker.recentTask.Get(result.TaskId)
			if !exist {
				log.WithField("worker", worker.uuid).Error("task not found for verify result")
				continue
			}
			dtask := data.(*dispatchTask)
			if dtask.status < TASK_FINISHED {
				dtask.setResult(result)
			} else {
				log.WithFields(log.Fields{
					"task":        result.TaskId,
					"worker addr": worker.workerAddr,
				}).Warn("task is timeout")
			}
		}
		worker.SendToWorker(msg, callback)
	}

	return nil
}
func (wm *WorkerManager) handleWorkerMsg(worker *Worker) {
	l := log.WithField("worker-uuid", worker.uuid)
	l.WithField("worker-addr", worker.workerAddr).Info("start handle worker message")
	defer l.WithField("worker-addr", worker.workerAddr).Info("exit handle worker message")
	defer close(worker.quit)
	defer worker.Disconnect()

	checkDuration := config.GetConfig().Tickers.HeartBeat * 3

	workerCheckTicker := time.NewTicker(time.Second * time.Duration(checkDuration))
	defer workerCheckTicker.Stop()

	go worker.RecvMessage()
	for {
		select {
		case <-wm.quit:
			return
		case workerErr := <-worker.errCh:
			log.WithError(workerErr).WithField("worker-uuid", worker.uuid).Error("worker error")
			worker.quit <- workerErr
			return

		case <-workerCheckTicker.C:
			if time.Now().Unix()-worker.heartBeat > int64(checkDuration) {
				log.WithField("worker-uuid", worker.uuid).Error("worker heartbeat expired")
				worker.quit <- ErrHeartBeatExpired
				return
			}
			if lastTaskTm, err := wm.getWorkerLastTaskTime(worker); err != nil {
				log.WithError(err).Error("get worker last task time failed")
			} else {
				expire := config.GetConfig().Tickers.WorkerTaskExpireTicker
				if expire <= 0 {
					expire = 60 // default value
				}
				if time.Now().Unix()-lastTaskTm > int64(expire) {
					log.WithField("worker-uuid", worker.uuid).Error("worker last task time expired")
					worker.quit <- ErrLongtimeNoTask
					return
				}
			}
		case wmsg := <-worker.msgCh:
			worker.online = true
			switch msg := wmsg.Message.(type) {
			case *omanager.WorkerMessage_GoodbyeMessage:
				worker.doGoodBye(msg)
				return
			case *omanager.WorkerMessage_SubmitTaskAck:
				worker.doSubmitAck(msg)
			case *omanager.WorkerMessage_SubmitTaskResult:
				worker.doSubmitResult(msg)
			case *omanager.WorkerMessage_HeartbeatResponse:
				worker.doHeartBeat(msg)
			case *omanager.WorkerMessage_BenefitAddrUpdate:
				worker.doUpdateBenefit(msg)
			case *omanager.WorkerMessage_NodeInfo:
				worker.doGetNodeInfo(msg)

			case *omanager.WorkerMessage_FetchStandardTask:
				worker.doFetchStdTask(msg)

			case *omanager.WorkerMessage_DeviceInfo:
				worker.doGetDeviceInfo(msg)

			case *omanager.WorkerMessage_DeviceUsage:
				worker.doDeviceUsage(msg)
			case *omanager.WorkerMessage_GpuUsage:
				worker.doGPUUsage(msg)

			case *omanager.WorkerMessage_AddModelRunning:
				worker.doAddRunningModel(msg)

			case *omanager.WorkerMessage_DelModeRunning:
				worker.doRemoveRunningModel(msg)

			case *omanager.WorkerMessage_AddModelInstalled:
				worker.doAddInstalledModel(msg)

			case *omanager.WorkerMessage_DelModelInstalled:
				worker.doRemoveInstalledModel(msg)

			case *omanager.WorkerMessage_InstalledModelStatus:
				worker.doInstalledModelStatus(msg)

			case *omanager.WorkerMessage_RunningModelStatus:
				worker.doRunningModelStatus(msg)

			case *omanager.WorkerMessage_RegisteMessage:
				// 1. do some verify.
				if worker.registed {
					continue
				}
				l.WithFields(log.Fields{
					"worker-addr": worker.workerAddr,
				}).Debug("receive registed message")
				// 2. check signature.
				info := msg.RegisteMessage.Info
				{
					hardware := msg.RegisteMessage.Hardware
					sig := msg.RegisteMessage.DeviceSignature
					data := utils.CombineBytes([]byte(info.String()),
						[]byte(hardware.String()),
						[]byte(msg.RegisteMessage.Models.String()),
						big.NewInt(int64(msg.RegisteMessage.Timestamp)).Bytes())

					if !utils.VerifySignature(data, sig, utils.FromHex(info.MinerPubkey)) {
						l.WithFields(log.Fields{
							"worker-addr": worker.workerAddr,
						}).Error("verify device signature failed")
						worker.quit <- ErrInvalidMessageValue
						return
					}

				}
				// 3. check timestamp  not expired.
				if time.Now().Unix()-int64(msg.RegisteMessage.Timestamp) > config.GetConfig().GetWorkerSignatureExpiredTime() {
					l.WithFields(log.Fields{
						"worker-addr": worker.workerAddr,
						"timestamp":   msg.RegisteMessage.Timestamp,
					}).Error("message signature expired")
					worker.quit <- ErrExpiredMsgSignature
					return
				}

				// 4. replace old connection.

				if pubkey, err := utils.HexToPubkey(info.MinerPubkey); err != nil {
					l.WithFields(log.Fields{
						"worker-addr": worker.workerAddr,
						"error":       err,
					}).Error("parse pubkey failed")
					worker.quit <- ErrInvalidMsgSignature
					return

				} else {
					addr := utils.PubkeyToAddress(pubkey)
					if old := wm.GetWorkerByAddr(addr); old != nil {
						old.errCh <- ErrOldConnection

						l.WithField("worker-addr", worker.workerAddr).Error("worker with the address is existed, and disconnect it")
						worker.quit <- ErrWorkerExist
						return
					}
					worker.workerAddr = addr
				}
				worker.registed = true
				// 5. check ip address.
				matched, err := regexp.MatchString("((2(5[0-5]|[0-4]\\d))|[0-1]?\\d{1,2})(\\.((2(5[0-5]|[0-4]\\d))|[0-1]?\\d{1,2})){3}",
					msg.RegisteMessage.Hardware.NET.Ip)
				if err != nil {
					log.WithField("registed.ip", msg.RegisteMessage.Hardware.NET.Ip).Error("ip匹配出现错误")
				}
				if !matched {
					msg.RegisteMessage.Hardware.NET.Ip = ""
				}
				worker.info = &omanager.NodeInfoResponse{
					Info:     msg.RegisteMessage.Info,
					Hardware: msg.RegisteMessage.Hardware,
					Models:   msg.RegisteMessage.Models,
				}
				wm.SetWorkerAddr(worker, worker.workerAddr)
				// check white list.
				if err := wm.checkWhiteList(worker, info.BenefitAddress); err != nil {
					worker.quit <- err
					return
				} else {
					wm.addWorkerToWhiteListSet(worker, info.BenefitAddress)
				}

				// add worker to mogo.
				if err := wm.AddWorker(worker); err == nil {
					worker.addFirstSucceed = true
					wm.UpdateWorkerActive(worker)
				}

				// start manage worker.
				wreg := workerRegistry{
					worker: worker,
					wm:     wm,
				}

				reg := registry.NewRegistry(registry.RedisConnParam{
					Addr:     config.GetConfig().Redis.Addr,
					Password: config.GetConfig().Redis.Password,
					DbIndex:  config.GetConfig().Redis.DbIndex,
				}, wreg, wreg.Instance)

				go reg.Start()
				wm.SetWorkerRegistry(worker.uuid, reg)
				if e := wm.setWorkerLastTaskTime(worker, time.Now().Unix()); e != nil {
					log.WithField("worker", worker.uuid).WithError(e).Error("set worker last task time failed")
				}

			default:
				l.WithField("worker-addr", worker.workerAddr).Error(fmt.Sprintf("unsupport msg type %T", msg))
			}
		}
	}
}

func (wm *WorkerManager) Payment(task *odysseus.TaskContent) error {
	if config.GetConfig().EnablePay == true {
		// pay for task.
		fee, _ := strconv.ParseInt(task.TaskFee, 10, 64)
		uid, _ := strconv.ParseInt(task.TaskUid, 10, 64)

		err := wm.node.PayForFee(uid, fee)
		if err != nil {
			return err
		}
	}
	return nil
}

func (wm *WorkerManager) makeReceipt(worker *Worker, task *odysseus.TaskContent, result *omanager.SubmitTaskResult, err error) *odysseus.TaskReceipt {
	now := uint64(time.Now().UnixNano())
	receipt := &odysseus.TaskReceipt{
		TaskId:              task.TaskId,
		TaskTimestamp:       task.TaskTimestamp,
		TaskType:            task.TaskType,
		TaskUid:             task.TaskUid,
		TaskWorkload:        task.TaskWorkload,
		TaskDuration:        int64(now-task.TaskTimestamp) / 1000,
		TaskFee:             0,
		TaskOutLen:          int64(len(result.TaskResultBody)),
		TaskProfitAccount:   worker.ProfitAccount().Hex(),
		TaskWorkerAccount:   worker.WorkerAccount().Hex(),
		TaskExecuteDuration: result.TaskExecuteDuration,
	}
	if result.IsSuccessed {
		fee, _ := strconv.ParseInt(task.TaskFee, 10, 64)
		receipt.TaskFee = fee
		receipt.TaskResult = Succeed.Error()
	} else {
		receipt.TaskResult = err.Error()
	}
	return receipt
}

type taskProof struct {
	paramHash          []byte
	resultHash         []byte
	finishTime         uint64
	nmSignature        []byte
	containerSignature []byte
	minerSignature     []byte
}

func (wm *WorkerManager) makeTaskProof(worker *Worker, task *odysseus.TaskContent, t taskProof) *odysseus.TaskProof {
	proof := &odysseus.TaskProof{
		TaskId:                 task.TaskId,
		TaskFinishTimestamp:    t.finishTime,
		TaskType:               task.TaskType,
		TaskWorkload:           uint64(task.TaskWorkload),
		TaskReqHash:            hex.EncodeToString(t.paramHash),
		TaskRespHash:           hex.EncodeToString(t.resultHash),
		TaskManagerSignature:   hex.EncodeToString(t.nmSignature),
		TaskContainerSignature: hex.EncodeToString(t.containerSignature),
		TaskMinerSignature:     hex.EncodeToString(t.minerSignature),
		TaskWorkerAccount:      worker.workerAddr,
		TaskProfitAccount:      worker.info.Info.BenefitAddress,
	}
	return proof
}
